"""PACMAP algorithm."""
# Author: Hugues Van Assel <vanasselhugues@gmail.com>
#
# License: BSD 3-Clause License
import torch
from torchdr.neighbor_embedding.base import SampledNeighborEmbedding
from typing import Union, Optional, Dict, Type, Any
from torchdr.affinity import PACMAPAffinity
from torchdr.utils import kmin, sum_red, symmetric_pairwise_distances_indices
[docs]
class PACMAP(SampledNeighborEmbedding):
r"""PACMAP algorithm introduced in :cite:`wang2021understanding`.
It uses a :class:`~torchdr.PACMAPAffinity` as input affinity.
The loss function is defined as:
.. math::
w_{\mathrm{NB}} \sum_{i, j \in \mathrm{NB}(i)} \frac{d_{ij}}{10 + d_{ij}} + w_{\mathrm{MN}} \sum_{i,j \in \mathrm{MN}(i)} \frac{d_{ij}}{10^4 + d_{ij}} + w_{\mathrm{FP}} \sum_{i,j \in \mathrm{FP}(i)} \frac{1}{1 + d_{ij}}
where :math:`\mathrm{NB}(i)`, :math:`\mathrm{MN}(i)` and :math:`\mathrm{FP}(i)` are the nearest neighbors, mid-near neighbors and far neighbors of point :math:`i` respectively,
and :math:`d_{ij} = 1 + \|\mathbf{z}_i - \mathbf{z}_j\|^2` (more details in :cite:`wang2021understanding`).
Parameters
----------
n_neighbors : int, optional
Number of nearest neighbors.
n_components : int, optional
Dimension of the embedding space.
lr : float or 'auto', optional
Learning rate for the algorithm, by default 1e0.
optimizer : str or torch.optim.Optimizer, optional
Name of an optimizer from torch.optim or an optimizer class.
Default is "Adam".
optimizer_kwargs : dict or 'auto', optional
Additional keyword arguments for the optimizer. Default is None,
which sets appropriate momentum values for SGD based on early exaggeration phase.
scheduler : str or torch.optim.lr_scheduler.LRScheduler, optional
Name of a scheduler from torch.optim.lr_scheduler or a scheduler class.
Default is None (no scheduler).
scheduler_kwargs : dict, optional
Additional keyword arguments for the scheduler.
init : {'normal', 'pca'} or torch.Tensor of shape (n_samples, output_dim), optional
Initialization for the embedding Z, default 'pca'.
init_scaling : float, optional
Scaling factor for the initialization, by default 1e-4.
min_grad_norm : float, optional
Precision threshold at which the algorithm stops, by default 1e-7.
max_iter : int, optional
Number of maximum iterations for the descent algorithm, by default 450.
device : str, optional
Device to use, by default "auto".
backend : {"keops", "faiss", None}, optional
Which backend to use for handling sparsity and memory efficiency.
Default is "faiss".
verbose : bool, optional
Verbosity, by default False.
random_state : float, optional
Random seed for reproducibility, by default None.
metric_in : {'sqeuclidean', 'manhattan'}, optional
Metric to use for the input affinity, by default 'sqeuclidean'.
metric_out : {'sqeuclidean', 'manhattan'}, optional
Metric to use for the output affinity, by default 'sqeuclidean'.
MN_ratio : float, optional
Ratio of mid-near pairs to nearest neighbor pairs, by default 0.5.
FP_ratio : float, optional
Ratio of far pairs to nearest neighbor pairs, by default 2.
check_interval : int, optional
Interval for checking convergence, by default 50.
iter_per_phase : int, optional
Number of iterations for each phase of the algorithm, by default 100.
""" # noqa: E501
def __init__(
self,
n_neighbors: float = 10,
n_components: int = 2,
lr: Union[float, str] = 1e0,
optimizer: Union[str, Type[torch.optim.Optimizer]] = "Adam",
optimizer_kwargs: Optional[Union[Dict, str]] = None,
scheduler: Optional[
Union[str, Type[torch.optim.lr_scheduler.LRScheduler]]
] = None,
scheduler_kwargs: Optional[Dict] = None,
init: str = "pca",
init_scaling: float = 1e-4,
min_grad_norm: float = 1e-7,
max_iter: int = 450,
device: Optional[str] = None,
backend: Optional[str] = "faiss",
verbose: bool = False,
random_state: Optional[float] = None,
metric_in: str = "sqeuclidean",
metric_out: str = "sqeuclidean",
MN_ratio: float = 0.5,
FP_ratio: float = 2,
check_interval: int = 50,
iter_per_phase: int = 100,
):
self.n_neighbors = n_neighbors
self.metric_in = metric_in
self.metric_out = metric_out
self.MN_ratio = MN_ratio
self.FP_ratio = FP_ratio
self.n_mid_near = int(MN_ratio * n_neighbors)
self.n_further = int(FP_ratio * n_neighbors)
self.iter_per_phase = iter_per_phase
affinity_in = PACMAPAffinity(
n_neighbors=n_neighbors,
metric=metric_in,
device=device,
backend=backend,
verbose=verbose,
)
super().__init__(
affinity_in=affinity_in,
affinity_out=None,
n_components=n_components,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
min_grad_norm=min_grad_norm,
max_iter=max_iter,
lr=lr,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
init=init,
init_scaling=init_scaling,
device=device,
backend=backend,
verbose=verbose,
random_state=random_state,
check_interval=check_interval,
n_negatives=self.n_further,
)
def _fit_transform(self, X: torch.Tensor, y: Optional[Any] = None):
self.X_ = X # Keep input data to compute mid-near loss
self._set_weights()
return super()._fit_transform(X, y)
def _set_weights(self):
if self.n_iter_ < self.iter_per_phase:
self.w_NB = 2
self.w_MN = (
1000 * (1 - self.n_iter_ / self.iter_per_phase)
+ 3 * self.n_iter_ / self.iter_per_phase
)
self.w_FP = 1
elif self.n_iter_ < 2 * self.iter_per_phase:
self.w_NB = 3
self.w_MN = 3
self.w_FP = 1
else:
self.w_NB = 1
self.w_MN = 0
self.w_FP = 1
def _after_step(self):
self._set_weights()
def _attractive_loss(self):
# Attractive loss with nearest neighbors
Q_near = (
1
+ symmetric_pairwise_distances_indices(
self.embedding_, indices=self.NN_indices_, metric=self.metric_out
)[0]
)
Q_near = Q_near / (1e1 + Q_near)
near_loss = self.w_NB * sum_red(Q_near, dim=(0, 1))
if self.w_MN > 0:
# Attractive loss with mid-near points :
# we sample 6 mid-near points for each sample
# and keep the second closest in terms of input space distance
device = getattr(self.NN_indices_, "device", "cpu")
mid_near_indices = torch.empty(
self.n_samples_in_, self.n_mid_near, device=device
)
self_idxs = torch.arange(self.n_samples_in_, device=device).unsqueeze(1)
n_possible_idxs = self.n_samples_in_ - 1
if n_possible_idxs < 6:
raise ValueError(
"[TorchDR] ERROR : Not enough points to sample 6 mid-near points."
)
for i in range(self.n_mid_near): # to do: broadcast for efficiency
mid_near_candidates_indices = torch.randint(
1,
n_possible_idxs,
(self.n_samples_in_, 6),
device=device,
)
shifts = torch.searchsorted(
self_idxs, mid_near_candidates_indices, right=True
)
mid_near_candidates_indices += shifts
D_mid_near_candidates = symmetric_pairwise_distances_indices(
self.X_, indices=mid_near_candidates_indices, metric=self.metric_in
)[0]
_, idxs = kmin(D_mid_near_candidates, k=2, dim=1)
mid_near_indices[:, i] = idxs[:, 1] # Retrieve the second closest point
Q_mid_near = (
1
+ symmetric_pairwise_distances_indices(
self.embedding_, indices=mid_near_indices, metric=self.metric_out
)[0]
)
Q_mid_near = Q_mid_near / (1e4 + Q_mid_near)
mid_near_loss = self.w_MN * sum_red(Q_mid_near, dim=(0, 1))
else:
mid_near_loss = 0
return near_loss + mid_near_loss
def _repulsive_loss(self):
indices = self._sample_negatives(discard_NNs=True)
Q_further = (
1
+ symmetric_pairwise_distances_indices(
self.embedding_, metric=self.metric_out, indices=indices
)[0]
)
Q_further = 1 / (1 + Q_further)
return self.w_FP * sum_red(Q_further, dim=(0, 1))