azula.nn.dit

Diffusion Transformer (DiT) building blocks.

References

Scalable Diffusion Models with Transformers (Peebles et al., 2022)

Classes

DiT

Creates a modulated DiT-like module.

DiTBlock

Creates a modulated DiT block module.

Descriptions

class azula.nn.dit.DiT(in_channels, out_channels, cond_channels=0, mod_features=0, pos_channels=1, hid_channels=1024, hid_blocks=3, **kwargs)[source]

Creates a modulated DiT-like module.

Parameters:
  • in_channels (int) – The number of input channels \(C_i\).

  • out_channels (int) – The number of output channels \(C_o\).

  • cond_channels (int) – The number of condition channels \(C_c\).

  • mod_features (int) – The number of modulating features \(D\).

  • pos_channels (int) – The number of positional channels \(P\).

  • hid_channels (int) – The numbers of hidden token channels \(C_h\).

  • hid_blocks (int) – The number of hidden transformer blocks.

  • kwargs – Keyword arguments passed to DiTBlock.

forward(x, mod=None, pos=None, cond=None)[source]
Parameters:
  • x (Tensor) – The input tensor, with shape \((*, L, C_i)\).

  • mod (Tensor | None) – The modulation vector, with shape \((D)\) or \((*, D)\).

  • pos (Tensor | None) – The position tensor, with shape \((*, L, P)\). If None, use the sequence indices instead.

  • cond (Tensor | None) – The condition tensor, with shape \((*, L, C_c)\).

Returns:

The output tensor, with shape \((*, L, C_o)\).

Return type:

Tensor

class azula.nn.dit.DiTBlock(channels, mod_features=0, ffn_factor=4, ffn_activation='silu', dropout=None, checkpointing=False, **kwargs)[source]

Creates a modulated DiT block module.

Parameters:
  • channels (int) – The number of channels \(C\).

  • mod_features (int) – The number of modulating features \(D\).

  • ffn_factor (int) – The channel factor in the FFN.

  • ffn_activation (Literal['relu', 'relu2', 'silu', 'swiglu']) – The activation function in the FFN.

  • dropout (float | None) – The dropout rate in \([0, 1]\).

  • checkpointing (bool) – Whether to use activation checkpointing or not.

  • kwargs – Keyword arguments passed to MultiheadSelfAttention.

forward(x, mod=None, pos=None, mask=None)[source]
Parameters:
  • x (Tensor) – The input tokens \(x\), with shape \((*, L, C)\).

  • mod (Tensor | None) – The modulation vector, with shape \((D)\) or \((*, D)\).

  • pos (Tensor | None) – The postition coordinates, with shape \((*, L, N)\).

  • mask (Tensor | None) – The attention mask, with shape \((*, L, L)\).

Returns:

The ouput tokens \(y\), with shape \((*, L, C)\).

Return type:

Tensor