Source code for torchdr.utils.geometry

# -*- coding: utf-8 -*-
"""Ground metrics and distances."""

# Author: Hugues Van Assel <vanasselhugues@gmail.com>
#
# License: BSD 3-Clause License

import torch
from .keops import LazyTensor, pykeops

from torchdr.utils.utils import identity_matrix

LIST_METRICS = ["euclidean", "sqeuclidean", "manhattan", "angular", "hyperbolic"]


[docs] def pairwise_distances( X: torch.Tensor, Y: torch.Tensor = None, metric: str = "sqeuclidean", keops: bool = False, ): r"""Compute pairwise distances matrix between points in two datasets. Returns the pairwise distance matrix as torch tensor or KeOps lazy tensor (if keops is True). Parameters ---------- X : torch.Tensor of shape (n_samples, n_features) First dataset. Y : torch.Tensor of shape (m_samples, n_features), optional Second dataset. If None, Y = X. metric : str, optional Metric to use for computing distances. The default is "sqeuclidean". keops : bool, optional If True, uses KeOps for computing the distances. Returns ------- C : torch.Tensor or pykeops.torch.LazyTensor (if keops is True) of shape (n_samples, m_samples) Pairwise distances matrix. """ if Y is None: Y = X if keops and not pykeops: # pykeops no installed raise ValueError( "pykeops is not installed. Please install it to use `keops=true`." ) if keops: # recommended for large datasets C = _pairwise_distances_keops(X, Y, metric) else: C = _pairwise_distances_torch(X, Y, metric) return C
def symmetric_pairwise_distances( X: torch.Tensor, metric: str, keops: bool = False, add_diag: float = None ): r"""Compute pairwise distances matrix between points in a dataset. Return the pairwise distance matrix as torch tensor or KeOps lazy tensor (if keops is True). Supports batched input. The batch dimension should be the first. Parameters ---------- X : torch.Tensor of shape (n_samples, n_features) or (n_batch, n_samples_batch, n_features) Input dataset. metric : str, optional Metric to use for computing distances. The default is "sqeuclidean". keops : bool, optional If True, uses KeOps for computing the distances. add_diag : float, optional If not None, adds weight on the diagonal of the distance matrix. Returns ------- C : torch.Tensor or pykeops.torch.LazyTensor (if keops is True) of shape (n_samples, n_samples) or (n_batch, n_samples_batch, n_samples_batch) Pairwise distances matrix. """ # noqa E501 if keops and not pykeops: # pykeops no installed raise ValueError( "pykeops is not installed. Please install it to use `keops=true`." ) if keops: # recommended for large datasets C = _pairwise_distances_keops(X, metric=metric) else: C = _pairwise_distances_torch(X, metric=metric) if add_diag is not None: # add mass on the diagonal I = identity_matrix(C.shape[-1], keops, X.device, X.dtype) C += add_diag * I return C def _pairwise_distances_torch( X: torch.Tensor, Y: torch.Tensor = None, metric: str = "sqeuclidean" ): r"""Compute pairwise distances matrix between points in two datasets. Return the pairwise distance matrix as a torch tensor. Parameters ---------- X : torch.Tensor of shape (n_samples, n_features) First dataset. Y : torch.Tensor of shape (m_samples, n_features) Second dataset. metric : str Metric to use for computing distances. Returns ------- C : torch.Tensor of shape (n_samples, m_samples) Pairwise distances matrix. """ if metric not in LIST_METRICS: raise ValueError(f"[TorchDR] ERROR : The '{metric}' distance is not supported.") if Y is None: Y = X if metric == "sqeuclidean": X_norm = (X**2).sum(-1) Y_norm = (Y**2).sum(-1) C = X_norm.unsqueeze(-1) + Y_norm.unsqueeze(-2) - 2 * X @ Y.transpose(-1, -2) elif metric == "euclidean": X_norm = (X**2).sum(-1) Y_norm = (Y**2).sum(-1) C = X_norm.unsqueeze(-1) + Y_norm.unsqueeze(-2) - 2 * X @ Y.transpose(-1, -2) C = torch.clip( C, min=0.0 ).sqrt() # negative values can appear because of float precision elif metric == "manhattan": C = (X.unsqueeze(-2) - Y.unsqueeze(-3)).abs().sum(-1) elif metric == "angular": C = -X @ Y.transpose(-1, -2) elif metric == "hyperbolic": X_norm = (X**2).sum(-1) Y_norm = (Y**2).sum(-1) C = ( X_norm.unsqueeze(-1) + Y_norm.unsqueeze(-2) - 2 * X @ Y.transpose(-1, -2) ) / (X[..., 0].unsqueeze(-1) * Y[..., 0].unsqueeze(-2)) return C def _pairwise_distances_keops( X: torch.Tensor, Y: torch.Tensor = None, metric: str = "sqeuclidean" ): r"""Compute pairwise distances matrix between points in two datasets. Return the pairwise distance matrix as KeOps lazy tensor. Parameters ---------- X : torch.Tensor of shape (n_samples, n_features) First dataset. Y : torch.Tensor of shape (m_samples, n_features) Second dataset. metric : str Metric to use for computing distances. Returns ------- C : pykeops.torch.LazyTensor of shape (n_samples, m_samples) Pairwise distances matrix. """ if metric not in LIST_METRICS: raise ValueError(f"[TorchDR] ERROR : The '{metric}' distance is not supported.") if Y is None: Y = X X_i = LazyTensor(X.unsqueeze(-2)) Y_j = LazyTensor(Y.unsqueeze(-3)) if metric == "sqeuclidean": C = ((X_i - Y_j) ** 2).sum(-1) elif metric == "euclidean": C = ((X_i - Y_j) ** 2).sum(-1) ** (1.0 / 2.0) elif metric == "manhattan": C = (X_i - Y_j).abs().sum(-1) elif metric == "angular": C = -(X_i | Y_j) elif metric == "hyperbolic": C = ((X_i - Y_j) ** 2).sum(-1) / (X_i[0] * Y_j[0]) return C def symmetric_pairwise_distances_indices( X: torch.Tensor, indices: torch.Tensor, metric: str = "sqeuclidean", ): r"""Compute pairwise distances for a subset of pairs given by indices. The output distance matrix has shape (n, k) and its (i,j) element is the distance between X[i] and Y[indices[i, j]]. Parameters ---------- X : torch.Tensor of shape (n, p) Input dataset. indices : torch.Tensor of shape (n, k) Indices of the pairs for which to compute the distances. metric : str, optional Metric to use for computing distances. The default is "sqeuclidean". Returns ------- C_indices : torch.Tensor of shape (n, k) Pairwise distances matrix for the subset of pairs. """ X_indices = X[indices.int()] # Shape (n, k, p) if metric == "sqeuclidean": C_indices = torch.sum((X.unsqueeze(1) - X_indices) ** 2, dim=-1) elif metric == "euclidean": C_indices = torch.sum((X.unsqueeze(1) - X_indices) ** 2, dim=-1).sqrt() elif metric == "manhattan": C_indices = torch.sum(torch.abs(X.unsqueeze(1) - X_indices), dim=-1) elif metric == "angular": C_indices = -torch.sum(X.unsqueeze(1) * X_indices, dim=-1) elif metric == "hyperbolic": C_indices = torch.sum((X.unsqueeze(1) - X_indices) ** 2, dim=-1) / ( X[:, 0].unsqueeze(1) * X_indices[:, :, 0] ) else: raise NotImplementedError(f"Metric '{metric}' is not (yet) implemented.") return C_indices