azula.nn.utils¶
Miscellaneous neural network helpers.
Classes¶
Creates a context in which weight initialization is skipped. |
Functions¶
Applies activation checkpointing to a function. |
|
Returns the execution device of a module. |
|
Returns the data type of a module. |
|
Applies data type promotion to a function. |
Descriptions¶
- azula.nn.utils.checkpoint(f, reentrant=False)¶[source]
Applies activation checkpointing to a function.
Activation checkpointing reduces memory consumption by storing the inputs of the function and recomputing its graph during automatic differentiation (AD).
Reentrant checkpointing is compatible with backward and forward AD, but only propagates gradients to the explicit positional inputs of the function. Implicit inputs, such as module parameters, do not get gradients. Conversely, non-reentrant will propagate gradients to implicit inputs, but is not compatible with foward AD.
- azula.nn.utils.get_module_device(module)¶[source]
Returns the execution device of a module.
The module’s device is the first device in the module’s parameters or buffers. If there is none, returns
None.- Parameters:
module (Module) – A module.
- azula.nn.utils.get_module_dtype(module)¶[source]
Returns the data type of a module.
The module’s data type is the first floating-point type in the module’s parameters or buffers. If there is none, returns
None.- Parameters:
module (Module) – A module.