azula.guidance.mmps

Moment matching posterior sampling (MMPS) internals.

References

Learning Diffusion Priors from Observations by Expectation Maximization (Rozet et al., 2024)

Classes

MMPSDenoiser

Creates a MMPS denoiser module.

Descriptions

class azula.guidance.mmps.MMPSDenoiser(denoiser, y, A, var_y, tweedie_covariance=True, solver='gmres', iterations=1)

Creates a MMPS denoiser module.

Parameters:
  • denoiser (GaussianDenoiser) – A Gaussian denoiser.

  • y (Tensor) – An observation \(y \sim \mathcal{N}(A(x), \Sigma_y)\), with shape \((*, D)\).

  • A (Callable[[Tensor], Tensor]) – The forward operator \(x \mapsto A(x)\).

  • var_y (Tensor) – The noise variance \(\Sigma_y\).

  • tweedie_covariance (bool) – Whether to use the Tweedie covariance formula or not. If False, use \(\Sigma_\phi(x_t)\) instead.

  • solver (str) – The linear solver name ("cg" or "gmres").

  • iterations (int) – The number of solver iterations.