checkpoint
optimus_dl.modules.model_transforms.checkpoint
¶
Activation checkpointing (gradient checkpointing) transform using public PyTorch API.
ActivationCheckpointConfig
dataclass
¶
Bases: ModelTransformConfig
Configuration for activation checkpointing.
Attributes:
| Name | Type | Description |
|---|
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
layer_classes
|
list[str] | None
|
|
None
|
use_reentrant
|
bool
|
|
False
|
ops_to_save
|
list[str] | None
|
|
None
|
Source code in optimus_dl/modules/model_transforms/checkpoint.py
ActivationCheckpointTransform
¶
Bases: BaseModelTransform
Transform that injects activation checkpointing into a model.
Recursively searches the model for instances of specified layer_classes
and wraps them with CheckpointWrapper. This is a crucial optimization for
fitting large models or long sequences into GPU memory.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cfg
|
ActivationCheckpointConfig
|
Activation checkpointing configuration. |
required |
Source code in optimus_dl/modules/model_transforms/checkpoint.py
apply(model, **kwargs)
¶
Find and wrap target layers in the model.
Source code in optimus_dl/modules/model_transforms/checkpoint.py
CheckpointWrapper
¶
Bases: Module
Module wrapper that applies activation checkpointing to its child.
During the forward pass, this module uses torch.utils.checkpoint.checkpoint
to trade compute for memory: activations are discarded after the forward
pass and recomputed during the backward pass.
Source code in optimus_dl/modules/model_transforms/checkpoint.py
forward(*args, **kwargs)
¶
Forward pass with activation checkpointing.