Source code for torchdr.neighbor_embedding.pacmap

"""PACMAP algorithm."""

# Author: Hugues Van Assel <vanasselhugues@gmail.com>
#
# License: BSD 3-Clause License

import torch
from torchdr.neighbor_embedding.base import SampledNeighborEmbedding
from typing import Union, Optional, Dict, Type, Any
from torchdr.affinity import PACMAPAffinity
from torchdr.utils import kmin, sum_red, symmetric_pairwise_distances_indices


[docs] class PACMAP(SampledNeighborEmbedding): r"""PACMAP algorithm introduced in :cite:`wang2021understanding`. It uses a :class:`~torchdr.PACMAPAffinity` as input affinity. The loss function is defined as: .. math:: w_{\mathrm{NB}} \sum_{i, j \in \mathrm{NB}(i)} \frac{d_{ij}}{10 + d_{ij}} + w_{\mathrm{MN}} \sum_{i,j \in \mathrm{MN}(i)} \frac{d_{ij}}{10^4 + d_{ij}} + w_{\mathrm{FP}} \sum_{i,j \in \mathrm{FP}(i)} \frac{1}{1 + d_{ij}} where :math:`\mathrm{NB}(i)`, :math:`\mathrm{MN}(i)` and :math:`\mathrm{FP}(i)` are the nearest neighbors, mid-near neighbors and far neighbors of point :math:`i` respectively, and :math:`d_{ij} = 1 + \|\mathbf{z}_i - \mathbf{z}_j\|^2` (more details in :cite:`wang2021understanding`). Parameters ---------- n_neighbors : int, optional Number of nearest neighbors. n_components : int, optional Dimension of the embedding space. lr : float or 'auto', optional Learning rate for the algorithm, by default 1e0. optimizer : str or torch.optim.Optimizer, optional Name of an optimizer from torch.optim or an optimizer class. Default is "Adam". optimizer_kwargs : dict or 'auto', optional Additional keyword arguments for the optimizer. Default is None, which sets appropriate momentum values for SGD based on early exaggeration phase. scheduler : str or torch.optim.lr_scheduler.LRScheduler, optional Name of a scheduler from torch.optim.lr_scheduler or a scheduler class. Default is None (no scheduler). scheduler_kwargs : dict, optional Additional keyword arguments for the scheduler. init : {'normal', 'pca'} or torch.Tensor of shape (n_samples, output_dim), optional Initialization for the embedding Z, default 'pca'. init_scaling : float, optional Scaling factor for the initialization, by default 1e-4. min_grad_norm : float, optional Precision threshold at which the algorithm stops, by default 1e-7. max_iter : int, optional Number of maximum iterations for the descent algorithm, by default 450. device : str, optional Device to use, by default "auto". backend : {"keops", "faiss", None}, optional Which backend to use for handling sparsity and memory efficiency. Default is "faiss". verbose : bool, optional Verbosity, by default False. random_state : float, optional Random seed for reproducibility, by default None. metric_in : {'sqeuclidean', 'manhattan'}, optional Metric to use for the input affinity, by default 'sqeuclidean'. metric_out : {'sqeuclidean', 'manhattan'}, optional Metric to use for the output affinity, by default 'sqeuclidean'. MN_ratio : float, optional Ratio of mid-near pairs to nearest neighbor pairs, by default 0.5. FP_ratio : float, optional Ratio of far pairs to nearest neighbor pairs, by default 2. check_interval : int, optional Interval for checking convergence, by default 50. iter_per_phase : int, optional Number of iterations for each phase of the algorithm, by default 100. """ # noqa: E501 def __init__( self, n_neighbors: float = 10, n_components: int = 2, lr: Union[float, str] = 1e0, optimizer: Union[str, Type[torch.optim.Optimizer]] = "Adam", optimizer_kwargs: Optional[Union[Dict, str]] = None, scheduler: Optional[ Union[str, Type[torch.optim.lr_scheduler.LRScheduler]] ] = None, scheduler_kwargs: Optional[Dict] = None, init: str = "pca", init_scaling: float = 1e-4, min_grad_norm: float = 1e-7, max_iter: int = 450, device: Optional[str] = None, backend: Optional[str] = "faiss", verbose: bool = False, random_state: Optional[float] = None, metric_in: str = "sqeuclidean", metric_out: str = "sqeuclidean", MN_ratio: float = 0.5, FP_ratio: float = 2, check_interval: int = 50, iter_per_phase: int = 100, ): self.n_neighbors = n_neighbors self.metric_in = metric_in self.metric_out = metric_out self.MN_ratio = MN_ratio self.FP_ratio = FP_ratio self.n_mid_near = int(MN_ratio * n_neighbors) self.n_further = int(FP_ratio * n_neighbors) self.iter_per_phase = iter_per_phase affinity_in = PACMAPAffinity( n_neighbors=n_neighbors, metric=metric_in, device=device, backend=backend, verbose=verbose, ) super().__init__( affinity_in=affinity_in, affinity_out=None, n_components=n_components, optimizer=optimizer, optimizer_kwargs=optimizer_kwargs, min_grad_norm=min_grad_norm, max_iter=max_iter, lr=lr, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, init=init, init_scaling=init_scaling, device=device, backend=backend, verbose=verbose, random_state=random_state, check_interval=check_interval, n_negatives=self.n_further, ) def _fit_transform(self, X: torch.Tensor, y: Optional[Any] = None): self.X_ = X # Keep input data to compute mid-near loss self._set_weights() return super()._fit_transform(X, y) def _set_weights(self): if self.n_iter_ < self.iter_per_phase: self.w_NB = 2 self.w_MN = ( 1000 * (1 - self.n_iter_ / self.iter_per_phase) + 3 * self.n_iter_ / self.iter_per_phase ) self.w_FP = 1 elif self.n_iter_ < 2 * self.iter_per_phase: self.w_NB = 3 self.w_MN = 3 self.w_FP = 1 else: self.w_NB = 1 self.w_MN = 0 self.w_FP = 1 def _after_step(self): self._set_weights() def _attractive_loss(self): # Attractive loss with nearest neighbors Q_near = ( 1 + symmetric_pairwise_distances_indices( self.embedding_, indices=self.NN_indices_, metric=self.metric_out )[0] ) Q_near = Q_near / (1e1 + Q_near) near_loss = self.w_NB * sum_red(Q_near, dim=(0, 1)) if self.w_MN > 0: # Attractive loss with mid-near points : # we sample 6 mid-near points for each sample # and keep the second closest in terms of input space distance device = getattr(self.NN_indices_, "device", "cpu") mid_near_indices = torch.empty( self.n_samples_in_, self.n_mid_near, device=device ) self_idxs = torch.arange(self.n_samples_in_, device=device).unsqueeze(1) n_possible_idxs = self.n_samples_in_ - 1 if n_possible_idxs < 6: raise ValueError( "[TorchDR] ERROR : Not enough points to sample 6 mid-near points." ) for i in range(self.n_mid_near): # to do: broadcast for efficiency mid_near_candidates_indices = torch.randint( 1, n_possible_idxs, (self.n_samples_in_, 6), device=device, ) shifts = torch.searchsorted( self_idxs, mid_near_candidates_indices, right=True ) mid_near_candidates_indices += shifts D_mid_near_candidates = symmetric_pairwise_distances_indices( self.X_, indices=mid_near_candidates_indices, metric=self.metric_in )[0] _, idxs = kmin(D_mid_near_candidates, k=2, dim=1) mid_near_indices[:, i] = idxs[:, 1] # Retrieve the second closest point Q_mid_near = ( 1 + symmetric_pairwise_distances_indices( self.embedding_, indices=mid_near_indices, metric=self.metric_out )[0] ) Q_mid_near = Q_mid_near / (1e4 + Q_mid_near) mid_near_loss = self.w_MN * sum_red(Q_mid_near, dim=(0, 1)) else: mid_near_loss = 0 return near_loss + mid_near_loss def _repulsive_loss(self): indices = self._sample_negatives(discard_NNs=True) Q_further = ( 1 + symmetric_pairwise_distances_indices( self.embedding_, metric=self.metric_out, indices=indices )[0] ) Q_further = 1 / (1 + Q_further) return self.w_FP * sum_red(Q_further, dim=(0, 1))