Source code for torchdr.eval.knn_labels

"""K-NN label accuracy metric for evaluating embedding quality."""

# Author: Hugues Van Assel <vanasselhugues@gmail.com>
#
# License: BSD 3-Clause License

import numpy as np
import torch
import torch.distributed as dist
from typing import Union, Optional

from torchdr.utils import to_torch
from torchdr.distributed import DistributedContext
from torchdr.distance import pairwise_distances, FaissConfig


[docs] def knn_label_accuracy( X: Union[torch.Tensor, np.ndarray], labels: Union[torch.Tensor, np.ndarray], k: int = 10, metric: str = "euclidean", backend: Optional[Union[str, FaissConfig]] = "faiss", exclude_self: bool = True, distributed: Union[bool, str] = "auto", return_per_sample: bool = False, device: Optional[str] = None, ): r"""Compute k-NN label accuracy to evaluate class structure preservation. This metric measures how well local class structure is preserved in the data representation (original or embedded space). For each point, it computes the proportion of its k-nearest neighbors that share the same class label. Parameters ---------- X : torch.Tensor or np.ndarray of shape (n_samples, n_features) Data representation (can be original features or embeddings). labels : torch.Tensor or np.ndarray of shape (n_samples,) Class labels for each sample. Can be integers or any comparable type. k : int, default=10 Number of nearest neighbors to consider for each point. metric : str, default='euclidean' Distance metric to use for computing nearest neighbors. Options: 'euclidean', 'sqeuclidean', 'manhattan', 'angular'. backend : {'keops', 'faiss', None} or FaissConfig, default='faiss' Backend to use for k-NN computation: - 'keops': Memory-efficient symbolic computations - 'faiss': Fast approximate nearest neighbors (recommended for large datasets) - None: Standard PyTorch operations - FaissConfig object: FAISS with custom configuration exclude_self : bool, default=True Whether to exclude the point itself from its neighborhood. Usually True when evaluating on the same dataset used for k-NN search. distributed : bool or 'auto', default='auto' Whether to use multi-GPU distributed computation. - 'auto': Automatically detects if torch.distributed is initialized - True: Forces distributed mode (requires torch.distributed to be initialized) - False: Disables distributed mode When enabled: - Each GPU computes accuracy for its assigned chunk of samples - Device is automatically set to the local GPU rank - Backend is forced to 'faiss' for efficient distributed k-NN - Returns per-chunk results (no automatic gathering across GPUs) Requires launching with torchrun: ``torchrun --nproc_per_node=N script.py`` return_per_sample : bool, default=False If True, returns per-sample accuracies instead of the mean. Shape: (n_samples,) or (chunk_size,) in distributed mode. device : str, optional Device to use for computation. If None, uses input device. Returns ------- accuracy : float or torch.Tensor If return_per_sample=False: Mean k-NN label accuracy across all samples. If return_per_sample=True: Per-sample k-NN label accuracies. Value between 0 and 1, where 1 means all neighbors have same label. Returns numpy array/float if inputs are numpy, torch.Tensor otherwise. Examples -------- >>> import torch >>> from torchdr.eval import knn_label_accuracy >>> >>> # Generate example data with 3 classes >>> X = torch.randn(300, 50) >>> labels = torch.repeat_interleave(torch.arange(3), 100) >>> >>> # Compute k-NN label accuracy >>> accuracy = knn_label_accuracy(X, labels, k=10) >>> print(f"k-NN label accuracy: {accuracy:.3f}") >>> >>> # Get per-sample accuracies >>> per_sample = knn_label_accuracy(X, labels, k=10, return_per_sample=True) >>> print(f"Mean: {per_sample.mean():.3f}, Std: {per_sample.std():.3f}") >>> >>> # Distributed computation (launch with: torchrun --nproc_per_node=4 script.py) >>> accuracy_chunk = knn_label_accuracy(X, labels, k=10, distributed=True) Notes ----- This metric is useful for: - Evaluating how well embeddings preserve class structure - Comparing different DR methods on classification tasks - Assessing the quality of learned representations Higher values indicate better preservation of local class homogeneity. The metric is sensitive to class imbalance and noise in labels. In distributed mode, each GPU computes accuracy for its chunk of the dataset. To get the global accuracy, gather results from all GPUs and compute the mean. """ if k < 1: raise ValueError(f"k must be at least 1, got {k}") input_is_numpy = not isinstance(X, torch.Tensor) or not isinstance( labels, torch.Tensor ) X = to_torch(X) labels = to_torch(labels) if X.shape[0] != labels.shape[0]: raise ValueError( f"X and labels must have same number of samples, " f"got {X.shape[0]} and {labels.shape[0]}" ) n_samples = X.shape[0] if k >= n_samples: raise ValueError(f"k ({k}) must be less than number of samples ({n_samples})") if distributed == "auto": distributed = dist.is_initialized() else: distributed = bool(distributed) if distributed: if not dist.is_initialized(): raise RuntimeError( "[TorchDR] distributed=True requires launching with torchrun. " "Example: torchrun --nproc_per_node=4 your_script.py" ) dist_ctx = DistributedContext() if device is None: device = X.device elif device == "cpu": raise ValueError( "[TorchDR] Distributed mode requires GPU (device cannot be 'cpu')" ) device = torch.device(f"cuda:{dist_ctx.local_rank}") else: dist_ctx = None if device is None: device = X.device else: device = torch.device(device) X = X.to(device) labels = labels.to(device) _, neighbor_indices = pairwise_distances( X, metric=metric, backend=backend, k=k, exclude_diag=exclude_self, return_indices=True, device=device, distributed_ctx=dist_ctx, ) neighbor_labels = labels[neighbor_indices] if dist_ctx is not None and dist_ctx.is_initialized: chunk_start, chunk_end = dist_ctx.compute_chunk_bounds(n_samples) query_labels = labels[chunk_start:chunk_end].unsqueeze(1) else: query_labels = labels.unsqueeze(1) matches = (neighbor_labels == query_labels).float() accuracies = matches.mean(dim=1) if return_per_sample: result = accuracies if input_is_numpy: result = result.detach().cpu().numpy() else: result = accuracies.mean() if input_is_numpy: result = result.detach().cpu().numpy().item() return result