User Guide

Overview

General Formulation of Dimensionality Reduction

DR aims to construct a low-dimensional representation (or embedding) \(\mathbf{Z} = (\mathbf{z}_1, ..., \mathbf{z}_n)^\top\) of an input dataset \(\mathbf{X} = (\mathbf{x}_1, ..., \mathbf{x}_n)^\top\) that best preserves its geometry, encoded via a pairwise affinity matrix \(\mathbf{A_X}\). To this end, DR methods optimize \(\mathbf{Z}\) such that a pairwise affinity matrix in the embedding space (denoted \(\mathbf{A_Z}\)) matches \(\mathbf{A_X}\). This general problem is as follows

\[\min_{\mathbf{Z}} \: \mathcal{L}( \mathbf{A_X}, \mathbf{A_Z}) \quad \text{(DR)}\]

where \(\mathcal{L}\) is typically the \(\ell_2\) or cross-entropy loss. Each DR method is thus characterized by a triplet \((\mathcal{L}, \mathbf{A_X}, \mathbf{A_Z})\).

TorchDR is structured around the above formulation \(\text{(DR)}\). Defining a DR algorithm solely requires providing an Affinity object for both input and embedding as well as a loss function \(\mathcal{L}\).

All modules follow the sklearn [21] API and can be used in sklearn pipelines.

Torch GPU support and automatic differentiation

TorchDR is built on top of torch [20], offering GPU support and automatic differentiation. This foundation enables efficient computations and straightforward implementation of new DR methods.

To utilize GPU support, set device="cuda" when initializing any module. For CPU computations, set device="cpu".

Note

DR particularly benefits from GPU acceleration as most computations, including affinity calculations and the DR objective, involve matrix reductions that are highly parallelizable.

Avoiding memory overflows with KeOps symbolic (lazy) tensors

Affinities incur a quadratic memory cost, which can be particularly problematic when dealing with large numbers of samples, especially when using GPUs.

To prevent memory overflows, TorchDR relies on pykeops [19] lazy tensors. These tensors are expressed as mathematical formulas, evaluated directly on the data samples. This symbolic representation allows computations to be performed without storing the entire matrix in memory, thereby effectively eliminating any memory limitation.

_images/symbolic_matrix.svg

The above figure is taken from here.

Set keops=True as input to any module to use symbolic tensors. For small datasets, setting keops=False allows the computation of the full affinity matrix directly in memory.

Affinities

Affinities are the essential building blocks of dimensionality reduction methods. TorchDR provides a wide range of affinities, including basic ones such as GaussianAffinity, StudentAffinity and ScalarProductAffinity.

Base structure

Affinities inherit the structure of the following Affinity() class.

torchdr.Affinity

Base class for affinity matrices.

If computations can be performed in log domain, the LogAffinity() class should be used.

torchdr.LogAffinity

Base class for affinity matrices in log domain.

Affinities are objects that can directly be called. The outputed affinity matrix is a square matrix of size (n, n) where n is the number of input samples.

Here is an example with the GaussianAffinity:

>>> import torch, torchdr
>>>
>>> n = 100
>>> data = torch.randn(n, 2)
>>> affinity = torchdr.GaussianAffinity()
>>> affinity_matrix = affinity(data)
>>> print(affinity_matrix.shape)
(100, 100)

Spotlight on affinities based on entropic projections

A widely used family of affinities focuses on controlling the entropy of the affinity matrix. It is notably a crucial component of Neighbor-Embedding methods (see Neighbor Embedding).

These affinities are normalized such that each row sums to one, allowing the affinity matrix to be viewed as a Markov transition matrix. An adaptive bandwidth parameter then determines how the mass from each point spreads to its neighbors. The bandwidth is based on the perplexity hyperparameter which controls the number of effective neighbors for each point.

The resulting affinities can be viewed as a soft approximation of a k-nearest neighbor graph, where perplexity takes the role of k. This allows for capturing more nuances than binary weights, as closer neighbors receive a higher weight compared to those farther away. Ultimately, perplexity is an interpretable hyperparameter that governs the scale of dependencies represented in the affinity.

The following table outlines the aspects controlled by different formulations of entropic affinities. Marginal indicates whether each row of the affinity matrix has a controlled sum. Symmetry indicates whether the affinity matrix is symmetric. Entropy indicates whether each row of the affinity matrix has controlled entropy, dictated by the perplexity hyperparameter.

Affinity (associated DR method)

Marginal

Symmetry

Entropy

NormalizedGaussianAffinity

SinkhornAffinity [5] [9]

EntropicAffinity [1]

SymmetricEntropicAffinity [3]

More details on these affinities can be found in the SNEkhorn paper.

Examples using EntropicAffinity:

Entropic Affinities can adapt to varying noise levels

Entropic Affinities can adapt to varying noise levels

Neighbor Embedding on genomics & equivalent affinity matcher formulation

Neighbor Embedding on genomics & equivalent affinity matcher formulation

Other various affinities

TorchDR features other affinities that can be used in various contexts.

