azula.nn.vit¶
Vision Transformer (ViT) building blocks.
References
Classes¶
Descriptions¶
- class azula.nn.vit.ViTBlock(channels, mod_features=0, ffn_factor=4, spatial=2, rope=True, dropout=None, checkpointing=False, **kwargs)¶
Creates a ViT block module.
- Parameters:
channels (int) – The number of channels \(C\).
mod_features (int) – The number of modulating features \(D\).
ffn_factor (int) – The channel factor in the FFN.
spatial (int) – The number of spatial dimensinons \(N\). Only necessary with RoPE.
rope (bool) – Whether to use rotary positional embedding (RoPE) or not.
dropout (float | None) – The dropout rate in \([0, 1]\).
checkpointing (bool) – Whether to use gradient checkpointing or not.
kwargs – Keyword arguments passed to
azula.nn.attention.MultiheadSelfAttention.
- forward(x, mod=None, pos=None, mask=None)¶
- Parameters:
- Returns:
The ouput tokens \(y\), with shape \((*, L, C)\).
- Return type:
- class azula.nn.vit.ViT(in_channels, out_channels, cond_channels=0, mod_features=0, hid_channels=1024, hid_blocks=3, spatial=2, patch_size=1, unpatch_size=None, **kwargs)¶
Creates a modulated ViT-like module.
- Parameters:
in_channels (int) – The number of input channels \(C_i\).
out_channels (int) – The number of output channels \(C_o\).
cond_channels (int) – The number of condition channels \(C_c\).
mod_features (int) – The number of modulating features \(D\).
hid_channels (int) – The numbers of hidden token channels.
hid_blocks (int) – The number of hidden transformer blocks.
spatial (int) – The number of spatial dimensions \(N\).
unpatch_size (int | Sequence[int] | None) – The unpatch size or shape.
kwargs – Keyword arguments passed to
ViTBlock.
- forward(x, mod=None, cond=None)¶
- Parameters:
- Returns:
The output tensor, with shape \((B, C_o, L_1, ..., L_N)\).
- Return type: