"""Base classes for Neighbor Embedding methods."""
# Author: Hugues Van Assel <vanasselhugues@gmail.com>
#
# License: BSD 3-Clause License
import warnings
import os
import numpy as np
from typing import Any, Dict, Union, Optional, Type
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from torchdr.affinity import (
Affinity,
SparseAffinity,
)
from torchdr.distance import FaissConfig
from torchdr.affinity_matcher import AffinityMatcher
[docs]
class NeighborEmbedding(AffinityMatcher):
r"""Solves the neighbor embedding problem.
It amounts to solving:
.. math::
\min_{\mathbf{Z}} \: - \lambda \sum_{ij} P_{ij} \log Q_{ij} + \mathcal{L}_{\mathrm{rep}}(\mathbf{Q})
where :math:`\mathbf{P}` is the input affinity matrix, :math:`\mathbf{Q}` is the
output affinity matrix, :math:`\mathcal{L}_{\mathrm{rep}}` is the repulsive
term of the loss function, :math:`\lambda` is the :attr:`early_exaggeration_coeff`
parameter.
Note that the early exaggeration coefficient :math:`\lambda` is set to
:math:`1` after the early exaggeration phase which duration is controlled by the
:attr:`early_exaggeration_iter` parameter.
Parameters
----------
affinity_in : Affinity
The affinity object for the input space.
affinity_out : Affinity, optional
The affinity object for the output embedding space. Default is None.
kwargs_affinity_out : dict, optional
Additional keyword arguments for the affinity_out method.
n_components : int, optional
Number of dimensions for the embedding. Default is 2.
lr : float or 'auto', optional
Learning rate for the optimizer. Default is 1e0.
optimizer : str or torch.optim.Optimizer, optional
Name of an optimizer from torch.optim or an optimizer class.
Default is "SGD". For best results, we recommend using "SGD" with 'auto' learning rate.
optimizer_kwargs : dict or 'auto', optional
Additional keyword arguments for the optimizer. Default is 'auto',
which sets appropriate momentum values for SGD based on early exaggeration phase.
scheduler : str or torch.optim.lr_scheduler.LRScheduler, optional
Name of a scheduler from torch.optim.lr_scheduler or a scheduler class.
Default is None.
scheduler_kwargs : dict, 'auto', or None, optional
Additional keyword arguments for the scheduler. Default is 'auto', which
corresponds to a linear decay from the learning rate to 0 for `LinearLR`.
min_grad_norm : float, optional
Tolerance for stopping criterion. Default is 1e-7.
max_iter : int, optional
Maximum number of iterations. Default is 2000.
init : str or torch.Tensor or np.ndarray, optional
Initialization method for the embedding. Default is "pca".
init_scaling : float, optional
Scaling factor for the initial embedding. Default is 1e-4.
device : str, optional
Device to use for computations. Default is "auto".
backend : {"keops", "faiss", None} or FaissConfig, optional
Which backend to use for handling sparsity and memory efficiency.
Can be:
- "keops": Use KeOps for memory-efficient symbolic computations
- "faiss": Use FAISS for fast k-NN computations with default settings
- None: Use standard PyTorch operations
- FaissConfig object: Use FAISS with custom configuration
Default is None.
verbose : bool, optional
Verbosity of the optimization process. Default is False.
random_state : float, optional
Random seed for reproducibility. Default is None.
early_exaggeration_coeff : float, optional
Coefficient for the attraction term during the early exaggeration phase.
Default is None.
early_exaggeration_iter : int, optional
Number of iterations for early exaggeration. Default is None.
check_interval : int, optional
Number of iterations between two checks for convergence. Default is 50.
compile : bool, default=False
Whether to use torch.compile for faster computation.
""" # noqa: E501
def __init__(
self,
affinity_in: Affinity,
affinity_out: Optional[Affinity] = None,
kwargs_affinity_out: Optional[Dict] = None,
n_components: int = 2,
lr: Union[float, str] = 1e0,
optimizer: Union[str, Type[torch.optim.Optimizer]] = "SGD",
optimizer_kwargs: Union[Dict, str] = "auto",
scheduler: Optional[
Union[str, Type[torch.optim.lr_scheduler.LRScheduler]]
] = None,
scheduler_kwargs: Union[Dict, str, None] = "auto",
min_grad_norm: float = 1e-7,
max_iter: int = 2000,
init: Union[str, torch.Tensor, np.ndarray] = "pca",
init_scaling: float = 1e-4,
device: str = "auto",
backend: Union[str, FaissConfig, None] = None,
verbose: bool = False,
random_state: Optional[float] = None,
early_exaggeration_coeff: Optional[float] = None,
early_exaggeration_iter: Optional[int] = None,
check_interval: int = 50,
compile: bool = False,
**kwargs: Any,
):
self.early_exaggeration_iter = early_exaggeration_iter
if self.early_exaggeration_iter is None:
self.early_exaggeration_iter = 0
self.early_exaggeration_coeff = early_exaggeration_coeff
if self.early_exaggeration_coeff is None:
self.early_exaggeration_coeff = 1
# improve consistency with the sklearn API
if "learning_rate" in kwargs:
self.lr = kwargs["learning_rate"]
if "early_exaggeration" in kwargs:
self.early_exaggeration_coeff = kwargs["early_exaggeration"]
# by default, the linear scheduler goes from 1 to 0
_scheduler_kwargs = scheduler_kwargs
if scheduler == "LinearLR" and scheduler_kwargs == "auto":
_scheduler_kwargs = {
"start_factor": torch.tensor(1.0),
"end_factor": torch.tensor(0),
"total_iters": max_iter,
}
super().__init__(
affinity_in=affinity_in,
affinity_out=affinity_out,
kwargs_affinity_out=kwargs_affinity_out,
n_components=n_components,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
lr=lr,
scheduler=scheduler,
scheduler_kwargs=_scheduler_kwargs,
min_grad_norm=min_grad_norm,
max_iter=max_iter,
init=init,
init_scaling=init_scaling,
device=device,
backend=backend,
verbose=verbose,
random_state=random_state,
check_interval=check_interval,
compile=compile,
**kwargs,
)
def on_training_step_end(self):
if ( # stop early exaggeration phase
self.early_exaggeration_coeff_ > 1
and self.n_iter_ == self.early_exaggeration_iter
):
self.early_exaggeration_coeff_ = 1
# reinitialize optim
self._set_learning_rate()
self._configure_optimizer()
self._configure_scheduler()
return self
def _check_n_neighbors(self, n):
param_list = ["perplexity", "n_neighbors"]
for param_name in param_list:
if hasattr(self, param_name):
param_value = getattr(self, param_name)
if n <= param_value:
raise ValueError(
f"[TorchDR] ERROR : Number of samples is smaller than {param_name} "
f"({n} <= {param_value})."
)
return self
def _fit_transform(self, X: torch.Tensor, y: Optional[Any] = None) -> torch.Tensor:
n_samples = len(X.dataset) if isinstance(X, DataLoader) else X.shape[0]
self._check_n_neighbors(n_samples)
self.early_exaggeration_coeff_ = (
self.early_exaggeration_coeff
) # early_exaggeration_ may change during the optimization
return super()._fit_transform(X, y)
def _compute_loss(self):
raise NotImplementedError(
"[TorchDR] ERROR : _compute_loss method must be implemented."
)
def _compute_gradients(self):
raise NotImplementedError(
"[TorchDR] ERROR : _compute_gradients method must be implemented "
"when _use_direct_gradients is True."
)
def _set_learning_rate(self):
if self.lr == "auto":
if self.optimizer != "SGD":
if self.verbose:
warnings.warn(
"[TorchDR] WARNING : when 'auto' is used for the learning "
"rate, the optimizer should be 'SGD'."
)
# from the sklearn TSNE implementation
self.lr_ = max(self.n_samples_in_ / self.early_exaggeration_coeff_ / 4, 50)
else:
self.lr_ = self.lr
def _configure_optimizer(self):
if isinstance(self.optimizer, str):
# Get optimizer directly from torch.optim
try:
optimizer_class = getattr(torch.optim, self.optimizer)
except AttributeError:
raise ValueError(
f"[TorchDR] ERROR: Optimizer '{self.optimizer}' not found in torch.optim"
)
else:
if not issubclass(self.optimizer, torch.optim.Optimizer):
raise ValueError(
"[TorchDR] ERROR: optimizer must be a string (name of an optimizer in "
"torch.optim) or a subclass of torch.optim.Optimizer"
)
# Assume it's already an optimizer class
optimizer_class = self.optimizer
# If 'auto' and SGD, set momentum based on early exaggeration phase
if self.optimizer_kwargs == "auto":
if self.optimizer == "SGD":
if self.early_exaggeration_coeff_ > 1:
optimizer_kwargs = {"momentum": 0.5}
else:
optimizer_kwargs = {"momentum": 0.8}
else:
optimizer_kwargs = {}
else:
optimizer_kwargs = self.optimizer_kwargs or {}
self.optimizer_ = optimizer_class(self.params_, lr=self.lr_, **optimizer_kwargs)
return self.optimizer_
def _configure_scheduler(self):
if self.early_exaggeration_coeff_ > 1:
n_iter = min(self.early_exaggeration_iter, self.max_iter)
else:
n_iter = self.max_iter - self.early_exaggeration_iter
super()._configure_scheduler(n_iter)
[docs]
class SparseNeighborEmbedding(NeighborEmbedding):
r"""Solves the neighbor embedding problem with a sparse input affinity matrix.
It amounts to solving:
.. math::
\min_{\mathbf{Z}} \: - \lambda \sum_{ij} P_{ij} \log Q_{ij} + \mathcal{L}_{\mathrm{rep}}( \mathbf{Q})
where :math:`\mathbf{P}` is the input affinity matrix, :math:`\mathbf{Q}` is the
output affinity matrix, :math:`\mathcal{L}_{\mathrm{rep}}` is the repulsive
term of the loss function, :math:`\lambda` is the :attr:`early_exaggeration_coeff`
parameter.
**Fast attraction.** This class should be used when the input affinity matrix is sparse. In such cases, the attractive term can be computed with linear complexity.
Parameters
----------
affinity_in : Affinity
The affinity object for the input space.
affinity_out : Affinity, optional
The affinity object for the output embedding space. Default is None.
kwargs_affinity_out : dict, optional
Additional keyword arguments for the affinity_out method.
n_components : int, optional
Number of dimensions for the embedding. Default is 2.
lr : float or 'auto', optional
Learning rate for the optimizer. Default is 1e0.
optimizer : str or torch.optim.Optimizer, optional
Name of an optimizer from torch.optim or an optimizer class.
Default is "SGD". For best results, we recommend using "SGD" with 'auto' learning rate.
optimizer_kwargs : dict or 'auto', optional
Additional keyword arguments for the optimizer. Default is 'auto',
which sets appropriate momentum values for SGD based on early exaggeration phase.
scheduler : str or torch.optim.lr_scheduler.LRScheduler, optional
Name of a scheduler from torch.optim.lr_scheduler or a scheduler class.
Default is None (no scheduler).
scheduler_kwargs : dict, optional
Additional keyword arguments for the scheduler.
Default is "auto", which corresponds to a linear decay from the learning rate to 0 for `LinearLR`.
min_grad_norm : float, optional
Tolerance for stopping criterion. Default is 1e-7.
max_iter : int, optional
Maximum number of iterations. Default is 2000.
init : str or torch.Tensor or np.ndarray, optional
Initialization method for the embedding. Default is "pca".
init_scaling : float, optional
Scaling factor for the initial embedding. Default is 1e-4.
device : str, optional
Device to use for computations. Default is "auto".
backend : {"keops", "faiss", None} or FaissConfig, optional
Which backend to use for handling sparsity and memory efficiency.
Can be:
- "keops": Use KeOps for memory-efficient symbolic computations
- "faiss": Use FAISS for fast k-NN computations with default settings
- None: Use standard PyTorch operations
- FaissConfig object: Use FAISS with custom configuration
Default is None.
verbose : bool, optional
Verbosity of the optimization process. Default is False.
random_state : float, optional
Random seed for reproducibility. Default is None.
early_exaggeration_coeff : float, optional
Coefficient for the attraction term during the early exaggeration phase.
Default is 1.0.
early_exaggeration_iter : int, optional
Number of iterations for early exaggeration. Default is None.
repulsion_strength: float, optional
Strength of the repulsive term. Default is 1.0.
check_interval : int, optional
Number of iterations between two checks for convergence. Default is 50.
compile : bool, default=False
Whether to use torch.compile for faster computation.
""" # noqa: E501
def __init__(
self,
affinity_in: Affinity,
affinity_out: Optional[Affinity] = None,
kwargs_affinity_out: Optional[Dict] = None,
n_components: int = 2,
lr: Union[float, str] = 1e0,
optimizer: Union[str, Type[torch.optim.Optimizer]] = "SGD",
optimizer_kwargs: Union[Dict, str] = "auto",
scheduler: Optional[
Union[str, Type[torch.optim.lr_scheduler.LRScheduler]]
] = None,
scheduler_kwargs: Optional[Dict] = "auto",
min_grad_norm: float = 1e-7,
max_iter: int = 2000,
init: Union[str, torch.Tensor, np.ndarray] = "pca",
init_scaling: float = 1e-4,
device: str = "auto",
backend: Union[str, FaissConfig, None] = None,
verbose: bool = False,
random_state: Optional[float] = None,
early_exaggeration_coeff: float = 1.0,
early_exaggeration_iter: Optional[int] = None,
repulsion_strength: float = 1.0,
check_interval: int = 50,
compile: bool = False,
):
# check affinity affinity_in
if not isinstance(affinity_in, SparseAffinity):
raise NotImplementedError(
"[TorchDR] ERROR : when using SparseNeighborEmbedding, affinity_in "
"must be a sparse affinity."
)
self.repulsion_strength = repulsion_strength
super().__init__(
affinity_in=affinity_in,
affinity_out=affinity_out,
kwargs_affinity_out=kwargs_affinity_out,
n_components=n_components,
lr=lr,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
min_grad_norm=min_grad_norm,
max_iter=max_iter,
init=init,
init_scaling=init_scaling,
device=device,
backend=backend,
verbose=verbose,
random_state=random_state,
early_exaggeration_coeff=early_exaggeration_coeff,
early_exaggeration_iter=early_exaggeration_iter,
check_interval=check_interval,
compile=compile,
)
def _compute_attractive_loss(self):
raise NotImplementedError(
"[TorchDR] ERROR : _compute_attractive_loss method must be implemented."
)
def _compute_repulsive_loss(self):
raise NotImplementedError(
"[TorchDR] ERROR : _compute_repulsive_loss method must be implemented."
)
def _compute_loss(self):
loss = (
self.early_exaggeration_coeff_ * self._compute_attractive_loss()
+ self.repulsion_strength * self._compute_repulsive_loss()
)
return loss
@torch.no_grad()
def _compute_gradients(self):
# triggered when _use_direct_gradients is True
gradients = (
self.early_exaggeration_coeff_ * self._compute_attractive_gradients()
+ self.repulsion_strength * self._compute_repulsive_gradients()
)
return gradients
def _compute_attractive_gradients(self):
raise NotImplementedError(
"[TorchDR] ERROR : _compute_attractive_gradients method must be implemented "
"when _use_direct_gradients is True."
)
def _compute_repulsive_gradients(self):
raise NotImplementedError(
"[TorchDR] ERROR : _compute_repulsive_gradients method must be implemented "
"when _use_direct_gradients is True."
)
[docs]
class NegativeSamplingNeighborEmbedding(SparseNeighborEmbedding):
r"""Solves the neighbor embedding problem with both sparsity and sampling.
It amounts to solving:
.. math::
\min_{\mathbf{Z}} \: - \lambda \sum_{ij} P_{ij} \log Q_{ij} + \mathcal{L}_{\mathrm{rep}}( \mathbf{Q})
where :math:`\mathbf{P}` is the input affinity matrix, :math:`\mathbf{Q}` is the
output affinity matrix, :math:`\mathcal{L}_{\mathrm{rep}}` is the repulsive
term of the loss function, :math:`\lambda` is the :attr:`early_exaggeration_coeff`
parameter.
**Fast attraction.** This class should be used when the input affinity matrix is sparse.
In such cases, the attractive term can be computed with linear complexity.
**Fast repulsion.** A stochastic estimation of the repulsive term is used
to reduce its complexity to linear.
This is done by sampling a fixed number of negative samples
:attr:`n_negatives` for each point.
**Multi-GPU training.** When launched with torchrun, this class supports
distributed multi-GPU training. Each rank processes its chunk of the input
affinity, the embedding is replicated across ranks, and gradients are
synchronized during optimization.
Parameters
----------
affinity_in : Affinity
The affinity object for the input space.
affinity_out : Affinity, optional
The affinity object for the output embedding space. Default is None.
kwargs_affinity_out : dict, optional
Additional keyword arguments for the affinity_out method.
n_components : int, optional
Number of dimensions for the embedding. Default is 2.
lr : float or 'auto', optional
Learning rate for the optimizer. Default is 1e0.
optimizer : str or torch.optim.Optimizer, optional
Name of an optimizer from torch.optim or an optimizer class.
Default is "SGD". For best results, we recommend using "SGD" with 'auto' learning rate.
optimizer_kwargs : dict or 'auto', optional
Additional keyword arguments for the optimizer. Default is 'auto',
which sets appropriate momentum values for SGD based on early exaggeration phase.
scheduler : str or torch.optim.lr_scheduler.LRScheduler, optional
Name of a scheduler from torch.optim.lr_scheduler or a scheduler class.
Default is None (no scheduler).
scheduler_kwargs : dict, optional
Additional keyword arguments for the scheduler.
Default is "auto", which corresponds to a linear decay from the learning rate to 0 for `LinearLR`.
min_grad_norm : float, optional
Tolerance for stopping criterion. Default is 1e-7.
max_iter : int, optional
Maximum number of iterations. Default is 2000.
init : str, optional
Initialization method for the embedding. Default is "pca".
init_scaling : float, optional
Scaling factor for the initial embedding. Default is 1e-4.
device : str, optional
Device to use for computations. Default is "auto".
backend : {"keops", "faiss", None} or FaissConfig, optional
Which backend to use for handling sparsity and memory efficiency.
Can be:
- "keops": Use KeOps for memory-efficient symbolic computations
- "faiss": Use FAISS for fast k-NN computations with default settings
- None: Use standard PyTorch operations
- FaissConfig object: Use FAISS with custom configuration
Default is None.
verbose : bool, optional
Verbosity of the optimization process. Default is False.
random_state : float, optional
Random seed for reproducibility. Default is None.
early_exaggeration_coeff : float, optional
Coefficient for the attraction term during the early exaggeration phase.
Default is 1.0.
early_exaggeration_iter : int, optional
Number of iterations for early exaggeration. Default is None.
repulsion_strength: float, optional
Strength of the repulsive term. Default is 1.0.
n_negatives : int, optional
Number of negative samples to use. Default is 5.
check_interval : int, optional
Number of iterations between two checks for convergence. Default is 50.
discard_NNs : bool, optional
Whether to discard nearest neighbors from negative sampling. Default is False.
compile : bool, default=False
Whether to use torch.compile for faster computation.
distributed : bool or 'auto', optional
Whether to use distributed computation across multiple GPUs.
- "auto": Automatically detect if running with torchrun (default)
- True: Force distributed mode (requires torchrun)
- False: Disable distributed mode
Default is "auto".
""" # noqa: E501
def __init__(
self,
affinity_in: Affinity,
affinity_out: Optional[Affinity] = None,
kwargs_affinity_out: Optional[Dict] = None,
n_components: int = 2,
lr: Union[float, str] = 1e0,
optimizer: Union[str, Type[torch.optim.Optimizer]] = "SGD",
optimizer_kwargs: Union[Dict, str] = "auto",
scheduler: Optional[
Union[str, Type[torch.optim.lr_scheduler.LRScheduler]]
] = None,
scheduler_kwargs: Union[Dict, str, None] = "auto",
min_grad_norm: float = 1e-7,
max_iter: int = 2000,
init: str = "pca",
init_scaling: float = 1e-4,
device: str = "auto",
backend: Union[str, FaissConfig, None] = None,
verbose: bool = False,
random_state: Optional[float] = None,
early_exaggeration_coeff: float = 1.0,
early_exaggeration_iter: Optional[int] = None,
repulsion_strength: float = 1.0,
n_negatives: int = 5,
check_interval: int = 50,
discard_NNs: bool = False,
compile: bool = False,
distributed: Union[bool, str] = "auto",
):
super().__init__(
affinity_in=affinity_in,
affinity_out=affinity_out,
kwargs_affinity_out=kwargs_affinity_out,
n_components=n_components,
lr=lr,
optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
scheduler=scheduler,
scheduler_kwargs=scheduler_kwargs,
min_grad_norm=min_grad_norm,
max_iter=max_iter,
init=init,
init_scaling=init_scaling,
device=device,
backend=backend,
verbose=verbose,
random_state=random_state,
early_exaggeration_coeff=early_exaggeration_coeff,
early_exaggeration_iter=early_exaggeration_iter,
repulsion_strength=repulsion_strength,
check_interval=check_interval,
compile=compile,
)
self.n_negatives = n_negatives
self.discard_NNs = discard_NNs
if distributed == "auto":
self.distributed = dist.is_initialized()
else:
self.distributed = bool(distributed)
if self.distributed:
if not dist.is_initialized():
raise RuntimeError(
"[TorchDR] distributed=True requires launching with torchrun. "
"Example: torchrun --nproc_per_node=4 your_script.py"
)
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.is_multi_gpu = self.world_size > 1
# Bind to local CUDA device
local_rank = int(os.environ.get("LOCAL_RANK", 0))
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)
if self.device == "cpu":
raise ValueError(
"[TorchDR] Distributed mode requires GPU (device cannot be 'cpu')"
)
self.device = torch.device(f"cuda:{local_rank}")
else:
self.rank = 0
self.world_size = 1
self.is_multi_gpu = False
[docs]
def on_affinity_computation_end(self):
"""Prepare for negative sampling by building per-row exclusion indices.
Unified logic for single- and multi-GPU using chunk bounds.
"""
super().on_affinity_computation_end()
# Get chunk bounds from affinity (stored during _distance_matrix call)
if hasattr(self.affinity_in, "chunk_start_"):
chunk_start = self.affinity_in.chunk_start_
chunk_size = self.affinity_in.chunk_size_
else:
if self.world_size > 1:
raise ValueError(
"[TorchDR] ERROR: Distributed mode is enabled but affinity_in does not "
"have chunk bounds. Make sure affinity_in has distributed=True."
)
chunk_start = 0
chunk_size = self.n_samples_in_
self.chunk_indices_ = torch.arange(
chunk_start, chunk_start + chunk_size, device=self.device_
)
global_self_idx = self.chunk_indices_.unsqueeze(1)
chunk_size = len(global_self_idx)
# Optionally include NN indices (rows aligned with local slice)
if self.discard_NNs:
if not hasattr(self, "NN_indices_"):
self.logger.warning(
"NN_indices_ not found. Cannot discard NNs from negative sampling."
)
exclude = global_self_idx
else:
nn_rows = self.NN_indices_
if nn_rows.shape[0] != chunk_size:
raise ValueError(
f"[TorchDR] ERROR: In distributed mode, expected NN_indices_ to have "
f"{chunk_size} rows for chunk size, but got {nn_rows.shape[0]}."
)
exclude = torch.cat([global_self_idx, nn_rows], dim=1)
else:
exclude = global_self_idx
# Sort per-row exclusions for searchsorted
exclude_sorted, _ = exclude.sort(dim=1)
self.register_buffer(
"negative_exclusion_indices_", exclude_sorted, persistent=False
)
# Safety check on number of available negatives
n_possible = self.n_samples_in_ - self.negative_exclusion_indices_.shape[1]
if self.n_negatives > n_possible and self.verbose:
raise ValueError(
f"[TorchDR] ERROR : requested {self.n_negatives} negatives but "
f"only {n_possible} available."
)
[docs]
def on_training_step_start(self):
"""Sample negatives using a unified path for single- and multi-GPU."""
super().on_training_step_start()
chunk_size = len(self.chunk_indices_)
device = self.embedding_.device
exclusion = self.negative_exclusion_indices_
excl_width = exclusion.shape[1]
# Only excluding self-indices
if excl_width == 1:
negatives = torch.randint(
0,
self.n_samples_in_ - 1,
(chunk_size, self.n_negatives),
device=device,
)
self_idx = self.chunk_indices_.unsqueeze(1)
neg_indices = negatives + (negatives >= self_idx).long()
# Excluding self-indices and NNs indices (computed in on_affinity_computation_end)
else:
negatives = torch.randint(
1,
self.n_samples_in_ - excl_width,
(chunk_size, self.n_negatives),
device=device,
)
shifts = torch.searchsorted(exclusion, negatives, right=True)
neg_indices = negatives + shifts
self.register_buffer("neg_indices_", neg_indices, persistent=False)
def _init_embedding(self, X: torch.Tensor):
"""Initialize embedding across ranks (broadcast from rank 0)."""
if self.world_size > 1:
if self.rank == 0:
super()._init_embedding(X)
else:
n = X.shape[0]
self.embedding_ = torch.empty(
(n, self.n_components),
device=self.device_,
dtype=X.dtype,
requires_grad=True,
)
if not self.embedding_.is_contiguous():
self.embedding_ = self.embedding_.contiguous()
dist.broadcast(self.embedding_, src=0)
self.embedding_ = self.embedding_.detach().requires_grad_(True)
return self.embedding_
else:
return super()._init_embedding(X)