compile
optimus_dl.modules.model_transforms.compile
¶
CompileTransform
¶
Bases: BaseModelTransform
Model transform that applies torch.compile to the entire model.
Graph compilation can significantly improve performance by fusing kernels and reducing CPU overhead. It should typically be the last transform applied.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cfg
|
Any
|
Compilation configuration. |
None
|
Source code in optimus_dl/modules/model_transforms/compile.py
apply(model, **kwargs)
¶
Apply torch.compile to the model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
BaseModel
|
The model to compile. |
required |
**kwargs
|
Unused. |
{}
|
Returns:
| Type | Description |
|---|---|
BaseModel
|
The compiled model wrapper. |
Source code in optimus_dl/modules/model_transforms/compile.py
CompileTransformConfig
dataclass
¶
Bases: ModelTransformConfig
Configuration for torch.compile model transform.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
compile_kwargs
|
dict
|
Arguments for torch.compile. See https://pytorch.org/docs/stable/generated/torch.compile.html |
<class 'dict'>
|
activation_memory_budget
|
float
|
Activation memory budget for torch.compile. See https://pytorch.org/blog/activation-checkpointing-techniques/ |
1.0
|