azula.nn.utils¶
Miscellaneous neural network helpers.
Classes¶
Creates a context in which weight initialization is skipped. |
Functions¶
Returns the data type of a module. |
|
Returns the execution device of a module. |
|
Applies activation checkpointing to a function. |
|
Applies data type promotion to a function. |
Descriptions¶
- azula.nn.utils.get_module_dtype(module)¶
Returns the data type of a module.
The module’s data type is the first floating-point type in the module’s parameters or buffers. If there is none, returns
None.- Parameters:
module (Module) – A module.
- azula.nn.utils.get_module_device(module)¶
Returns the execution device of a module.
The module’s device is the first device in the module’s parameters or buffers. If there is none, returns
None.- Parameters:
module (Module) – A module.
- azula.nn.utils.checkpoint(f, reentrant=False)¶
Applies activation checkpointing to a function.
Activation checkpointing reduces memory consumption by storing the inputs of the function and recomputing its graph during automatic differentiation (AD).
Reentrant checkpointing is compatible with backward and forward AD, but only propagates gradients to the explicit positional inputs of the function. Implicit inputs, such as module parameters, do not get gradients. Conversely, non-reentrant will propagate gradients to implicit inputs, but is not compatible with foward AD.
- class azula.nn.utils.skip_init¶
Creates a context in which weight initialization is skipped.
Example
>>> with skip_init(): ... layer = nn.Linear(3, 5)