azula.nn.unet

U-Net building blocks.

Classes

UNet

Creates a modulated U-Net module.

UNetBlock

Creates a modulated U-Net block module.

Descriptions

class azula.nn.unet.UNet(in_channels, out_channels, cond_channels=0, hid_channels=(64, 128, 256), hid_blocks=(3, 3, 3), kernel_size=3, stride=2, spatial=2, periodic=False, identity_init=False, **kwargs)[source]

Creates a modulated U-Net module.

Parameters:
  • in_channels (int) – The number of input channels \(C_i\).

  • out_channels (int) – The number of output channels \(C_o\).

  • cond_channels (int) – The number of condition channels \(C_c\).

  • hid_channels (Sequence[int]) – The numbers of channels at each depth.

  • hid_blocks (Sequence[int]) – The numbers of hidden blocks at each depth.

  • kernel_size (int | Sequence[int]) – The kernel size of all convolutions.

  • stride (int | Sequence[int]) – The stride of the downsampling convolutions.

  • spatial (int) – The number of spatial dimensions \(N\).

  • periodic (bool) – Whether the spatial dimensions are periodic or not.

  • identity_init (bool) – Initialize down/upsampling convolutions as identity.

  • kwargs – Keyword arguments passed to UNetBlock.

forward(x, mod=None, cond=None)[source]
Parameters:
  • x (Tensor) – The input tensor, with shape \((B, C_i, L_1, ..., L_N)\).

  • mod (Tensor | None) – The modulation vector, with shape \((D)\) or \((B, D)\).

  • cond (Tensor | None) – The condition tensor, with shape \((B, C_c, L_1, ..., L_N)\).

Returns:

The output tensor, with shape \((B, C_o, L_1, ..., L_N)\).

Return type:

Tensor

class azula.nn.unet.UNetBlock(channels, mod_features=0, norm='layer', groups=16, ffn_factor=1, spatial=2, dropout=None, checkpointing=False, **kwargs)[source]

Creates a modulated U-Net block module.

Parameters:
  • channels (int) – The number of channels \(C\).

  • mod_features (int) – The number of modulating features \(D\).

  • norm (str) – The kind of normalization. Options are group, layer and rms.

  • groups (int) – The number of groups in torch.nn.GroupNorm layers.

  • ffn_factor (int) – The channel factor in the FFN.

  • spatial (int) – The number of spatial dimensions \(N\).

  • dropout (float | None) – The dropout rate in \([0, 1]\).

  • checkpointing (bool) – Whether to use activation checkpointing or not.

  • kwargs – Keyword arguments passed to azula.nn.layers.ConvNd.

forward(x, mod=None)[source]
Parameters:
  • x (Tensor) – The input tensor, with shape \((B, C, L_1, ..., L_N)\).

  • mod (Tensor | None) – The modulation vector, with shape \((D)\) or \((B, D)\).

Returns:

The output tensor, with shape \((B, C, L_1, ..., L_N)\).

Return type:

Tensor