azula.plugins.adm

Ablated diffusion model (ADM) plugin.

from azula.plugins import adm

References

Diffusion Models Beat GANs on Image Synthesis (Dhariwal et al., 2021)

Classes

AblatedDenoiser

Creates an ablated denoiser.

Functions

load_model

Loads a pre-trained ADM denoiser.

Descriptions

class azula.plugins.adm.AblatedDenoiser(backbone, schedule=None, clip_mean=False, learn_var=False, discrete_schedule='linear', discrete_steps=1000)[source]

Creates an ablated denoiser.

Parameters:
  • backbone (Module) – A time conditional network.

  • schedule (Schedule) – A noise schedule. If None, use azula.noise.VPSchedule instead.

  • clip_mean (bool) – Whether the mean \(\mu_\phi(x_t)\) is clipped to \([-1, 1]\) or not during evaluation.

  • learn_var (bool) – Whether the variance \(\sigma^2_\phi(x_t)\) is learned or not. For pre-trained models, the learned variance is indicative, but inexact.

forward(x_t, t, label=None, **kwargs)[source]
Parameters:
  • x_t (Tensor) – A noisy tensor \(x_t\), with shape \((B, 3, H, W)\).

  • t (Tensor) – The time \(t\), with shape \(()\) or \((B)\).

  • label (Tensor | None) – The class label \(c\) as an integer, with shape \((B)\).

  • kwargs – Optional keyword arguments.

Returns:

The Gaussian \(\mathcal{N}(X \mid \mu_\phi(x_t \mid c), \sigma^2_\phi(x_t \mid c))\).

Return type:

GaussianPosterior

azula.plugins.adm.load_model(name, **kwargs)[source]

Loads a pre-trained ADM denoiser.

Parameters:
  • name (str) – The pre-trained model name.

  • kwargs – Keyword arguments passed to torch.load.

Returns:

A pre-trained denoiser.

Return type:

Denoiser