"""Affinity matrices with normalizations using nearest neighbor distances."""
# Author: Hugues Van Assel <vanasselhugues@gmail.com>
# Cédric Vincent-Cuaz <cedric.vincent-cuaz@inria.fr>
#
# License: BSD 3-Clause License
from typing import Tuple, Union, Optional
import torch
import math
from torchdr.affinity.base import Affinity, LogAffinity, SparseLogAffinity
from torchdr.utils import (
matrix_transpose,
kmin,
logsumexp_red,
sum_red,
wrap_vectors,
pairwise_distances,
matrix_power,
check_neighbor_param,
binary_search,
)
@wrap_vectors
def _log_P_SelfTuning(C, sigma):
sigma_t = matrix_transpose(sigma)
return -C / (sigma * sigma_t)
@wrap_vectors
def _log_P_MAGIC(C, sigma):
return -C / sigma
@wrap_vectors
def _log_P_UMAP(C, rho, sigma):
return -(C - rho) / sigma
@wrap_vectors
def _log_P_PHATE(C, sigma, alpha=10.0):
return -((C / sigma) ** alpha)
[docs]
class SelfTuningAffinity(LogAffinity):
r"""Self-tuning affinity introduced in :cite:`zelnik2004self`.
The affinity has a sample-wise bandwidth :math:`\mathbf{\sigma} \in \mathbb{R}^n`.
.. math::
\exp \left( - \frac{C_{ij}}{\sigma_i \sigma_j} \right)
In the above, :math:`\mathbf{C}` is the pairwise distance matrix and
:math:`\sigma_i` is the distance from the K'th nearest neighbor of data point
:math:`\mathbf{x}_i`.
Parameters
----------
K : int, optional
K-th neirest neighbor .
normalization_dim : int or Tuple[int], optional
Dimension along which to normalize the affinity matrix.
metric : str, optional
Metric to use for pairwise distances computation.
zero_diag : bool, optional
Whether to set the diagonal of the affinity matrix to zero.
device : str, optional
Device to use for computations.
backend : {"keops", "faiss", None}, optional
Which backend to use for handling sparsity and memory efficiency.
Default is None.
verbose : bool, optional
Verbosity. Default is False.
"""
def __init__(
self,
K: int = 7,
normalization_dim: Union[int, Tuple[int]] = (0, 1),
metric: str = "sqeuclidean",
zero_diag: bool = True,
device: Optional[str] = None,
backend: Optional[str] = None,
verbose: bool = False,
):
super().__init__(
metric=metric,
zero_diag=zero_diag,
device=device,
backend=backend,
verbose=verbose,
)
self.K = K
self.normalization_dim = normalization_dim
def _compute_log_affinity(self, X: torch.Tensor):
r"""Fit the self-tuning affinity model to the provided data.
Parameters
----------
X : torch.Tensor
Input data.
Returns
-------
log_affinity_matrix : torch.Tensor or pykeops.torch.LazyTensor
The computed affinity matrix in log domain.
"""
C, _ = self._distance_matrix(X)
minK_values, _ = kmin(C, k=self.K, dim=1)
self.sigma_ = minK_values[:, -1]
log_affinity_matrix = _log_P_SelfTuning(C, self.sigma_)
if self.normalization_dim is not None:
self.log_normalization_ = logsumexp_red(
log_affinity_matrix, self.normalization_dim
)
log_affinity_matrix = log_affinity_matrix - self.log_normalization_
return log_affinity_matrix
[docs]
class MAGICAffinity(Affinity):
r"""Compute the MAGIC affinity with alpha-decay kernel introduced in :cite:`van2018recovering`.
The construction is as follows. First, it computes a generalized
kernel with sample-wise bandwidth :math:`\mathbf{\sigma} \in \mathbb{R}^n`:
.. math::
P_{ij} \leftarrow \exp \left( - \frac{C_{ij}}{\sigma_i} \right)
In the above, :math:`\mathbf{C}` is the pairwise distance matrix and
:math:`\sigma_i` is the distance from the K'th nearest neighbor of data point
:math:`\mathbf{x}_i`.
Then it averages the affinity matrix with its transpose:
.. math::
P_{ij} \leftarrow \frac{P_{ij} + P_{ji}}{2} \:.
Finally, it normalizes the affinity matrix along each row:
.. math::
P_{ij} \leftarrow \frac{P_{ij}}{\sum_{t} P_{it}} \:.
Parameters
----------
K : int, optional
K-th neirest neighbor. Default is 7.
metric : str, optional
Metric to use for pairwise distances computation.
zero_diag : bool, optional
Whether to set the diagonal of the affinity matrix to zero.
device : str, optional
Device to use for computations.
backend : {"keops", "faiss", None}, optional
Which backend to use for handling sparsity and memory efficiency.
Default is None.
verbose : bool, optional
Verbosity. Default is False.
"""
def __init__(
self,
K: int = 7,
metric: str = "sqeuclidean",
zero_diag: bool = True,
device: Optional[str] = None,
backend: Optional[str] = None,
verbose: bool = False,
):
super().__init__(
metric=metric,
zero_diag=zero_diag,
device=device,
backend=backend,
verbose=verbose,
)
self.K = K
def _compute_affinity(self, X: torch.Tensor):
r"""Fit the MAGIC affinity model to the provided data.
Parameters
----------
X : torch.Tensor
Input data.
Returns
-------
affinity_matrix : torch.Tensor or pykeops.torch.LazyTensor
The computed affinity matrix.
"""
C, _ = self._distance_matrix(X)
minK_values, _ = kmin(C, k=self.K, dim=1)
self.sigma_ = minK_values[:, -1]
affinity_matrix = _log_P_MAGIC(C, self.sigma_).exp()
affinity_matrix = (affinity_matrix + matrix_transpose(affinity_matrix)) / 2
affinity_matrix = affinity_matrix / sum_red(affinity_matrix, dim=1)
return affinity_matrix
[docs]
class PHATEAffinity(Affinity):
r"""Compute the potential affinity used in PHATE :cite:`moon2019visualizing`.
The method follows these steps:
1. Compute pairwise distance matrix
2. Find k-th nearest neighbor distances to set bandwidth sigma
3. Compute base affinity with alpha-decay kernel: exp(-((d/sigma)^alpha))
4. Symmetrize the affinity matrix
5. Row-normalize to create diffusion matrix
6. Raise diffusion matrix to power t (diffusion steps)
7. Compute potential distances from the diffused matrix
8. Return negative potential distances as affinities
Parameters
----------
metric : str, optional (default="euclidean")
Metric to use for pairwise distances computation.
device : str, optional (default=None)
Device to use for computations. If None, uses the device of input data.
backend : {"keops", "faiss", None}, optional (default=None)
Which backend to use for handling sparsity and memory efficiency.
verbose : bool, optional (default=False)
Whether to print verbose output during computation.
k : int, optional (default=5)
Number of nearest neighbors used to determine bandwidth parameter sigma.
alpha : float, optional (default=10.0)
Exponent for the alpha-decay kernel in affinity computation.
t : int, optional (default=5)
Number of diffusion steps (power to raise diffusion matrix).
"""
def __init__(
self,
metric: str = "euclidean",
device: str = None,
backend: Optional[str] = None,
verbose: bool = False,
k: int = 5,
alpha: float = 10.0,
t: int = 5,
):
if backend == "faiss" or backend == "keops":
raise ValueError(
f"[TorchDR] ERROR : {self.__class__.__name__} class does not support backend {backend}."
)
super().__init__(
metric=metric,
device=device,
backend=backend,
verbose=verbose,
zero_diag=False,
)
self.alpha = alpha
self.k = k
self.t = t
def _compute_affinity(self, X: torch.Tensor):
C, _ = self._distance_matrix(X)
minK_values, _ = kmin(C, k=self.k, dim=1)
self.sigma_ = minK_values[:, -1]
affinity = _log_P_PHATE(C, self.sigma_, self.alpha).exp()
affinity = (affinity + matrix_transpose(affinity)) / 2
affinity = affinity / sum_red(affinity, dim=1)
affinity = matrix_power(affinity, self.t)
affinity = -pairwise_distances(
-affinity.clamp(min=1e-12).log(), metric="euclidean", backend=self.backend
)[0]
return affinity
[docs]
class UMAPAffinityIn(SparseLogAffinity):
r"""Compute the input affinity used in UMAP :cite:`mcinnes2018umap`.
The algorithm computes via root search the variable
:math:`\mathbf{\sigma}^* \in \mathbb{R}^n_{>0}` such that
.. math::
\forall (i,j), \: P_{ij} = \exp(- (C_{ij} - \rho_i) / \sigma^\star_i) \quad \text{where} \quad \forall i, \: \sum_j P_{ij} = \log (\mathrm{n_neighbors})
and :math:`\rho_i = \min_j C_{ij}`.
Parameters
----------
n_neighbors : float, optional
Number of effective nearest neighbors to consider. Similar to the perplexity.
tol : float, optional
Precision threshold for the root search.
max_iter : int, optional
Maximum number of iterations for the root search.
sparsity : bool, optional
Whether to use sparsity mode.
Default is True.
metric : str, optional
Metric to use for pairwise distances computation.
zero_diag : bool, optional
Whether to set the diagonal of the affinity matrix to zero.
device : str, optional
Device to use for computations.
backend : {"keops", "faiss", None}, optional
Which backend to use for handling sparsity and memory efficiency.
Default is None.
verbose : bool, optional
Verbosity. Default is False.
""" # noqa: E501
def __init__(
self,
n_neighbors: float = 30,
tol: float = 1e-5,
max_iter: int = 1000,
sparsity: bool = True,
metric: str = "sqeuclidean",
zero_diag: bool = True,
device: str = "auto",
backend: Optional[str] = None,
verbose: bool = False,
):
self.n_neighbors = n_neighbors
self.tol = tol
self.max_iter = max_iter
super().__init__(
metric=metric,
zero_diag=zero_diag,
device=device,
backend=backend,
verbose=verbose,
sparsity=sparsity,
)
def _compute_sparse_log_affinity(self, X: torch.Tensor):
r"""Compute the input affinity matrix of UMAP from input data X.
Parameters
----------
X : torch.Tensor or np.ndarray of shape (n_samples, n_features)
Data on which affinity is computed.
Returns
-------
self : UMAPAffinityIn
The fitted instance.
"""
n_samples_in = X.shape[0]
n_neighbors = check_neighbor_param(self.n_neighbors, n_samples_in)
if self.sparsity:
if self.verbose:
self.logger.info(
f"Sparsity mode enabled, computing {n_neighbors} nearest neighbors."
)
# when using sparsity, we construct a reduced distance matrix
# of shape (n_samples, n_neighbors)
C_, indices = self._distance_matrix(X, k=n_neighbors)
else:
C_, indices = self._distance_matrix(X)
self.rho_ = kmin(C_, k=1, dim=1)[0].squeeze().contiguous()
def marginal_gap(eps): # function to find the root of
marg = _log_P_UMAP(C_, self.rho_, eps).logsumexp(1).exp().squeeze()
return marg - math.log(n_neighbors)
self.eps_ = binary_search(
f=marginal_gap,
n=n_samples_in,
tol=self.tol,
max_iter=self.max_iter,
verbose=self.verbose,
dtype=X.dtype,
device=X.device,
logger=self.logger if self.verbose else None,
)
log_affinity_matrix = _log_P_UMAP(C_, self.rho_, self.eps_)
return log_affinity_matrix, indices
[docs]
class PACMAPAffinity(SparseLogAffinity):
r"""Compute the input affinity used in PACMAP :cite:`wang2021understanding`.
Parameters
----------
n_neighbors : float, optional
Number of effective nearest neighbors to consider. Similar to the perplexity.
tol : float, optional
Precision threshold for the root search.
metric : str, optional
Metric to use for pairwise distances computation.
zero_diag : bool, optional
Whether to set the diagonal of the affinity matrix to zero.
device : str, optional
Device to use for computations.
backend : {"keops", "faiss", None}, optional
Which backend to use for handling sparsity and memory efficiency.
Default is None.
verbose : bool, optional
Verbosity. Default is False.
""" # noqa: E501
def __init__(
self,
n_neighbors: float = 10,
metric: str = "sqeuclidean",
zero_diag: bool = True,
device: str = "auto",
backend: Optional[str] = None,
verbose: bool = False,
):
self.n_neighbors = n_neighbors
super().__init__(
metric=metric,
zero_diag=zero_diag,
device=device,
backend=backend,
verbose=verbose,
sparsity=True, # PACMAP uses sparsity mode
)
def _compute_sparse_log_affinity(self, X: torch.Tensor):
r"""Compute the input affinity matrix of PACMAP from input data X.
Parameters
----------
X : torch.Tensor
Input data.
Returns
-------
self : PACMAPAffinityIn
The fitted instance.
"""
n_samples_in = X.shape[0]
k = min(self.n_neighbors + 50, n_samples_in)
k = check_neighbor_param(k, n_samples_in)
if self.verbose:
self.logger.info(f"Sparsity mode enabled, computing {k} nearest neighbors.")
C_, temp_indices = self._distance_matrix(X, k=k)
# Compute rho as the average distance between the 4th to 6th neighbors
sq_neighbor_distances, _ = kmin(C_, k=6, dim=1)
self.rho_ = torch.sqrt(sq_neighbor_distances)[:, 3:6].mean(dim=1).contiguous()
rho_i = self.rho_.unsqueeze(1) # Shape: (n_samples, 1)
rho_j = self.rho_[temp_indices] # Shape: (n_samples, k)
normalized_C = C_ / rho_i * rho_j
# Compute final NN indices
_, local_indices = kmin(normalized_C, k=self.n_neighbors, dim=1)
final_indices = torch.gather(temp_indices, 1, local_indices.to(torch.int64))
return None, final_indices # PACMAP only uses the NN indices