chop.utils

General utility functions.

Functions

bdot(tensor, other)

Returns the batch-wise dot product between tensor and other. Supposes that the shapes are (batch_size, *).

bmm(tensor, other)

bmul(tensor, other)

Batch multiplies tensor and other

bmv(tensor, vector)

closure(f)

get_func_and_jac(func, x, *args, **kwargs)

Computes the jacobian of a batch-wise separable function func of x. func returns a torch.Tensor of shape (batch_size,) when x is a torch.Tensor of shape (batch_size, *). Adapted from https://gist.github.com/sbarratt/37356c46ad1350d4c30aefbd488a4faa by Shane Baratt.

init_lipschitz(closure, x0[, L0, n_it])

Estimates the Lipschitz constant of closure for each datapoint in the batch using backtracking line-search.

power_iteration(mat[, n_iter, tol])

Obtains the largest singular value of a matrix, batch wise, and the associated left and right singular vectors.

chop.utils.bdot(tensor, other)[source]

Returns the batch-wise dot product between tensor and other. Supposes that the shapes are (batch_size, *)

chop.utils.bmul(tensor, other)[source]

Batch multiplies tensor and other

chop.utils.get_func_and_jac(func, x, *args, **kwargs)[source]

Computes the jacobian of a batch-wise separable function func of x. func returns a torch.Tensor of shape (batch_size,) when x is a torch.Tensor of shape (batch_size, *). Adapted from https://gist.github.com/sbarratt/37356c46ad1350d4c30aefbd488a4faa by Shane Baratt

chop.utils.init_lipschitz(closure, x0, L0=0.001, n_it=100)[source]

Estimates the Lipschitz constant of closure for each datapoint in the batch using backtracking line-search.

Parameters
  • closure – callable returns func_val, jacobian

  • x0 – torch.tensor of shape (batch_size, *)

  • L0 – float initial guess

  • n_it – int number of iterations

Returns

torch.tensor of shape (batch_size,)

Return type

Lt

chop.utils.power_iteration(mat, n_iter: int = 10, tol: float = 1e-06)[source]

Obtains the largest singular value of a matrix, batch wise, and the associated left and right singular vectors.

Parameters
  • mat – torch.Tensor of shape (*, M, N)

  • n_iter – int number of iterations to perform

  • tol – float Tolerance. Not used for now.