base
optimus_dl.modules.model_transforms.base
¶
BaseModelTransform
¶
Bases: ABC
Abstract base class for all model transformations.
Model transforms are applied after the model is built but before training begins. They modify the model's structure, wrapping it with distributed wrappers (DDP, FSDP), applying graph compilation (torch.compile), or injecting activation checkpointing.
Transforms are registered in the model_transform registry and can be
chained together in the configuration.
Source code in optimus_dl/modules/model_transforms/base.py
__init__(cfg=None, **kwargs)
¶
Initialize the transform.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cfg
|
Any
|
Configuration object for the transform. |
None
|
**kwargs
|
Additional keyword arguments. |
{}
|
apply(model, **kwargs)
abstractmethod
¶
Apply the transformation to the given model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
BaseModel
|
The model to transform. |
required |
**kwargs
|
Additional arguments (e.g., collective, device). |
{}
|
Returns:
| Type | Description |
|---|---|
BaseModel
|
The transformed model (either modified in-place or a new wrapper). |