azula.nn.vit

Vision Transformer (ViT) building blocks.

References

An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale (Dosovitskiy et al., 2021)
Scalable Diffusion Models with Transformers (Peebles et al., 2022)

Classes

ViTBlock

Creates a ViT block module.

ViT

Creates a modulated ViT-like module.

Descriptions

class azula.nn.vit.ViTBlock(channels, mod_features=0, ffn_factor=4, ffn_activation='silu', dropout=None, checkpointing=False, **kwargs)

Creates a ViT block module.

Parameters:
  • channels (int) – The number of channels \(C\).

  • mod_features (int) – The number of modulating features \(D\).

  • ffn_factor (int) – The channel factor in the FFN.

  • ffn_activation (str) – The activation function in the FFN. Options are relu, relu2, silu and swiglu.

  • dropout (float | None) – The dropout rate in \([0, 1]\).

  • checkpointing (bool) – Whether to use gradient checkpointing or not.

  • kwargs – Keyword arguments passed to azula.nn.attention.MultiheadSelfAttention.

forward(x, mod=None, pos=None, mask=None)
Parameters:
  • x (Tensor) – The input tokens \(x\), with shape \((*, L, C)\).

  • mod (Tensor | None) – The modulation vector, with shape \((D)\) or \((*, D)\).

  • pos (Tensor | None) – The postition coordinates, with shape \((*, L, N)\).

  • mask (Tensor | None) – The attention mask, with shape \((*, L, L)\).

Returns:

The ouput tokens \(y\), with shape \((*, L, C)\).

Return type:

Tensor

class azula.nn.vit.ViT(in_channels, out_channels, cond_channels=0, mod_features=0, hid_channels=1024, hid_blocks=3, spatial=2, patch_size=1, unpatch_size=None, **kwargs)

Creates a modulated ViT-like module.

Parameters:
  • in_channels (int) – The number of input channels \(C_i\).

  • out_channels (int) – The number of output channels \(C_o\).

  • cond_channels (int) – The number of condition channels \(C_c\).

  • mod_features (int) – The number of modulating features \(D\).

  • hid_channels (int) – The numbers of hidden token channels.

  • hid_blocks (int) – The number of hidden transformer blocks.

  • spatial (int) – The number of spatial dimensions \(N\).

  • patch_size (int | Sequence[int]) – The patch size or shape.

  • unpatch_size (int | Sequence[int] | None) – The unpatch size or shape.

  • kwargs – Keyword arguments passed to ViTBlock.

forward(x, mod=None, cond=None)
Parameters:
  • x (Tensor) – The input tensor, with shape \((B, C_i, L_1, ..., L_N)\).

  • mod (Tensor | None) – The modulation vector, with shape \((D)\) or \((B, D)\).

  • cond (Tensor | None) – The condition tensor, with \((B, C_c, L_1, ..., L_N)\).

Returns:

The output tensor, with shape \((B, C_o, L_1, ..., L_N)\).

Return type:

Tensor