Skip to content

compile

optimus_dl.modules.model_transforms.compile

CompileTransform

Bases: BaseModelTransform

Model transform that applies torch.compile to the entire model.

Graph compilation can significantly improve performance by fusing kernels and reducing CPU overhead. It should typically be the last transform applied.

Parameters:

Name Type Description Default
cfg Any

Compilation configuration.

None
Source code in optimus_dl/modules/model_transforms/compile.py
@register_model_transform("compile", CompileTransformConfig)
class CompileTransform(BaseModelTransform):
    """Model transform that applies torch.compile to the entire model.

    Graph compilation can significantly improve performance by fusing kernels
    and reducing CPU overhead. It should typically be the last transform
    applied.

    Args:
        cfg: Compilation configuration.
    """

    def apply(self, model: BaseModel, **kwargs) -> BaseModel:
        """Apply torch.compile to the model.

        Args:
            model: The model to compile.
            **kwargs: Unused.

        Returns:
            The compiled model wrapper.
        """
        import torch._functorch.config

        compile_kwargs = self.cfg.compile_kwargs if self.cfg else {}
        torch._functorch.config.activation_memory_budget = (
            self.cfg.activation_memory_budget
        )

        logger.info(f"Applying torch.compile with args: {compile_kwargs}")
        model = torch.compile(model, **compile_kwargs)

        return model

apply(model, **kwargs)

Apply torch.compile to the model.

Parameters:

Name Type Description Default
model BaseModel

The model to compile.

required
**kwargs

Unused.

{}

Returns:

Type Description
BaseModel

The compiled model wrapper.

Source code in optimus_dl/modules/model_transforms/compile.py
def apply(self, model: BaseModel, **kwargs) -> BaseModel:
    """Apply torch.compile to the model.

    Args:
        model: The model to compile.
        **kwargs: Unused.

    Returns:
        The compiled model wrapper.
    """
    import torch._functorch.config

    compile_kwargs = self.cfg.compile_kwargs if self.cfg else {}
    torch._functorch.config.activation_memory_budget = (
        self.cfg.activation_memory_budget
    )

    logger.info(f"Applying torch.compile with args: {compile_kwargs}")
    model = torch.compile(model, **compile_kwargs)

    return model

CompileTransformConfig dataclass

Bases: ModelTransformConfig

Configuration for torch.compile model transform.

Parameters:

Name Type Description Default
compile_kwargs dict

Arguments for torch.compile. See https://pytorch.org/docs/stable/generated/torch.compile.html

<class 'dict'>
activation_memory_budget float

Activation memory budget for torch.compile. See https://pytorch.org/blog/activation-checkpointing-techniques/

1.0
Source code in optimus_dl/modules/model_transforms/compile.py
@dataclass
class CompileTransformConfig(ModelTransformConfig):
    """Configuration for torch.compile model transform."""

    compile_kwargs: dict = field(
        default_factory=dict,
        metadata={
            "description": "Arguments for torch.compile. See https://pytorch.org/docs/stable/generated/torch.compile.html"
        },
    )
    activation_memory_budget: float = field(
        default=1.0,
        metadata={
            "description": "Activation memory budget for torch.compile. See https://pytorch.org/blog/activation-checkpointing-techniques/"
        },
    )