Source code for torchdr.utils.root_search

"""Root search algorithms for solving scalar equations."""

# Author: Hugues Van Assel <vanasselhugues@gmail.com>
#         Rémi Flamary <remi.flamary@polytechnique.edu>
#
# License: BSD 3-Clause License

import torch


from typing import Callable, Tuple

_DEFAULT_TOL = torch.tensor(1e-6)






[docs] @torch.compiler.disable def false_position( f: Callable[[torch.Tensor], torch.Tensor], n: int, begin: float = 1.0, end: float = 1.0, max_iter: int = 100, dtype: torch.dtype = torch.float32, device: torch.device = torch.device("cpu"), ) -> torch.Tensor: """Batched false-position root finding. Uses linear interpolation to bracket and converge on the root of an increasing function f. Parameters ---------- f : Callable[[torch.Tensor], torch.Tensor] Batched 1-D increasing function. n : int Batch size (length of the input/output vectors). begin : float, optional Scalar initial lower bound (default: 1.0). end : float, optional Scalar initial upper bound (default: 1.0). max_iter : int, optional Maximum number of iterations (default: 1000). dtype : torch.dtype, optional Data type of all tensors (default: torch.float32). device : torch.device, optional Device for all tensors (default: CPU). Returns ------- m : torch.Tensor of shape (n,) Estimated roots where |f(m)| < tol. """ tol = _DEFAULT_TOL.to(device).to(dtype) b, e = init_bounds(f, n, begin, end, max_iter=max_iter, dtype=dtype, device=device) f_b = f(b) f_e = f(e) m = b - (b - e) / (f_b - f_e) * f_b f_m = f(m) for _ in range(max_iter): active = torch.abs(f_m) >= tol if not active.any(): break same_sign = f_m * f_b > 0 mask1 = active & same_sign b[mask1] = m[mask1] f_b[mask1] = f_m[mask1] mask2 = active & (~same_sign) e[mask2] = m[mask2] f_e[mask2] = f_m[mask2] m = b - (b - e) / (f_b - f_e) * f_b f_m = f(m) return m
@torch.compiler.disable def init_bounds( f: Callable[[torch.Tensor], torch.Tensor], n: int, begin=1.0, end=1.0, max_iter=100, dtype: torch.dtype = torch.float32, device="cpu", ) -> Tuple[torch.Tensor, torch.Tensor]: """Initialize root‐search bounds for f, supporting both scalar and tensor inputs.""" if isinstance(begin, torch.Tensor): b = begin.to(dtype=dtype, device=device) if b.shape != (n,): raise ValueError(f"begin tensor must have shape ({n},), got {b.shape}") else: b = torch.full((n,), float(begin), dtype=dtype, device=device) if isinstance(end, torch.Tensor): e = end.to(dtype=dtype, device=device) if e.shape != (n,): raise ValueError(f"end tensor must have shape ({n},), got {e.shape}") else: e = torch.full((n,), float(end), dtype=dtype, device=device) # shrink `b` downward until f(b) ≤ 0, pulling `e` in with it for _ in range(max_iter): mask = f(b) > 0 if not mask.any(): break old_b = b e = torch.where(mask, torch.min(e, old_b), e) b = torch.where(mask, b * 0.5, b) # expand `e` upward until f(e) ≥ 0, pushing `b` out with it for _ in range(max_iter): mask = f(e) < 0 if not mask.any(): break old_e = e b = torch.where(mask, torch.max(b, old_e), b) e = torch.where(mask, e * 2.0, e) return b, e