context_mixin
optimus_dl.recipe.train.mixins.execution.context_mixin
¶
Training context mixin for AMP and gradient scaler setup.
TrainingContextMixin
¶
Mixin for setting up the training context (precision, scaling, devices).
Responsible for initializing PyTorch's AMP (Automatic Mixed Precision) and GradScaler based on the optimization configuration. This ensures consistent precision settings across the training loop.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
optimization_config
|
OptimizationConfig
|
Configuration containing AMP settings. |
required |
Source code in optimus_dl/recipe/train/mixins/execution/context_mixin.py
setup_training_context(device)
¶
Initialize AMP context and Gradient Scaler.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
device
|
device
|
The target compute device. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
A dictionary containing: |
dict[str, Any]
|
|
dict[str, Any]
|
|
dict[str, Any]
|
|
dict[str, Any]
|
|