Source code for torchdr.base

"""Base class for dimensionality reduction methods."""

# Author: Hugues Van Assel <vanasselhugues@gmail.com>
#
# License: BSD 3-Clause License

from abc import ABC, abstractmethod

import torch
import torch.nn as nn
import numpy as np
from sklearn.base import BaseEstimator
from torch.utils.data import DataLoader

from torchdr.utils import (
    seed_everything,
    set_logger,
    handle_input_output,
)
from torchdr.distance import FaissConfig

from typing import Optional, Any, TypeVar, Union

ArrayLike = TypeVar("ArrayLike", torch.Tensor, np.ndarray)


[docs] class DRModule(BaseEstimator, nn.Module, ABC): """Base class for dimensionality reduction methods. Subclasses must implement :meth:`_fit_transform`. Parameters ---------- n_components : int, optional Number of dimensions for the embedding. Default is 2. device : str, optional Device for computations. ``"auto"`` uses the input tensor's device. Default is "auto". backend : {"keops", "faiss", None} or FaissConfig, optional Backend for handling sparsity and memory efficiency. Default is None (standard PyTorch). verbose : bool, optional Verbosity. Default is False. random_state : float, optional Random seed for reproducibility. Default is None. compile : bool, default=False Whether to use ``torch.compile`` for faster computation. process_duplicates : bool, default=True Whether to handle duplicate data points by default. """ def __init__( self, n_components: int = 2, device: str = "auto", backend: Union[str, FaissConfig, None] = None, verbose: bool = False, random_state: Optional[float] = None, compile: bool = False, process_duplicates: bool = True, **kwargs, ): super().__init__() self.n_components = n_components self.device = device if device is not None else "auto" self.backend = backend self.verbose = verbose self.random_state = random_state self.compile = compile self.process_duplicates = process_duplicates self.logger = set_logger(self.__class__.__name__, self.verbose) if self.random_state is not None: self._actual_seed = seed_everything( self.random_state, fast=True, deterministic=False ) self.logger.info(f"Random seed set to: {self._actual_seed}.") self.embedding_ = None self.is_fitted_ = False # --- Public API ---
[docs] @handle_input_output() def fit(self, X: ArrayLike, y: Optional[Any] = None) -> "DRModule": """Fit the model from the input data. Parameters ---------- X : ArrayLike of shape (n_samples, n_features) Input data (or ``(n_samples, n_samples)`` if precomputed). y : None Ignored. Returns ------- self : DRModule The fitted instance. """ self.fit_transform(X, y=y) return self
[docs] @handle_input_output() def fit_transform(self, X: ArrayLike, y: Optional[Any] = None) -> ArrayLike: """Fit the model and return the embedding. Handles duplicate data points by default: performs DR on unique points and maps results back to the original structure. Controlled by :attr:`process_duplicates`. Parameters ---------- X : ArrayLike of shape (n_samples, n_features) Input data (or ``(n_samples, n_samples)`` if precomputed). y : None Ignored. Returns ------- embedding_ : ArrayLike of shape (n_samples, n_components) The embedding. """ if self.process_duplicates and isinstance(X, DataLoader): self.logger.warning( "process_duplicates is not supported with DataLoader input. " "Consider deduplicating your dataset before creating " "the DataLoader." ) if self.process_duplicates and not isinstance(X, DataLoader): X_unique, inverse_indices = torch.unique(X, dim=0, return_inverse=True) if X_unique.shape[0] < X.shape[0]: n_duplicates = X.shape[0] - X_unique.shape[0] self.logger.info( f"Detected {n_duplicates} duplicate samples, " "performing DR on unique data." ) embedding_unique = self._fit_transform(X_unique, y=y) if isinstance(self.embedding_, torch.nn.Parameter): self.embedding_.data = embedding_unique[inverse_indices] else: self.embedding_ = embedding_unique[inverse_indices] else: self.embedding_ = self._fit_transform(X, y=y) else: self.embedding_ = self._fit_transform(X, y=y) self.is_fitted_ = True return self.embedding_
[docs] def transform(self, X: Optional[ArrayLike] = None) -> ArrayLike: """Transform data into the learned embedding space. If ``X`` is None, returns the training embedding. When an encoder is set, new data is transformed via ``encoder(X)``. Parameters ---------- X : ArrayLike of shape (n_samples, n_features), optional Data to transform. If None, returns the training embedding. Returns ------- embedding_ : ArrayLike of shape (n_samples, n_components) The embedding. """ if not self.is_fitted_: raise ValueError( "This DRModule instance is not fitted yet. " "Call 'fit' or 'fit_transform' with some data first." ) if X is not None: if getattr(self, "encoder", None) is not None: from torchdr.utils import to_torch X_tensor = to_torch(X).to(device=self.device_) with torch.no_grad(): return self.encoder(X_tensor) raise NotImplementedError( "Transforming new data is not implemented for this model." ) return self.embedding_
# --- Core algorithm (must be implemented by subclasses) --- @abstractmethod def _fit_transform(self, X: torch.Tensor, y: Optional[Any] = None) -> torch.Tensor: """Fit the model and return the embedding (core algorithm). Subclasses implement this with the actual DR logic. Called by :meth:`fit_transform` after duplicate handling. Parameters ---------- X : torch.Tensor of shape (n_samples, n_features) Input data (or ``(n_samples, n_samples)`` if precomputed). y : None Ignored. Returns ------- embedding_ : torch.Tensor of shape (n_samples, n_components) The embedding. """ raise NotImplementedError( "[TorchDR] ERROR : _fit_transform method is not implemented." ) # --- Utilities --- def _get_compute_device(self, X: torch.Tensor): """Return the target device (from ``self.device`` or inferred from X).""" return X.device if self.device == "auto" else self.device # --- Memory management ---
[docs] def clear_memory(self): """Clear non-persistent buffers to free memory after training.""" if hasattr(self, "_non_persistent_buffers_set"): for name in list(self._non_persistent_buffers_set): if hasattr(self, name): delattr(self, name) if torch.cuda.is_available(): torch.cuda.empty_cache()