azula.nn.utils

Module helpers.

Classes

FlattenWrapper

Creates a flatten/unflatten wrapper around a backbone.

Descriptions

class azula.nn.utils.FlattenWrapper(wrappee, shape)

Creates a flatten/unflatten wrapper around a backbone.

The purpose of this module is to create a flatten/unflatten frontier between azula components that opperate on one-dimensional vectors and a backbone that opperates on multi-dimensional tensors, like \(C \times H \times W\) images.

Parameters:
  • wrappee (Module) – The wrapped backbone.

  • shape (Sequence[int]) – The tensor shape.

forward(x_t, t, **kwargs)
Parameters:
  • x_t (Tensor) – A noisy vector \(x_t\), with shape \((*, D)\).

  • t (Tensor) – The time \(t\), with shape \((*)\).

  • kwargs – Optional keyword arguments.

Returns:

The output vector(s), with shape \((*, D)\).

Return type:

Tensor | List[Tensor]