azula.linalg.solve

Linear system solvers.

Functions

cg

Solves a linear system \(Ax = b\) with conjugate gradient (CG) iterations.

gmres

Solves a linear system \(Ax = b\) with generalized minimal residual (GMRES) iterations.

Descriptions

azula.linalg.solve.cg(A, b, x0=None, iterations=1, dtype=None)[source]

Solves a linear system \(Ax = b\) with conjugate gradient (CG) iterations.

The matrix \(A \in \mathbb{R}^{D \times D}\) must be symmetric positive (semi)definite.

Wikipedia

https://wikipedia.org/wiki/Conjugate_gradient_method

Warning

This function is optimized for GPU execution. To avoid CPU-GPU communication, all iterations are performed regardless of convergence.

Parameters:
  • A (Callable[[Tensor], Tensor]) – The linear operator \(x \mapsto Ax\).

  • b (Tensor) – The right-hand side vector \(b\), with shape \((*, D)\).

  • x0 (Tensor | None) – An initial guess \(x_0\), with shape \((*, D)\). If None, use \(x_0 = 0\) instead.

  • iterations (int) – The number of CG iterations \(n\).

  • dtype (dtype | None) – The data type used for intermediate computations. If None, use torch.float64 instead.

Returns:

The \(n\)-th iteration \(x_n\), with shape \((*, D)\).

Return type:

Tensor

azula.linalg.solve.gmres(A, b, x0=None, iterations=1, dtype=None)[source]

Solves a linear system \(Ax = b\) with generalized minimal residual (GMRES) iterations.

The matrix \(A \in \mathbb{R}^{D \times D}\) can be non-symmetric non-definite.

Wikipedia

https://wikipedia.org/wiki/Generalized_minimal_residual_method

Warning

This function is optimized for GPU execution. To avoid CPU-GPU communication, all iterations are performed regardless of convergence.

Parameters:
  • A (Callable[[Tensor], Tensor]) – The linear operator \(x \mapsto Ax\).

  • b (Tensor) – The right-hand side vector \(b\), with shape \((*, D)\).

  • x0 (Tensor | None) – An initial guess \(x_0\), with shape \((*, D)\). If None, use \(x_0 = 0\) instead.

  • iterations (int) – The number of GMRES iterations \(n\).

  • dtype (dtype | None) – The data type used for intermediate computations. If None, use torch.float64 instead.

Returns:

The \(n\)-th iteration \(x_n\), with shape \((*, D)\).

Return type:

Tensor