"""Base classes for affinity matrices."""
# Author: Hugues Van Assel <vanasselhugues@gmail.com>
#
# License: BSD 3-Clause License
from abc import ABC
from typing import Union, Any
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchdr.utils import (
to_torch,
bool_arg,
set_logger,
DistributedContext,
)
from torchdr.distance import (
pairwise_distances,
FaissConfig,
)
import torch.distributed as dist
[docs]
class Affinity(nn.Module, ABC):
r"""Base class for affinity matrices.
Parameters
----------
metric : str, optional
Distance metric for pairwise distances. Default is "sqeuclidean".
zero_diag : bool, optional
Whether to set the diagonal to zero. Default is True.
device : str, optional
Device for computation. ``"auto"`` uses the input data's device.
Default is "auto".
backend : {"keops", "faiss", None} or FaissConfig, optional
Backend for handling sparsity and memory efficiency.
Default is None (standard PyTorch).
verbose : bool, optional
Verbosity. Default is False.
compile : bool, optional
Whether to compile the affinity computation. Default is False.
_pre_processed : bool, optional
If True, skips ``to_torch`` conversion (inputs are already tensors
on the correct device). Default is False.
"""
def __init__(
self,
metric: str = "sqeuclidean",
zero_diag: bool = True,
device: str = "auto",
backend: Union[str, FaissConfig] = None,
verbose: bool = False,
random_state: float = None,
compile: bool = False,
_pre_processed: bool = False,
):
super().__init__()
self.log = {}
self.metric = metric
self.zero_diag = bool_arg(zero_diag)
self.device = device if device is not None else "auto"
self.backend = backend
self.verbose = bool_arg(verbose)
self.random_state = random_state
self.compile = compile
self._pre_processed = _pre_processed
self.logger = set_logger(self.__class__.__name__, self.verbose)
# --- Public API ---
def __call__(self, X: Union[torch.Tensor, np.ndarray], **kwargs):
r"""Compute the affinity matrix from the input data.
Parameters
----------
X : torch.Tensor or np.ndarray of shape (n_samples, n_features)
Input data.
Returns
-------
affinity_matrix : torch.Tensor or pykeops.torch.LazyTensor
The computed affinity matrix.
"""
if not self._pre_processed:
X = to_torch(X)
return self._compute_affinity(X, **kwargs)
# --- Core computation (must be implemented by subclasses) ---
def _compute_affinity(self, X: torch.Tensor):
r"""Compute the affinity matrix. Must be overridden by subclasses."""
raise NotImplementedError(
"[TorchDR] ERROR : `_compute_affinity` method is not implemented."
)
# --- Distance computation ---
def _distance_matrix(
self, X: torch.Tensor, k: int = None, return_indices: bool = False
):
r"""Compute the pairwise distance matrix.
Parameters
----------
X : torch.Tensor of shape (n_samples, n_features)
Input data.
k : int, optional
Number of nearest neighbors. Default is None (full matrix).
return_indices : bool, optional
Whether to return k-NN indices. Default is False.
Returns
-------
C : torch.Tensor or pykeops.torch.LazyTensor
The pairwise distance matrix.
"""
return pairwise_distances(
X=X,
metric=self.metric,
backend=self.backend,
exclude_diag=self.zero_diag,
k=k,
return_indices=return_indices,
device=self.device,
)
# --- Utilities ---
def _get_compute_device(self, X):
"""Return the target device (from ``self.device`` or inferred from X)."""
if self.device != "auto":
return self.device
if isinstance(X, DataLoader):
from torchdr.distance.faiss import get_dataloader_metadata
metadata = get_dataloader_metadata(X)
if metadata is not None and "device" in metadata:
return metadata["device"]
for batch in X:
if isinstance(batch, (list, tuple)):
batch = batch[0]
return batch.device
return torch.device("cpu")
return X.device
def _get_n_samples(self, X):
"""Return the number of samples in the input."""
if isinstance(X, DataLoader):
return len(X.dataset)
return X.shape[0]
def _get_dtype(self, X):
"""Return the dtype of the input."""
if isinstance(X, DataLoader):
from torchdr.distance.faiss import get_dataloader_metadata
metadata = get_dataloader_metadata(X)
if metadata is not None:
return metadata["dtype"]
for batch in X:
if isinstance(batch, (list, tuple)):
batch = batch[0]
return batch.dtype
raise ValueError("[TorchDR] DataLoader is empty, cannot determine dtype.")
return X.dtype
# --- Memory management ---
[docs]
def clear_memory(self):
"""Clear non-persistent buffers to free memory."""
if hasattr(self, "_non_persistent_buffers_set"):
for name in list(self._non_persistent_buffers_set):
if hasattr(self, name):
delattr(self, name)
if torch.cuda.is_available():
torch.cuda.empty_cache()
[docs]
class LogAffinity(Affinity):
r"""Base class for affinity matrices in log domain.
Subclasses must implement :meth:`_compute_log_affinity`.
Parameters
----------
metric : str, optional
Distance metric for pairwise distances. Default is "sqeuclidean".
device : str, optional
Device for computation. ``"auto"`` uses the input data's device.
Default is "auto".
backend : {"keops", "faiss", None} or FaissConfig, optional
Backend for handling sparsity and memory efficiency.
Default is None (standard PyTorch).
verbose : bool, optional
Verbosity. Default is False.
compile : bool, optional
Whether to compile the affinity computation. Default is False.
_pre_processed : bool, optional
If True, skips ``to_torch`` conversion. Default is False.
"""
def __init__(
self,
metric: str = "sqeuclidean",
zero_diag: bool = True,
device: str = "auto",
backend: Union[str, FaissConfig] = None,
verbose: bool = False,
random_state: float = None,
compile: bool = False,
_pre_processed: bool = False,
):
super().__init__(
metric=metric,
zero_diag=zero_diag,
device=device,
backend=backend,
verbose=verbose,
random_state=random_state,
compile=compile,
_pre_processed=_pre_processed,
)
def __call__(
self,
X: Union[torch.Tensor, np.ndarray],
log: bool = False,
**kwargs: Any,
):
r"""Compute the affinity matrix (or its log) from the input data.
Parameters
----------
X : torch.Tensor or np.ndarray of shape (n_samples, n_features)
Input data.
log : bool, optional
If True, returns the log affinity. Otherwise, exponentiates it.
Returns
-------
affinity_matrix : torch.Tensor or pykeops.torch.LazyTensor
The affinity matrix (or log affinity if ``log=True``).
"""
if not self._pre_processed:
X = to_torch(X)
log_affinity = self._compute_log_affinity(X, **kwargs)
if log:
return log_affinity
else:
return log_affinity.exp()
def _compute_log_affinity(self, X: torch.Tensor, **kwargs):
r"""Compute the log affinity matrix. Must be overridden by subclasses."""
raise NotImplementedError(
"[TorchDR] ERROR : `_compute_log_affinity` method is not implemented."
)
class SparseAffinity(Affinity):
r"""Base class for sparse affinity matrices.
Returns the affinity matrix in rectangular format (n_samples, k) with
the corresponding k-NN indices when sparsity is enabled. Otherwise,
returns the full (n_samples, n_samples) matrix.
**Distributed training:** When ``distributed='auto'`` (default) and
launched with ``torchrun``, each GPU processes a chunk of the dataset
in parallel. Requires ``sparsity=True`` and ``backend="faiss"``.
Subclasses must implement :meth:`_compute_sparse_affinity`.
Parameters
----------
metric : str, optional
Distance metric for pairwise distances. Default is "sqeuclidean".
zero_diag : bool, optional
Whether to set the diagonal to zero. Default is True.
device : str, optional
Device for computation. ``"auto"`` uses the input data's device.
Default is "auto".
backend : {"keops", "faiss", None} or FaissConfig, optional
Backend for handling sparsity and memory efficiency.
Default is None (standard PyTorch).
verbose : bool, optional
Verbosity. Default is False.
compile : bool, optional
Whether to compile the affinity computation. Default is False.
sparsity : bool or 'auto', optional
Whether to use sparse (rectangular) format. Default is True.
distributed : bool or 'auto', optional
Whether to use distributed multi-GPU computation.
``"auto"`` detects ``torchrun`` automatically. Default is "auto".
_pre_processed : bool, optional
If True, skips ``to_torch`` conversion. Default is False.
"""
def __init__(
self,
metric: str = "sqeuclidean",
zero_diag: bool = True,
device: str = "auto",
backend: Union[str, FaissConfig] = None,
verbose: bool = False,
compile: bool = False,
sparsity: bool = True,
distributed: Union[bool, str] = "auto",
random_state: float = None,
_pre_processed: bool = False,
):
# --- Distributed setup ---
if distributed == "auto":
self.distributed = dist.is_initialized()
else:
self.distributed = bool(distributed)
if self.distributed:
if not dist.is_initialized():
raise RuntimeError(
"[TorchDR] distributed=True requires launching with "
"torchrun. "
"Example: torchrun --nproc_per_node=4 your_script.py"
)
self.dist_ctx = DistributedContext()
self.rank = self.dist_ctx.rank
self.world_size = self.dist_ctx.world_size
self.is_multi_gpu = self.world_size > 1
if device == "cpu":
raise ValueError(
"[TorchDR] Distributed mode requires GPU (device cannot be 'cpu')"
)
device = torch.device(f"cuda:{self.dist_ctx.local_rank}")
# Force sparsity and FAISS backend for distributed mode
self._sparsity_forced = not sparsity
if self._sparsity_forced:
sparsity = True
self._backend_forced = backend not in [
"faiss",
None,
] and not isinstance(backend, FaissConfig)
if self._backend_forced:
self._original_backend = backend
backend = "faiss"
else:
self.dist_ctx = None
self.rank = 0
self.world_size = 1
self.is_multi_gpu = False
super().__init__(
metric=metric,
zero_diag=zero_diag,
device=device,
backend=backend,
verbose=verbose,
random_state=random_state,
compile=compile,
_pre_processed=_pre_processed,
)
self.sparsity = sparsity
if self.distributed and self.verbose:
if self._sparsity_forced:
self.logger.warning(
"Distributed mode requires sparsity=True, enabling sparsity."
)
if self._backend_forced:
self.logger.warning(
f"Distributed mode requires FAISS backend, "
f"switching from '{self._original_backend}' to 'faiss'."
)
if self.is_multi_gpu:
self.logger.info(
f"Distributed mode enabled: rank {self.rank}/{self.world_size}"
)
# --- Sparsity property ---
@property
def sparsity(self):
"""Return the sparsity setting."""
return self._sparsity
@sparsity.setter
def sparsity(self, value):
"""Set the sparsity setting."""
self._sparsity = bool_arg(value)
# --- Public API ---
def __call__(
self,
X: Union[torch.Tensor, np.ndarray],
return_indices: bool = True,
**kwargs,
):
r"""Compute the sparse affinity matrix from the input data.
Parameters
----------
X : torch.Tensor or np.ndarray of shape (n_samples, n_features)
Input data.
return_indices : bool, optional
Whether to return k-NN indices. Default is True.
Returns
-------
affinity_matrix : torch.Tensor
The computed affinity matrix.
indices : torch.Tensor or None
k-NN indices if ``return_indices=True`` and sparsity is enabled.
"""
if not self._pre_processed:
X = to_torch(X)
return self._compute_sparse_affinity(X, return_indices, **kwargs)
# --- Core computation (must be implemented by subclasses) ---
def _compute_sparse_affinity(
self, X: torch.Tensor, return_indices: bool = True, **kwargs
):
r"""Compute the sparse affinity matrix. Must be overridden."""
raise NotImplementedError(
"[TorchDR] ERROR : `_compute_sparse_affinity` method is not implemented."
)
# --- Distance computation ---
def _distance_matrix(
self, X: torch.Tensor, k: int = None, return_indices: bool = False
):
"""Compute pairwise distances, passing distributed context if active.
Parameters
----------
X : torch.Tensor
Input data.
k : int, optional
Number of nearest neighbors.
return_indices : bool, default=False
Whether to return k-NN indices.
Returns
-------
distances : torch.Tensor
Distance matrix.
indices : torch.Tensor, optional
Indices if ``return_indices=True``.
"""
result = pairwise_distances(
X=X,
metric=self.metric,
backend=self.backend,
exclude_diag=self.zero_diag,
k=k,
return_indices=return_indices,
device=self.device,
distributed_ctx=self.dist_ctx if self.distributed else None,
)
# Store chunk bounds for downstream use (e.g. distributed symmetrization)
if self.distributed and self.dist_ctx is not None:
chunk_start, chunk_end = self.dist_ctx.compute_chunk_bounds(
self._get_n_samples(X)
)
self.chunk_start_ = chunk_start
self.chunk_end_ = chunk_end
self.chunk_size_ = chunk_end - chunk_start
return result
class SparseLogAffinity(SparseAffinity, LogAffinity):
r"""Base class for sparse log affinity matrices.
Combines :class:`SparseAffinity` (sparse format, distributed support)
with :class:`LogAffinity` (log-domain computation).
Subclasses must implement :meth:`_compute_sparse_log_affinity`.
Parameters
----------
metric : str, optional
Distance metric for pairwise distances. Default is "sqeuclidean".
zero_diag : bool, optional
Whether to set the diagonal to zero. Default is True.
device : str, optional
Device for computation. ``"auto"`` uses the input data's device.
Default is "auto".
backend : {"keops", "faiss", None} or FaissConfig, optional
Backend for handling sparsity and memory efficiency.
Default is None (standard PyTorch).
verbose : bool, optional
Verbosity. Default is False.
compile : bool, optional
Whether to compile the affinity computation. Default is False.
sparsity : bool or 'auto', optional
Whether to use sparse (rectangular) format. Default is True.
distributed : bool or 'auto', optional
Whether to use distributed multi-GPU computation.
``"auto"`` detects ``torchrun`` automatically. Default is "auto".
_pre_processed : bool, optional
If True, skips ``to_torch`` conversion. Default is False.
"""
def __call__(
self,
X: Union[torch.Tensor, np.ndarray],
log: bool = False,
return_indices: bool = True,
**kwargs,
):
r"""Compute the sparse (log) affinity matrix from the input data.
Parameters
----------
X : torch.Tensor or np.ndarray of shape (n_samples, n_features)
Input data.
log : bool, optional
If True, returns the log affinity. Otherwise, exponentiates it.
return_indices : bool, optional
Whether to return k-NN indices. Default is True.
Returns
-------
affinity_matrix : torch.Tensor
The affinity matrix (or log affinity if ``log=True``).
indices : torch.Tensor or None
k-NN indices if ``return_indices=True`` and sparsity is enabled.
"""
if not self._pre_processed:
X = to_torch(X)
if return_indices:
log_affinity, indices = self._compute_sparse_log_affinity(
X, return_indices, **kwargs
)
affinity_to_return = log_affinity if log else log_affinity.exp()
return (affinity_to_return, indices)
else:
log_affinity = self._compute_sparse_log_affinity(
X, return_indices, **kwargs
)
affinity_to_return = log_affinity if log else log_affinity.exp()
return affinity_to_return
def _compute_sparse_log_affinity(
self, X: torch.Tensor, return_indices: bool = False, **kwargs
):
r"""Compute the sparse log affinity matrix. Must be overridden."""
raise NotImplementedError(
"[TorchDR] ERROR : `_compute_sparse_log_affinity` method is "
"not implemented."
)