Source code for torchdr.eval.neighborhood_preservation

"""K-ary neighborhood preservation metric for dimensionality reduction evaluation."""

# Author: Hugues Van Assel <vanasselhugues@gmail.com>
#
# License: BSD 3-Clause License

import numpy as np
import torch
import torch.distributed as dist
from typing import Union, Optional

from torchdr.utils import to_torch
from torchdr.distributed import DistributedContext
from torchdr.distance import pairwise_distances, FaissConfig


[docs] def neighborhood_preservation( X: Union[torch.Tensor, np.ndarray], Z: Union[torch.Tensor, np.ndarray], K: int, metric: str = "euclidean", backend: Optional[Union[str, FaissConfig]] = None, device: Optional[str] = None, distributed: Union[bool, str] = "auto", return_per_sample: bool = False, ): r"""Compute K-ary neighborhood preservation between input data and embeddings. This metric measures how well local neighborhood structure is preserved when reducing from high-dimensional input data (X) to low-dimensional embeddings (Z). Parameters ---------- X : torch.Tensor or np.ndarray of shape (n_samples, n_features) Original high-dimensional data. Z : torch.Tensor or np.ndarray of shape (n_samples, n_features_reduced) Reduced low-dimensional embeddings. K : int Neighborhood size (number of nearest neighbors to consider). metric : str, default='euclidean' Distance metric to use for computing nearest neighbors. Options: 'euclidean', 'sqeuclidean', 'manhattan', 'angular'. backend : {'keops', 'faiss', None} or FaissConfig, optional Backend to use for k-NN computation: - 'keops': Memory-efficient symbolic computations - 'faiss': Fast approximate nearest neighbors (recommended for large datasets) - None: Standard PyTorch operations - FaissConfig object: FAISS with custom configuration device : str, optional Device to use for computation. If None, uses input device. distributed : bool or 'auto', default='auto' Whether to use multi-GPU distributed computation. - 'auto': Automatically detects if torch.distributed is initialized - True: Forces distributed mode (requires torch.distributed to be initialized) - False: Disables distributed mode When enabled: - Each GPU computes preservation for its assigned chunk of samples - Automatically creates DistributedContext if torch.distributed is initialized - Device is automatically set to the local GPU rank - Backend is forced to 'faiss' for efficient distributed k-NN - Returns per-chunk results (no automatic gathering across GPUs) Requires launching with torchrun: ``torchrun --nproc_per_node=N script.py`` return_per_sample : bool, default=False If True, returns per-sample preservation scores instead of the mean. Shape: (n_samples,) or (chunk_size,) in distributed mode. Returns ------- score : float or torch.Tensor If return_per_sample=False: Mean neighborhood preservation across all samples. If return_per_sample=True: Per-sample neighborhood preservation scores. Value between 0 and 1, where 1 indicates perfect preservation. Returns numpy array/float if inputs are numpy, torch.Tensor otherwise. Examples -------- >>> import torch >>> from torchdr.eval.neighborhood_preservation import neighborhood_preservation >>> >>> # Generate example data >>> X = torch.randn(100, 50) # High-dimensional data >>> Z = torch.randn(100, 2) # Low-dimensional embedding >>> >>> # Compute preservation score >>> score = neighborhood_preservation(X, Z, K=10) >>> print(f"Neighborhood preservation: {score:.3f}") Notes ----- The metric computes the Jaccard similarity (intersection over union) between the K-nearest neighbor sets in the original and reduced spaces for each point, then averages across all points. For large datasets, using backend='faiss' is recommended for efficiency. The metric excludes self-neighbors (i.e., the point itself). """ if K < 1: raise ValueError(f"K must be at least 1, got {K}") input_is_numpy = not isinstance(X, torch.Tensor) or not isinstance(Z, torch.Tensor) X = to_torch(X) Z = to_torch(Z) if X.shape[0] != Z.shape[0]: raise ValueError( f"X and Z must have same number of samples, got {X.shape[0]} and {Z.shape[0]}" ) n_samples = X.shape[0] if K >= n_samples: raise ValueError(f"K ({K}) must be less than number of samples ({n_samples})") if distributed == "auto": distributed = dist.is_initialized() else: distributed = bool(distributed) if distributed: if not dist.is_initialized(): raise RuntimeError( "[TorchDR] distributed=True requires launching with torchrun. " "Example: torchrun --nproc_per_node=4 your_script.py" ) dist_ctx = DistributedContext() if device is None: device = X.device elif device == "cpu": raise ValueError( "[TorchDR] Distributed mode requires GPU (device cannot be 'cpu')" ) device = torch.device(f"cuda:{dist_ctx.local_rank}") else: dist_ctx = None if device is None: device = X.device else: device = torch.device(device) X = X.to(device) Z = Z.to(device) _, neighbors_X = pairwise_distances( X, metric=metric, backend=backend, k=K, exclude_diag=True, return_indices=True, device=device, distributed_ctx=dist_ctx, ) _, neighbors_Z = pairwise_distances( Z, metric=metric, backend=backend, k=K, exclude_diag=True, return_indices=True, device=device, distributed_ctx=dist_ctx, ) # Vectorized computation using broadcasting to check neighborhood overlap neighbors_X_expanded = neighbors_X.unsqueeze(2) # (chunk_size, K, 1) neighbors_Z_expanded = neighbors_Z.unsqueeze(1) # (chunk_size, 1, K) matches = (neighbors_X_expanded == neighbors_Z_expanded).any( dim=2 ) # (chunk_size, K) overlaps = matches.float().sum(dim=1) / K if return_per_sample: result = overlaps if input_is_numpy: result = result.detach().cpu().numpy() else: result = overlaps.mean() if input_is_numpy: result = result.detach().cpu().numpy().item() return result