binary_search#
- torchdr.binary_search(f: Callable[[Tensor], Tensor], n: int, begin: float = 1.0, end: float = 1.0, max_iter: int = 100, dtype: dtype = torch.float32, device: device = device(type='cpu')) Tensor [source]#
Batched binary search root finding.
Finds the roots of an increasing function f over positive inputs by repeatedly narrowing the bracket [begin, end].
- Parameters:
f (Callable[[torch.Tensor], torch.Tensor]) – Batched 1-D increasing function.
n (int) – Batch size (length of the input/output vectors).
begin (float, optional) – Scalar initial lower bound (default: 1.0).
end (float, optional) – Scalar initial upper bound (default: 1.0).
max_iter (int, optional) – Maximum number of iterations (default: 1000).
dtype (torch.dtype, optional) – Data type of all tensors (default: torch.float32).
device (torch.device, optional) – Device for all tensors (default: CPU).
- Returns:
m – Estimated roots where |f(m)| < tol.
- Return type:
torch.Tensor of shape (n,)