Source code for torchdr.neighbor_embedding.tsnekhorn

# -*- coding: utf-8 -*-
"""SNEkhorn algorithm (inverse OT DR)."""

# Author: Hugues Van Assel <vanasselhugues@gmail.com>
#
# License: BSD 3-Clause License

from torchdr.neighbor_embedding.base import NeighborEmbedding
from torchdr.affinity import (
    SymmetricEntropicAffinity,
    EntropicAffinity,
    SinkhornAffinity,
)
from torchdr.utils import logsumexp_red, cross_entropy_loss


[docs] class TSNEkhorn(NeighborEmbedding): r"""Implementation of the TSNEkhorn algorithm introduced in [3]_. It involves selecting a :class:`~torchdr.SymmetricEntropicAffinity` as input affinity :math:`\mathbf{P}` and a :class:`~torchdr.SinkhornAffinity` as output affinity :math:`\mathbf{Q}`. The loss function is defined as: .. math:: -\sum_{ij} P_{ij} \log Q_{ij} + \sum_{ij} Q_{ij} \:. The above loss is the gap objective for inverse symmetric optimal transport described in this `blog <https://huguesva.github.io/blog/2024/inverseOT_mongegap/>`_. .. note:: The :class:`~torchdr.SymmetricEntropicAffinity` requires a careful choice of learning rate (parameter :attr:`lr_affinity_in`). If it is too unstable, one can resort to using :class:`~torchdr.EntropicAffinity` instead by setting :attr:`symmetric_affinity` to ``False``. Parameters ---------- perplexity : float Number of 'effective' nearest neighbors. Consider selecting a value between 2 and the number of samples. Different values can result in significantly different results. n_components : int, optional Dimension of the embedding space. lr : float, optional Learning rate for the algorithm, by default 1e0. optimizer : {'SGD', 'Adam', 'NAdam'}, optional Which pytorch optimizer to use, by default 'Adam'. optimizer_kwargs : dict, optional Arguments for the optimizer, by default None. scheduler : {'constant', 'linear'}, optional Learning rate scheduler. scheduler_kwargs : dict, optional Arguments for the scheduler, by default None. init : {'normal', 'pca'} or torch.Tensor of shape (n_samples, output_dim), optional Initialization for the embedding Z, default 'pca'. init_scaling : float, optional Scaling factor for the initialization, by default 1e-4. tol : float, optional Precision threshold at which the algorithm stops, by default 1e-4. max_iter : int, optional Number of maximum iterations for the descent algorithm, by default 2000. tolog : bool, optional Whether to store intermediate results in a dictionary, by default False. device : str, optional Device to use, by default "auto". keops : bool, optional Whether to use KeOps, by default False. verbose : bool, optional Verbosity, by default False. random_state : float, optional Random seed for reproducibility, by default 0. early_exaggeration : float, optional Coefficient for the attraction term during the early exaggeration phase. By default 10.0 for early exaggeration. coeff_repulsion : float, optional Coefficient for the repulsion term, by default 1.0. early_exaggeration_iter : int, optional Number of iterations for early exaggeration, by default 250. lr_affinity_in : float, optional Learning rate used to update dual variables for the symmetric entropic affinity computation. eps_square_affinity_in : bool, optional When computing the symmetric entropic affinity, whether to optimize on the square of the dual variables. May be more stable in practice. tol_affinity_in : _type_, optional Precision threshold for the symmetric entropic affinity computation. max_iter_affinity_in : int, optional Number of maximum iterations for the symmetric entropic affinity computation. metric_in : {'sqeuclidean', 'manhattan'}, optional Metric to use for the input affinity, by default 'sqeuclidean'. metric_out : {'sqeuclidean', 'manhattan'}, optional Metric to use for the output affinity, by default 'sqeuclidean'. unrolling : bool, optional Whether to use unrolling for solving inverse OT. If False, uses the gap objective. Default is False. symmetric_affinity : bool, optional Whether to use symmetric entropic affinity. If False, uses entropic affinity. Default is True. References ---------- .. [3] SNEkhorn: Dimension Reduction with Symmetric Entropic Affinities, Hugues Van Assel, Titouan Vayer, RĂ©mi Flamary, Nicolas Courty. Advances in neural information processing systems 36 (NeurIPS). """ # noqa: E501 def __init__( self, perplexity: float = 30, n_components: int = 2, lr: float = 1.0, optimizer: str = "Adam", optimizer_kwargs: dict = None, scheduler: str = "constant", scheduler_kwargs: dict = None, init: str = "pca", init_scaling: float = 1e-4, tol: float = 1e-4, max_iter: int = 2000, tolog: bool = False, device: str = None, keops: bool = False, verbose: bool = False, random_state: float = 0, early_exaggeration: float = 10.0, coeff_repulsion: float = 1.0, early_exaggeration_iter: int = 250, lr_affinity_in: float = 1e-1, eps_square_affinity_in: bool = True, tol_affinity_in: float = 1e-3, max_iter_affinity_in: int = 100, metric_in: str = "sqeuclidean", metric_out: str = "sqeuclidean", unrolling: bool = False, symmetric_affinity: bool = True, ): self.metric_in = metric_in self.metric_out = metric_out self.perplexity = perplexity self.lr_affinity_in = lr_affinity_in self.eps_square_affinity_in = eps_square_affinity_in self.max_iter_affinity_in = max_iter_affinity_in self.tol_affinity_in = tol_affinity_in self.unrolling = unrolling self.symmetric_affinity = symmetric_affinity if self.symmetric_affinity: affinity_in = SymmetricEntropicAffinity( perplexity=perplexity, lr=lr_affinity_in, eps_square=eps_square_affinity_in, metric=metric_in, tol=tol_affinity_in, max_iter=max_iter_affinity_in, device=device, keops=keops, verbose=verbose, zero_diag=False, ) else: affinity_in = EntropicAffinity( perplexity=perplexity, metric=metric_in, tol=tol_affinity_in, max_iter=max_iter_affinity_in, device=device, keops=keops, verbose=verbose, ) affinity_out = SinkhornAffinity( metric=metric_out, device=device, keops=keops, verbose=False, base_kernel="student", with_grad=unrolling, max_iter=5, ) super().__init__( affinity_in=affinity_in, affinity_out=affinity_out, n_components=n_components, optimizer=optimizer, optimizer_kwargs=optimizer_kwargs, tol=tol, max_iter=max_iter, lr=lr, scheduler=scheduler, scheduler_kwargs=scheduler_kwargs, init=init, init_scaling=init_scaling, tolog=tolog, device=device, keops=keops, verbose=verbose, random_state=random_state, early_exaggeration=early_exaggeration, coeff_repulsion=coeff_repulsion, early_exaggeration_iter=early_exaggeration_iter, ) def _loss(self): if not hasattr(self, "dual_sinkhorn_"): self.dual_sinkhorn_ = None log_Q = self.affinity_out( self.embedding_, log=True, init_dual=self.dual_sinkhorn_ ) self.dual_sinkhorn_ = self.affinity_out.dual_.detach() P = self.PX_ attractive_term = cross_entropy_loss(P, log_Q, log=True) if self.unrolling: repulsive_term = 0 else: repulsive_term = logsumexp_red(log_Q, dim=(0, 1)).exp() loss = ( self.early_exaggeration_ * attractive_term + self.coeff_repulsion * repulsive_term ) return loss