azula.nn.unet¶
U-Net building blocks.
Classes¶
Descriptions¶
- class azula.nn.unet.UNetBlock(channels, mod_features, attention_heads=None, dropout=None, spatial=2, **kwargs)¶
Creates a modulated U-Net block module.
- Parameters:
channels (int) – The number of channels \(C\).
mod_features (int) – The number of modulating features \(D\).
attention_heads (int | None) – The number of attention heads.
dropout (float | None) – The dropout rate in \([0, 1]\).
spatial (int) – The number of spatial dimensions \(N\).
kwargs – Keyword arguments passed to
torch.nn.Conv2d.
- class azula.nn.unet.UNet(in_channels, out_channels, mod_features, hid_channels=(64, 128, 256), hid_blocks=(3, 3, 3), kernel_size=3, stride=2, attention_heads={}, dropout=None, spatial=2)¶
Creates a modulated U-Net module.
- Parameters:
in_channels (int) – The number of input channels \(C_i\).
out_channels (int) – The number of output channels \(C_o\).
mod_features (int) – The number of modulating features \(D\).
hid_channels (Sequence[int]) – The numbers of channels at each depth.
hid_blocks (Sequence[int]) – The numbers of hidden blocks at each depth.
kernel_size (int | Sequence[int]) – The kernel size of all convolutions.
stride (int | Sequence[int]) – The stride of the downsampling convolutions.
attention_heads (Dict[int, int]) – The number of attention heads at each depth.
dropout (float | None) – The dropout rate in \([0, 1]\).
spatial (int) – The number of spatial dimensions.