IncrementalPCA#
- class torchdr.IncrementalPCA(n_components: int | None = None, copy: bool | None = True, batch_size: int | None = None, svd_driver: str | None = None, lowrank: bool = False, lowrank_q: int | None = None, lowrank_niter: int = 4, device: str = 'auto', verbose: bool = False, random_state: float | None = None, **kwargs)[source]#
Bases:
DRModule
Incremental Principal Components Analysis (IPCA) leveraging PyTorch for GPU acceleration.
This class provides methods to fit the model on data incrementally in batches, and to transform new data based on the principal components learned during the fitting process.
The algorithm uses incremental SVD updates based on [Ross et al., 2008], which allows maintaining a low-rank approximation of the data covariance matrix without storing all data or recomputing from scratch.
It is particularly useful when the dataset to be decomposed is too large to fit in memory. Adapted from Scikit-learn Incremental PCA.
Memory Management Strategy: - Data is processed in batches to avoid loading entire dataset into memory - Each batch is temporarily moved to computation device (GPU if specified) - Model parameters (mean_, components_) are kept on computation device - Only the current batch needs to fit in GPU memory, not the full dataset
Examples
Using with PyTorch DataLoader for true out-of-core learning:
from torch.utils.data import DataLoader, TensorDataset # Create a DataLoader for a huge dataset dataset = TensorDataset(huge_X_tensor, huge_y_tensor) dataloader = DataLoader(dataset, batch_size=1000, shuffle=True) # Fit incrementally using DataLoader ipca = IncrementalPCA(n_components=50, device='cuda') for batch in dataloader: X_batch = batch[0] # DataLoader returns (X, y) tuples ipca.partial_fit(X_batch) # Transform new data in batches test_loader = DataLoader(test_dataset, batch_size=1000) for batch in test_loader: X_batch = batch[0] X_transformed = ipca.transform(X_batch) # Process transformed batch...
Using with data generators for streaming large files:
import pandas as pd def data_generator(): # Read huge CSV in chunks for chunk in pd.read_csv('huge_file.csv', chunksize=1000): yield torch.tensor(chunk.values, dtype=torch.float32) ipca = IncrementalPCA(n_components=50) for batch in data_generator(): ipca.partial_fit(batch)
Using with HDF5 or memory-mapped arrays:
import h5py with h5py.File('huge_dataset.h5', 'r') as f: X = f['data'] # HDF5 dataset, not loaded into memory n_samples = X.shape[0] batch_size = 1000 ipca = IncrementalPCA(n_components=100) for i in range(0, n_samples, batch_size): batch = torch.tensor(X[i:i+batch_size]) ipca.partial_fit(batch)
- Parameters:
n_components (int, optional) – Number of components to keep. If None, it’s set to the minimum of the number of samples and features. Defaults to None.
copy (bool) – If False, input data will be overwritten. Defaults to True.
batch_size (int, optional) – The number of samples to use for each batch. Only needed if self.fit is called. If None, it’s inferred from the data and set to 5 * n_features. Defaults to None.
svd_driver (str, optional) – Name of the cuSOLVER method to be used for torch.linalg.svd. This keyword argument only works on CUDA inputs. Available options are: None, gesvd, gesvdj and gesvda. Defaults to None.
lowrank (bool, optional) – Whether to use torch.svd_lowrank instead of torch.linalg.svd which can be faster. Defaults to False.
lowrank_q (int, optional) – For an adequate approximation of n_components, this parameter defaults to n_components * 2.
lowrank_niter (int, optional) – Number of subspace iterations to conduct for torch.svd_lowrank. Defaults to 4.
device (str, optional) – Device on which the computations are performed. Defaults to “auto”.
random_state (float, optional) – Random state for reproducibility. Defaults to None.
**kwargs (dict) – Additional keyword arguments.
- static gen_batches(n: int, batch_size: int, min_batch_size: int = 0)[source]#
Generate slices containing batch_size elements from 0 to n.
Used to split the dataset into manageable batches that fit in GPU memory. The last slice may contain less than batch_size elements, when batch_size does not divide n.
- Parameters:
- Yields:
slice – A slice object representing indices [start:end] for the current batch.
Examples
>>> list(IncrementalPCA.gen_batches(10, 3)) [slice(0, 3), slice(3, 6), slice(6, 9), slice(9, 10)]
- partial_fit(X, check_input=True)[source]#
Fit incrementally the model with batch data X.
This method updates the PCA model with a new batch of data without requiring access to previously seen data. It maintains running statistics (mean, variance) and incrementally updates the principal components.
The batch X should already be on the computation device when called from _fit_transform. When called directly, X can be on any device.
- Parameters:
X (torch.Tensor) – The batch input data tensor with shape (n_samples, n_features). Should fit in memory/GPU memory.
check_input (bool, optional) – If True, validates the input. Defaults to True.
- Returns:
The updated IPCA model after processing the batch.
- Return type:
Examples
Basic usage with manual batching:
ipca = IncrementalPCA(n_components=10) for i in range(0, len(X), batch_size): ipca.partial_fit(X[i:i+batch_size])
With PyTorch DataLoader (recommended for large datasets):
from torch.utils.data import DataLoader dataloader = DataLoader(dataset, batch_size=256) ipca = IncrementalPCA(n_components=50) for batch in dataloader: # Handle DataLoader's (X, y) tuple format X_batch = batch[0] if isinstance(batch, tuple) else batch ipca.partial_fit(X_batch)
Notes
Parameters (mean_, components_, etc.) are stored on the computation device, which is either self.device (if specified) or the device of the first batch (if self.device == “auto”).
Uses Welford’s algorithm for numerically stable incremental mean/variance.
SVD is performed on augmented matrix containing previous components and current batch to update the decomposition.
This is the recommended method for fitting with DataLoader or generators.
- set_partial_fit_request(*, check_input: bool | None | str = '$UNCHANGED$') IncrementalPCA #
Configure whether metadata should be requested to be passed to the
partial_fit
method.Note that this method is only relevant when this estimator is used as a sub-estimator within a meta-estimator and metadata routing is enabled with
enable_metadata_routing=True
(seesklearn.set_config()
). Please check the User Guide on how the routing mechanism works.The options for each parameter are:
True
: metadata is requested, and passed topartial_fit
if provided. The request is ignored if metadata is not provided.False
: metadata is not requested and the meta-estimator will not pass it topartial_fit
.None
: metadata is not requested, and the meta-estimator will raise an error if the user provides it.str
: metadata should be passed to the meta-estimator with this given alias instead of the original name.
The default (
sklearn.utils.metadata_routing.UNCHANGED
) retains the existing request. This allows you to change the request for some parameters and not others.Added in version 1.3.
- check_inputstr, True, False, or None, default=sklearn.utils.metadata_routing.UNCHANGED
Metadata routing for
check_input
parameter inpartial_fit
.
- selfobject
The updated object.
- transform(X: ArrayLike) ArrayLike [source]#
Apply dimensionality reduction to X.
Projects input data onto the principal components learned during fitting. Unlike fit, this processes the full input at once, not in batches.
Device Management: - If X and parameters are on different devices, X is temporarily moved
to parameters’ device for computation
Result is moved back to X’s original device
This avoids moving parameters which would be inefficient for repeated transforms
- Parameters:
X (ArrayLike) – New data with shape (n_samples, n_features) to be transformed. Can be on any device.
- Returns:
Transformed data with shape (n_samples, n_components). Will be on the same device and format as input X.
- Return type:
ArrayLike