azula.nn.unet¶
U-Net building blocks.
Classes¶
Descriptions¶
- class azula.nn.unet.UNet(in_channels, out_channels, cond_channels=0, hid_channels=(64, 128, 256), hid_blocks=(3, 3, 3), kernel_size=3, stride=2, spatial=2, periodic=False, identity_init=False, **kwargs)¶[source]
Creates a modulated U-Net module.
- Parameters:
in_channels (int) – The number of input channels \(C_i\).
out_channels (int) – The number of output channels \(C_o\).
cond_channels (int) – The number of condition channels \(C_c\).
hid_channels (Sequence[int]) – The numbers of channels at each depth.
hid_blocks (Sequence[int]) – The numbers of hidden blocks at each depth.
kernel_size (int | Sequence[int]) – The kernel size of all convolutions.
stride (int | Sequence[int]) – The stride of the downsampling convolutions.
spatial (int) – The number of spatial dimensions \(N\).
periodic (bool) – Whether the spatial dimensions are periodic or not.
identity_init (bool) – Initialize down/upsampling convolutions as identity.
kwargs – Keyword arguments passed to
UNetBlock.
- class azula.nn.unet.UNetBlock(channels, mod_features=0, norm='layer', groups=16, ffn_factor=1, spatial=2, dropout=None, checkpointing=False, **kwargs)¶[source]
Creates a modulated U-Net block module.
- Parameters:
channels (int) – The number of channels \(C\).
mod_features (int) – The number of modulating features \(D\).
norm (str) – The kind of normalization. Options are
group,layerandrms.groups (int) – The number of groups in
torch.nn.GroupNormlayers.ffn_factor (int) – The channel factor in the FFN.
spatial (int) – The number of spatial dimensions \(N\).
dropout (float | None) – The dropout rate in \([0, 1]\).
checkpointing (bool) – Whether to use activation checkpointing or not.
kwargs – Keyword arguments passed to
azula.nn.layers.ConvNd.