SinkhornAffinity#
- class torchdr.SinkhornAffinity(eps: float = 1.0, tol: float = 1e-05, max_iter: int = 1000, base_kernel: str = 'gaussian', metric: str = 'sqeuclidean', zero_diag: bool = True, device: str = 'auto', backend: str | None = None, verbose: bool = False, with_grad: bool = False)[source]#
Bases:
LogAffinity
Compute the symmetric doubly stochastic affinity matrix.
The algorithm computes the doubly stochastic matrix
with controlled global entropy using the symmetric Sinkhorn algorithm [Sinkhorn and Knopp, 1967].The algorithm computes the optimal dual variable
such thatwhere :
: symmetric pairwise distance matrix between the samples. : entropic regularization parameter. : all-ones vector.
is computed by performing dual ascent via the Sinkhorn fixed-point iteration (eq. 25 in [Feydy et al., 2019]).Convex problem. Consists in solving the following symmetric entropic optimal transport problem [Cuturi, 2013]:
where :
: set of symmetric doubly stochastic matrices. : (global) Shannon entropy such that .
Bregman projection. Another way to write this problem is to consider the KL projection of the Gaussian kernel
onto the set of doubly stochastic matrices:where
is the Kullback Leibler divergence between and .- Parameters:
eps (float, optional) – Regularization parameter for the Sinkhorn algorithm.
tol (float, optional) – Precision threshold at which the algorithm stops.
max_iter (int, optional) – Number of maximum iterations for the algorithm.
base_kernel ({"gaussian", "student"}, optional) – Which base kernel to normalize as doubly stochastic.
metric (str, optional) – Metric to use for computing distances (default “sqeuclidean”).
zero_diag (bool, optional) – Whether to set the diagonal of the distance matrix to 0.
device (str, optional) – Device to use for computation.
backend ({"keops", "faiss", None}, optional) – Which backend to use for handling sparsity and memory efficiency. Default is None.
verbose (bool, optional) – Verbosity. Default is False.
with_grad (bool, optional (default=False)) – If True, the Sinkhorn iterations are done with gradient tracking. If False, torch.no_grad() is used for the iterations.