Skip to content

checkpoint

optimus_dl.modules.model_transforms.checkpoint

Activation checkpointing (gradient checkpointing) transform using public PyTorch API.

ActivationCheckpointConfig dataclass

Bases: ModelTransformConfig

Configuration for activation checkpointing.

Attributes:

Name Type Description

Parameters:

Name Type Description Default
layer_classes list[str] | None
None
use_reentrant bool
False
ops_to_save list[str] | None
None
Source code in optimus_dl/modules/model_transforms/checkpoint.py
@dataclass
class ActivationCheckpointConfig(ModelTransformConfig):
    """Configuration for activation checkpointing.

    Attributes:
        layer_classes: List of layer class names to wrap (e.g., ["LlamaBlock"]).
        use_reentrant: If True, uses the legacy reentrant checkpointing. False
            is recommended for modern PyTorch and FSDP integration.
        ops_to_save: Optional list of specific operations to always save (not recompute).
    """

    # List of layer class names to wrap (e.g. ["LlamaBlock", "GPTBlock"])
    layer_classes: list[str] | None = None

    # Whether to use reentrant checkpointing.
    # False is generally recommended for newer PyTorch versions and FSDP.
    use_reentrant: bool = False

    ops_to_save: list[str] | None = None

ActivationCheckpointTransform

Bases: BaseModelTransform

Transform that injects activation checkpointing into a model.

Recursively searches the model for instances of specified layer_classes and wraps them with CheckpointWrapper. This is a crucial optimization for fitting large models or long sequences into GPU memory.

Parameters:

Name Type Description Default
cfg ActivationCheckpointConfig

Activation checkpointing configuration.

required
Source code in optimus_dl/modules/model_transforms/checkpoint.py
@register_model_transform("activation_checkpoint", ActivationCheckpointConfig)
class ActivationCheckpointTransform(BaseModelTransform):
    """Transform that injects activation checkpointing into a model.

    Recursively searches the model for instances of specified `layer_classes`
    and wraps them with `CheckpointWrapper`. This is a crucial optimization for
    fitting large models or long sequences into GPU memory.

    Args:
        cfg: Activation checkpointing configuration.
    """

    def __init__(
        self,
        cfg: ActivationCheckpointConfig,
        **kwargs: Any,
    ):
        super().__init__(cfg, **kwargs)

    def apply(self, model: BaseModel, **kwargs) -> BaseModel:
        """Find and wrap target layers in the model."""
        logger.info("Applying activation checkpointing (torch.utils.checkpoint)")

        if not self.cfg.layer_classes:
            logger.warning(
                "No layer classes specified for activation checkpointing. "
                "Please specify 'layer_classes' in the config (e.g. ['LlamaBlock'])."
            )
            return model

        target_classes = set(self.cfg.layer_classes)
        ops_to_save = [
            eval(op, {"__builtins__": None}, {"torch": torch})
            for op in (self.cfg.ops_to_save or [])
        ]
        replaced_count = self._replace_modules(
            model,
            target_classes,
            use_reentrant=self.cfg.use_reentrant,
            ops_to_save=ops_to_save,
        )

        if replaced_count == 0:
            logger.warning(f"No modules matching {target_classes} found to checkpoint.")
        else:
            logger.info(
                f"Applied activation checkpointing to {replaced_count} layers of types: {target_classes}"
            )

        return model

    def _replace_modules(
        self,
        model: nn.Module,
        target_classes: set,
        use_reentrant: bool,
        ops_to_save: list,
    ) -> int:
        """Recursively replace target modules with CheckpointWrapper."""
        count = 0
        for name, child in model.named_children():
            if child.__class__.__name__ in target_classes:
                # Replace the module
                logger.debug(f"Wrapping {name} ({child.__class__.__name__})")
                wrapped_child = CheckpointWrapper(
                    child, use_reentrant=use_reentrant, ops_to_save=ops_to_save
                )
                setattr(model, name, wrapped_child)
                count += 1
            else:
                # Recurse
                count += self._replace_modules(
                    child, target_classes, use_reentrant, ops_to_save
                )
        return count

apply(model, **kwargs)

Find and wrap target layers in the model.

Source code in optimus_dl/modules/model_transforms/checkpoint.py
def apply(self, model: BaseModel, **kwargs) -> BaseModel:
    """Find and wrap target layers in the model."""
    logger.info("Applying activation checkpointing (torch.utils.checkpoint)")

    if not self.cfg.layer_classes:
        logger.warning(
            "No layer classes specified for activation checkpointing. "
            "Please specify 'layer_classes' in the config (e.g. ['LlamaBlock'])."
        )
        return model

    target_classes = set(self.cfg.layer_classes)
    ops_to_save = [
        eval(op, {"__builtins__": None}, {"torch": torch})
        for op in (self.cfg.ops_to_save or [])
    ]
    replaced_count = self._replace_modules(
        model,
        target_classes,
        use_reentrant=self.cfg.use_reentrant,
        ops_to_save=ops_to_save,
    )

    if replaced_count == 0:
        logger.warning(f"No modules matching {target_classes} found to checkpoint.")
    else:
        logger.info(
            f"Applied activation checkpointing to {replaced_count} layers of types: {target_classes}"
        )

    return model

CheckpointWrapper

Bases: Module

Module wrapper that applies activation checkpointing to its child.

During the forward pass, this module uses torch.utils.checkpoint.checkpoint to trade compute for memory: activations are discarded after the forward pass and recomputed during the backward pass.

Source code in optimus_dl/modules/model_transforms/checkpoint.py
class CheckpointWrapper(nn.Module):
    """Module wrapper that applies activation checkpointing to its child.

    During the forward pass, this module uses `torch.utils.checkpoint.checkpoint`
    to trade compute for memory: activations are discarded after the forward
    pass and recomputed during the backward pass.
    """

    def __init__(
        self, module: nn.Module, ops_to_save: list, use_reentrant: bool = False
    ):
        super().__init__()
        self.module = module
        self.use_reentrant = use_reentrant

        def policy_fn(_, op, *__, **___):
            if op in ops_to_save:
                return CheckpointPolicy.MUST_SAVE
            else:
                return CheckpointPolicy.PREFER_RECOMPUTE

        self.policy_fn = policy_fn

    def forward(self, *args, **kwargs):
        """Forward pass with activation checkpointing."""
        # torch.utils.checkpoint.checkpoint requires a function as the first argument.
        # We pass self.module (which is callable).
        # Note: 'use_reentrant' argument is available in modern PyTorch.
        return checkpoint(
            self.module,
            *args,
            use_reentrant=self.use_reentrant,
            context_fn=functools.partial(
                create_selective_checkpoint_contexts, self.policy_fn
            ),
            **kwargs,
        )

forward(*args, **kwargs)

Forward pass with activation checkpointing.

Source code in optimus_dl/modules/model_transforms/checkpoint.py
def forward(self, *args, **kwargs):
    """Forward pass with activation checkpointing."""
    # torch.utils.checkpoint.checkpoint requires a function as the first argument.
    # We pass self.module (which is callable).
    # Note: 'use_reentrant' argument is available in modern PyTorch.
    return checkpoint(
        self.module,
        *args,
        use_reentrant=self.use_reentrant,
        context_fn=functools.partial(
            create_selective_checkpoint_contexts, self.policy_fn
        ),
        **kwargs,
    )