azula.nn.unet

U-Net building blocks.

Classes

UNetBlock

Creates a modulated U-Net block module.

UNet

Creates a modulated U-Net module.

Descriptions

class azula.nn.unet.UNetBlock(channels, mod_features, attention_heads=None, dropout=None, spatial=2, **kwargs)

Creates a modulated U-Net block module.

Parameters:
  • channels (int) – The number of channels \(C\).

  • mod_features (int) – The number of modulating features \(D\).

  • attention_heads (int | None) – The number of attention heads.

  • dropout (float | None) – The dropout rate in \([0, 1]\).

  • spatial (int) – The number of spatial dimensions \(N\).

  • kwargs – Keyword arguments passed to torch.nn.Conv2d.

forward(x, mod)
Parameters:
  • x (Tensor) – The input tensor, with shape \((B, C, H_1, ..., H_N)\).

  • mod (Tensor) – The modulation vector, with shape \((D)\) or \((B, D)\).

Returns:

The output tensor, with shape \((B, C, H_1, ..., H_N)\).

Return type:

Tensor

class azula.nn.unet.UNet(in_channels, out_channels, mod_features, hid_channels=(64, 128, 256), hid_blocks=(3, 3, 3), kernel_size=3, stride=2, attention_heads={}, dropout=None, spatial=2)

Creates a modulated U-Net module.

Parameters:
  • in_channels (int) – The number of input channels \(C_i\).

  • out_channels (int) – The number of output channels \(C_o\).

  • mod_features (int) – The number of modulating features \(D\).

  • hid_channels (Sequence[int]) – The numbers of channels at each depth.

  • hid_blocks (Sequence[int]) – The numbers of hidden blocks at each depth.

  • kernel_size (int | Sequence[int]) – The kernel size of all convolutions.

  • stride (int | Sequence[int]) – The stride of the downsampling convolutions.

  • attention_heads (Dict[int, int]) – The number of attention heads at each depth.

  • dropout (float | None) – The dropout rate in \([0, 1]\).

  • spatial (int) – The number of spatial dimensions.

forward(x, mod)
Parameters:
  • x (Tensor) – The input tensor, with shape \((B, C_i, H_1, ..., H_N)\).

  • mod (Tensor) – The modulation vector, with shape \((D)\) or \((B, D)\).

Returns:

The output tensor, with shape \((B, C_o, H_1, ..., H_N)\).

Return type:

Tensor