azula.nn.attention

Attention layers.

Classes

MultiheadSelfAttention

Creates a multi-head self-attention layer.

Descriptions

class azula.nn.attention.MultiheadSelfAttention(channels, attention_heads=1, qk_norm=True, dropout=None, checkpointing=False)

Creates a multi-head self-attention layer.

Parameters:
  • channels (int) – The number of channels \(H \times C\).

  • attention_heads (int) – The number of attention heads \(H\).

  • qk_norm (bool) – Whether to use query-key RMS-normalization or not.

  • dropout (float | None) – The dropout rate in \([0, 1]\).

  • checkpointing (bool) – Whether to use gradient checkpointing or not.

forward(x, theta=None, mask=None)
Parameters:
  • x (Tensor) – The input tokens \(x\), with shape \((*, L, H \times C)\).

  • theta (Tensor | None) – Optional rotary positional embedding \(\theta\), with shape \((*, L, H \times C / 2)\).

  • mask (Tensor | None) – Optional attention mask, with shape \((L, L)\).

Returns:

The ouput tokens \(y\), with shape \((*, L, H \times C)\).

Return type:

Tensor