For instance, the UMAP [8] algorithm relies on the affinities UMAPAffinityIn for the input data and UMAPAffinityOut in the embedding space. UMAPAffinityIn follows a similar construction as entropic affinities to ensure a constant number of effective neighbors, with n_neighbors playing the role of the perplexity hyperparameter.

Another example is the doubly stochastic normalization of a base affinity under the \(\ell_2\) geometry that has recently been proposed for DR [10]. This method is analogous to SinkhornAffinity where the Shannon entropy is replaced by the \(\ell_2\) norm to recover a sparse affinity. It is available at DoublyStochasticQuadraticAffinity.

Dimensionality Reduction Modules

TorchDR provides a wide range of dimensionality reduction (DR) methods. All DR estimators inherit the structure of the DRModule() class:

torchdr.DRModule

Base class for DR methods.

They are sklearn.base.BaseEstimator and sklearn.base.TransformerMixin classes which can be called with the fit_transform method.

Outside of Spectral methods, a closed-form solution to the DR problem is typically not available. The problem can then be solved using gradient-based optimizers.

The following classes serve as parent classes for this approach, requiring the user to provide affinity objects for the input and output spaces, referred to as affinity_in and affinity_out.

torchdr.AffinityMatcher

Perform dimensionality reduction by matching two affinity matrices.

In what follows we briefly present two families of DR algorithms: neighbor embedding methods and spectral methods.

Spectral methods

Spectral methods correspond to choosing the scalar product affinity \([\mathbf{A_X}]_{ij} = \langle \mathbf{z}_i, \mathbf{z}_j \rangle\) for the embeddings and the \(\ell_2\) loss.

\[\min_{\mathbf{Z}} \: \sum_{ij} ( [\mathbf{A_X}]_{ij} - \langle \mathbf{z}_i, \mathbf{z}_j \rangle )^{2}\]

When \(\mathbf{A_X}\) is positive semi-definite, this problem is commonly known as kernel Principal Component Analysis [11] and an optimal solution is given by

\[\mathbf{Z}^{\star} = (\sqrt{\lambda_1} \mathbf{v}_1, ..., \sqrt{\lambda_d} \mathbf{v}_d)^\top\]

where \(\lambda_1, ..., \lambda_d\) are the largest eigenvalues of the centered kernel matrix \(\mathbf{A_X}\) and \(\mathbf{v}_1, ..., \mathbf{v}_d\) are the corresponding eigenvectors.

Note

PCA (available at torchdr.PCA) corresponds to choosing \([\mathbf{A_X}]_{ij} = \langle \mathbf{x}_i, \mathbf{x}_j \rangle\).

Neighbor Embedding

TorchDR aims to implement most popular neighbor embedding (NE) algorithms. In these methods, \(\mathbf{A_X}\) and \(\mathbf{A_Z}\) can be viewed as soft neighborhood graphs, hence the term neighbor embedding.

NE objectives share a common structure: they aim to minimize the weighted sum of an attractive term and a repulsive term. Interestingly, the attractive term is often the cross-entropy between the input and output affinities. Additionally, the repulsive term is typically a function of the output affinities only. Thus, the NE problem can be formulated as the following minimization problem:

\[\min_{\mathbf{Z}} \: - \sum_{ij} [\mathbf{A_X}]_{ij} \log [\mathbf{A_Z}]_{ij} + \gamma \mathcal{L}_{\mathrm{rep}}(\mathbf{A_Z}) \:.\]

In the above, \(\mathcal{L}_{\mathrm{rep}}(\mathbf{A_Z})\) represents the repulsive part of the loss function while \(\gamma\) is a hyperparameter that controls the balance between attraction and repulsion. The latter is called coeff_repulsion in TorchDR.

Many NE methods can be represented within this framework. The following table summarizes the ones implemented in TorchDR, detailing their respective repulsive loss function, as well as their input and output affinities.

Method

Repulsive term \(\mathcal{L}_{\mathrm{rep}}\)

Affinity input \(\mathbf{A_X}\)

Affinity output \(\mathbf{A_Z}\)

SNE [1]

\(\sum_{i} \log(\sum_j [\mathbf{A_Z}]_{ij})\)

EntropicAffinity

GaussianAffinity

TSNE [2]

\(\log(\sum_{ij} [\mathbf{A_Z}]_{ij})\)

EntropicAffinity

StudentAffinity

TSNEkhorn [3]

\(\sum_{ij} [\mathbf{A_Z}]_{ij}\)

SymmetricEntropicAffinity

SinkhornAffinity(base_kernel="student")

InfoTSNE [15]

\(\sum_i \log(\sum_{j \in N(i)} [\mathbf{A_Z}]_{ij})\)

EntropicAffinity

StudentAffinity

UMAP [8]

\(- \sum_{i, j \in N(i)} \log (1 - [\mathbf{A_Z}]_{ij})\)

UMAPAffinityIn

UMAPAffinityOut

LargeVis [13]

\(- \sum_{i, j \in N(i)} \log (1 - [\mathbf{A_Z}]_{ij})\)

EntropicAffinity

StudentAffinity

In the above table, \(N(i)\) denotes the set of negative samples for point \(i\). They are usually sampled uniformly at random from the dataset.

References