Skip to content

base

optimus_dl.modules.model_transforms.base

BaseModelTransform

Bases: ABC

Abstract base class for all model transformations.

Model transforms are applied after the model is built but before training begins. They modify the model's structure, wrapping it with distributed wrappers (DDP, FSDP), applying graph compilation (torch.compile), or injecting activation checkpointing.

Transforms are registered in the model_transform registry and can be chained together in the configuration.

Source code in optimus_dl/modules/model_transforms/base.py
class BaseModelTransform(ABC):
    """Abstract base class for all model transformations.

    Model transforms are applied after the model is built but before training
    begins. They modify the model's structure, wrapping it with distributed
    wrappers (DDP, FSDP), applying graph compilation (torch.compile), or
    injecting activation checkpointing.

    Transforms are registered in the `model_transform` registry and can be
    chained together in the configuration.
    """

    def __init__(self, cfg: Any = None, **kwargs):
        """Initialize the transform.

        Args:
            cfg: Configuration object for the transform.
            **kwargs: Additional keyword arguments.
        """
        self.cfg = cfg

    @abstractmethod
    def apply(self, model: BaseModel, **kwargs) -> BaseModel:
        """Apply the transformation to the given model.

        Args:
            model: The model to transform.
            **kwargs: Additional arguments (e.g., collective, device).

        Returns:
            The transformed model (either modified in-place or a new wrapper).
        """
        pass

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(cfg={self.cfg})"

__init__(cfg=None, **kwargs)

Initialize the transform.

Parameters:

Name Type Description Default
cfg Any

Configuration object for the transform.

None
**kwargs

Additional keyword arguments.

{}
Source code in optimus_dl/modules/model_transforms/base.py
def __init__(self, cfg: Any = None, **kwargs):
    """Initialize the transform.

    Args:
        cfg: Configuration object for the transform.
        **kwargs: Additional keyword arguments.
    """
    self.cfg = cfg

apply(model, **kwargs) abstractmethod

Apply the transformation to the given model.

Parameters:

Name Type Description Default
model BaseModel

The model to transform.

required
**kwargs

Additional arguments (e.g., collective, device).

{}

Returns:

Type Description
BaseModel

The transformed model (either modified in-place or a new wrapper).

Source code in optimus_dl/modules/model_transforms/base.py
@abstractmethod
def apply(self, model: BaseModel, **kwargs) -> BaseModel:
    """Apply the transformation to the given model.

    Args:
        model: The model to transform.
        **kwargs: Additional arguments (e.g., collective, device).

    Returns:
        The transformed model (either modified in-place or a new wrapper).
    """
    pass