azula.nn.vit

Vision Transformer (ViT) building blocks.

References

An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale (Dosovitskiy et al., 2021)

Classes

ViT

Creates a modulated ViT-like module.

Descriptions

class azula.nn.vit.ViT(in_channels, out_channels, cond_channels=0, mod_features=0, hid_channels=1024, hid_blocks=3, spatial=2, patch_size=1, unpatch_size=None, **kwargs)[source]

Creates a modulated ViT-like module.

Parameters:
  • in_channels (int) – The number of input channels \(C_i\).

  • out_channels (int) – The number of output channels \(C_o\).

  • cond_channels (int) – The number of condition channels \(C_c\).

  • mod_features (int) – The number of modulating features \(D\).

  • hid_channels (int) – The numbers of hidden token channels.

  • hid_blocks (int) – The number of hidden transformer blocks.

  • spatial (int) – The number of spatial dimensions \(N\).

  • patch_size (int | Sequence[int]) – The patch size or shape.

  • unpatch_size (int | Sequence[int] | None) – The unpatch size or shape.

  • kwargs – Keyword arguments passed to ViTBlock.

forward(x, mod=None, cond=None)[source]
Parameters:
  • x (Tensor) – The input tensor, with shape \((B, C_i, L_1, ..., L_N)\).

  • mod (Tensor | None) – The modulation vector, with shape \((D)\) or \((B, D)\).

  • cond (Tensor | None) – The condition tensor, with \((B, C_c, L_1, ..., L_N)\).

Returns:

The output tensor, with shape \((B, C_o, L_1, ..., L_N)\).

Return type:

Tensor