azula.nn.dit¶
Diffusion Transformer (DiT) building blocks.
References
Classes¶
Descriptions¶
- class azula.nn.dit.DiT(in_channels, out_channels, cond_channels=0, mod_features=0, pos_channels=1, hid_channels=1024, hid_blocks=3, **kwargs)¶[source]
Creates a modulated DiT-like module.
- Parameters:
in_channels (int) – The number of input channels \(C_i\).
out_channels (int) – The number of output channels \(C_o\).
cond_channels (int) – The number of condition channels \(C_c\).
mod_features (int) – The number of modulating features \(D\).
pos_channels (int) – The number of positional channels \(P\).
hid_channels (int) – The numbers of hidden token channels \(C_h\).
hid_blocks (int) – The number of hidden transformer blocks.
kwargs – Keyword arguments passed to
DiTBlock.
- forward(x, mod=None, pos=None, cond=None)¶[source]
- Parameters:
x (Tensor) – The input tensor, with shape \((*, L, C_i)\).
mod (Tensor | None) – The modulation vector, with shape \((D)\) or \((*, D)\).
pos (Tensor | None) – The position tensor, with shape \((*, L, P)\). If
None, use the sequence indices instead.cond (Tensor | None) – The condition tensor, with shape \((*, L, C_c)\).
- Returns:
The output tensor, with shape \((*, L, C_o)\).
- Return type:
- class azula.nn.dit.DiTBlock(channels, mod_features=0, ffn_factor=4, ffn_activation='silu', dropout=None, checkpointing=False, **kwargs)¶[source]
Creates a modulated DiT block module.
- Parameters:
channels (int) – The number of channels \(C\).
mod_features (int) – The number of modulating features \(D\).
ffn_factor (int) – The channel factor in the FFN.
ffn_activation (Literal['relu', 'relu2', 'silu', 'swiglu']) – The activation function in the FFN.
dropout (float | None) – The dropout rate in \([0, 1]\).
checkpointing (bool) – Whether to use activation checkpointing or not.
kwargs – Keyword arguments passed to
MultiheadSelfAttention.