Source code for torchdr.affinity.quadratic

# -*- coding: utf-8 -*-
"""Affinity matrices with quadratic constraints."""

# Author: Hugues Van Assel <vanasselhugues@gmail.com>
#
# License: BSD 3-Clause License

import torch
from tqdm import tqdm
import warnings

from torchdr.affinity import Affinity
from torchdr.utils import (
    OPTIMIZERS,
    wrap_vectors,
    check_NaNs,
    batch_transpose,
)


@wrap_vectors
def _Pds(C, dual, eps):
    r"""Return the quadratic doubly stochastic matrix from dual variable and cost.

    Parameters
    ----------
    C : torch.Tensor or pykeops.torch.LazyTensor of shape (n, n)
        or shape (n_batch, batch_size, batch_size)
        Pairwise distance matrix.
    dual : torch.Tensor of shape (n) or (n_batch, batch_size)
        Dual variable of the normalization constraint.
    eps : float
        Dual variable of the quadratic constraint.

    Returns
    -------
    P : torch.Tensor or pykeops.torch.LazyTensor of shape (n, n)
        or shape (n_batch, batch_size, batch_size)
        Quadratic doubly stochastic affinity matrix.
    """
    dual_t = batch_transpose(dual)
    return (dual + dual_t - C).clamp(0, float("inf")) / eps


[docs] class DoublyStochasticQuadraticAffinity(Affinity): r"""Compute the symmetric doubly stochastic affinity. Implement the doubly stochastic normalized matrix with controlled global :math:`\ell_2` norm. The algorithm computes the optimal dual variable :math:`\mathbf{f}^\star \in \mathbb{R}^n` such that .. math:: \mathbf{P}^{\star} \mathbf{1} = \mathbf{1} \quad \text{where} \quad \forall (i,j), \: P^{\star}_{ij} = \left[f^\star_i + f^\star_j - C_{ij} / \varepsilon \right]_{+} \:. :math:`\mathbf{f}^\star` is computed by performing dual ascent. **Convex problem.** Consists in solving the following symmetric quadratic optimal transport problem [10]_: .. math:: \mathop{\arg\min}_{\mathbf{P} \in \mathcal{DS}} \: \langle \mathbf{C}, \mathbf{P} \rangle + \varepsilon \| \mathbf{P} \|_2^2 where : - :math:`\mathcal{DS} := \left\{ \mathbf{P} \in \mathbb{R}_+^{n \times n}: \: \mathbf{P} = \mathbf{P}^\top \:,\: \mathbf{P} \mathbf{1} = \mathbf{1} \right\}`: set of symmetric doubly stochastic matrices. - :math:`\mathbf{C}`: symmetric pairwise distance matrix between the samples. - :math:`\varepsilon`: quadratic regularization parameter. - :math:`\mathbf{1} := (1,...,1)^\top`: all-ones vector. **Bregman projection.** Another way to write this problem is to consider the :math:`\ell_2` projection of :math:`- \mathbf{C} / \varepsilon` onto the set of doubly stochastic matrices :math:`\mathcal{DS}`, as follows: .. math:: \mathrm{Proj}_{\mathcal{DS}}^{\ell_2}(- \mathbf{C} / \varepsilon) := \mathop{\arg\min}_{\mathbf{P} \in \mathcal{DS}} \: \| \mathbf{P} + \mathbf{C} / \varepsilon \|_2 \:. Parameters ---------- eps : float, optional Regularization parameter. init_dual : torch.Tensor of shape (n_samples), optional Initialization for the dual variable (default None). tol : float, optional Precision threshold at which the algorithm stops. max_iter : int, optional Number of maximum iterations for the algorithm. optimizer : {"Adam", "SGD", "NAdam"}, optional Optimizer to use for the dual ascent. lr : float, optional Learning rate for the optimizer. base_kernel : {"gaussian", "student"}, optional Which base kernel to normalize as doubly stochastic. tolog : bool, optional Whether to store intermediate result in a dictionary. metric : str, optional Metric to use for computing distances (default "sqeuclidean"). zero_diag : bool, optional Whether to set the diagonal elements of the affinity matrix to 0. device : str, optional Device to use for computation. keops : bool, optional Whether to use KeOps for computation. verbose : bool, optional Verbosity. Default is False. References ---------- .. [10] Stephen Zhang, Gilles Mordant, Tetsuya Matsumoto, Geoffrey Schiebinger (2023). `Manifold Learning with Sparse Regularised Optimal Transport <https://arxiv.org/abs/2307.09816>`_. arXiv preprint. """ # noqa: E501 def __init__( self, eps: float = 1.0, init_dual: torch.Tensor = None, tol: float = 1e-5, max_iter: int = 1000, optimizer: str = "Adam", lr: float = 1e0, base_kernel: str = "gaussian", tolog: bool = False, metric: str = "sqeuclidean", zero_diag: bool = True, device: str = "auto", keops: bool = False, verbose: bool = False, ): super().__init__( metric=metric, zero_diag=zero_diag, device=device, keops=keops, verbose=verbose, ) self.eps = eps self.init_dual = init_dual self.tol = tol self.max_iter = max_iter self.optimizer = optimizer self.lr = lr self.base_kernel = base_kernel self.tolog = tolog def _compute_affinity(self, X: torch.Tensor): r"""Compute the quadratic doubly stochastic affinity matrix from input data X. Parameters ---------- X : torch.Tensor of shape (n_samples, n_features) Data on which affinity is computed. Returns ------- affinity_matrix : torch.Tensor or pykeops.torch.LazyTensor The computed affinity matrix. """ if self.verbose: print( "[TorchDR] Affinity : computing the Doubly Stochastic Quadratic " "Affinity matrix." ) C = self._distance_matrix(X) if self.base_kernel == "student": C = (1 + C).log() n_samples_in = C.shape[0] one = torch.ones(n_samples_in, dtype=X.dtype, device=X.device) # Performs warm-start if an initial dual variable is provided self.dual_ = ( torch.ones(n_samples_in, dtype=X.dtype, device=X.device) if self.init_dual is None else self.init_dual ) if self.tolog: self.log["dual"] = [self.dual_.clone().detach().cpu()] optimizer = OPTIMIZERS[self.optimizer]([self.dual_], lr=self.lr) # Dual ascent iterations pbar = tqdm(range(self.max_iter), disable=not self.verbose) for k in pbar: with torch.no_grad(): P = _Pds(C, self.dual_, self.eps) P_sum = P.sum(1).squeeze() grad_dual = P_sum - one self.dual_.grad = grad_dual optimizer.step() check_NaNs( [self.dual_], msg=( f"[TorchDR] Affinity (ERROR): NaN at iter {k}, " "consider decreasing the learning rate." ), ) if self.tolog: self.log["dual"].append(self.dual_.clone().detach().cpu()) if self.verbose: pbar.set_description( f"MARGINAL:{float(P_sum.mean().item()): .2e} " f"(std:{float(P_sum.std().item()): .2e})" ) if torch.norm(grad_dual) < self.tol: if self.verbose: print(f"[TorchDR] Affinity : convergence reached at iter {k}.") break if k == self.max_iter - 1 and self.verbose: warnings.warn( "[TorchDR] Affinity (WARNING) : max iter attained, " "algorithm stops but may not have converged." ) self.n_iter_ = k affinity_matrix = _Pds(C, self.dual_, self.eps) affinity_matrix /= n_samples_in return affinity_matrix