azula.nn.utils

Miscellaneous neural network helpers.

Classes

skip_init

Creates a context in which weight initialization is skipped.

Functions

checkpoint

Applies activation checkpointing to a function.

get_module_device

Returns the execution device of a module.

get_module_dtype

Returns the data type of a module.

promote_dtype

Applies data type promotion to a function.

Descriptions

azula.nn.utils.checkpoint(f, reentrant=False)[source]

Applies activation checkpointing to a function.

Activation checkpointing reduces memory consumption by storing the inputs of the function and recomputing its graph during automatic differentiation (AD).

Reentrant checkpointing is compatible with backward and forward AD, but only propagates gradients to the explicit positional inputs of the function. Implicit inputs, such as module parameters, do not get gradients. Conversely, non-reentrant will propagate gradients to implicit inputs, but is not compatible with foward AD.

Parameters:
  • f (Callable[[...], T]) – A function.

  • reentrant (bool) – Whether to use reentrant checkpointing or not.

Returns:

The checkpointed function.

Return type:

Callable[[…], T]

azula.nn.utils.get_module_device(module)[source]

Returns the execution device of a module.

The module’s device is the first device in the module’s parameters or buffers. If there is none, returns None.

Parameters:

module (Module) – A module.

azula.nn.utils.get_module_dtype(module)[source]

Returns the data type of a module.

The module’s data type is the first floating-point type in the module’s parameters or buffers. If there is none, returns None.

Parameters:

module (Module) – A module.

azula.nn.utils.promote_dtype(f, min_dtype=torch.float32)[source]

Applies data type promotion to a function.

Parameters:
  • f (Callable[[...], T]) – A function.

  • min_dtype (dtype) – The minimum precision data type.

Returns:

The promoted function.

Return type:

Callable[[…], T]

class azula.nn.utils.skip_init[source]

Creates a context in which weight initialization is skipped.

Example

>>> with skip_init():
...    layer = torch.nn.Linear(3, 5)