Skip to content

config

optimus_dl.recipe.train.config

Training recipe configuration.

This module defines the configuration classes for the training recipe, including all hyperparameters, component configurations, and training settings.

TrainConfig dataclass

Bases: RegistryConfigStrict

Complete training configuration.

This is the root configuration class for training. It contains all component configurations (model, data, optimizer, etc.) and uses the registry system for flexible component selection.

The configuration is hierarchical and supports OmegaConf interpolation for sharing values across components. The args field serves as a "scratch space" for high-level variables that can be referenced throughout the config.

Example
config = TrainConfig(
    _name="base",
    args={"batch_size": 64, "seq_len": 1024},
    model=ModelConfig(_name="llama", n_embd=512),
    optimization=OptimizationConfig(
        batch_size="${args.batch_size}",
        lr=1e-4,
    ),
)

Parameters:

Name Type Description Default
args dict

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

<class 'dict'>
common TrainRecipeConfig

Configuration for training recipe common settings.

This class contains all the common settings shared across training runs, including experiment metadata, logging frequency, checkpointing, evaluation, and distributed training settings.

<dynamic>
model ModelConfig
'???'
data DataConfig
'???'
criterion CriterionConfig
'???'
optimization OptimizationConfig
'???'
lr_scheduler RegistryConfig | None
None
loggers list[MetricsLoggerConfig] | None

List of metrics logger configurations

None
metrics dict[str, list[dict]]

Metric configurations mapped by dataset name (e.g., 'train', 'val_slice_1')

<class 'dict'>
model_transforms list[ModelTransformConfig]

List of model transforms to apply after model building

<dynamic>
model_builder RegistryConfig
ModelBuilderConfig(_name='base')
optimizer_builder RegistryConfig
OptimizerBuilderConfig(_name='base')
criterion_builder RegistryConfig
CriterionBuilderConfig(_name='base')
data_builder RegistryConfig
DataBuilderConfig(_name='base')
scheduler_builder RegistryConfig
SchedulerBuilderConfig(_name='base')
logger_manager RegistryConfig
LoggerManagerConfig(_name='base')
checkpoint_manager RegistryConfig
CheckpointManagerConfig(_name='base')
evaluator RegistryConfig
EvaluatorConfig(_name='base', amp=AmpConfig(enabled=False, dtype='torch.bfloat16', enable_scaler='${eval: \'"${.dtype}" == "torch.float16"\'}', init_scale=65536, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000))
Source code in optimus_dl/recipe/train/config.py
@dataclass
class TrainConfig(RegistryConfigStrict):
    """Complete training configuration.

    This is the root configuration class for training. It contains all component
    configurations (model, data, optimizer, etc.) and uses the registry system
    for flexible component selection.

    The configuration is hierarchical and supports OmegaConf interpolation for
    sharing values across components. The `args` field serves as a "scratch space"
    for high-level variables that can be referenced throughout the config.

    Example:
        ```python
        config = TrainConfig(
            _name="base",
            args={"batch_size": 64, "seq_len": 1024},
            model=ModelConfig(_name="llama", n_embd=512),
            optimization=OptimizationConfig(
                batch_size="${args.batch_size}",
                lr=1e-4,
            ),
        )

        ```"""

    args: dict = field(default_factory=dict)
    common: TrainRecipeConfig = field(default_factory=TrainRecipeConfig)

    model: ModelConfig = field(default=MISSING)
    data: DataConfig = field(default=MISSING)
    criterion: CriterionConfig = field(default=MISSING)
    optimization: OptimizationConfig = field(default=MISSING)
    lr_scheduler: RegistryConfig | None = field(default=None)

    # Metrics logging configuration
    loggers: list[MetricsLoggerConfig] | None = field(
        default=None, metadata={"description": "List of metrics logger configurations"}
    )

    # Metrics configuration for MetricEngine, mapped by dataset name (e.g. 'train', 'val')
    metrics: dict[str, list[dict]] = field(
        default_factory=dict,
        metadata={
            "description": "Metric configurations mapped by dataset name (e.g., 'train', 'val_slice_1')"
        },
    )

    # Model transforms configuration
    model_transforms: list[ModelTransformConfig] = field(
        default_factory=list,
        metadata={
            "description": "List of model transforms to apply after model building"
        },
    )

    # Dependency Injection Configs
    model_builder: RegistryConfig = field(
        default_factory=lambda: ModelBuilderConfig(_name="base")
    )
    optimizer_builder: RegistryConfig = field(
        default_factory=lambda: OptimizerBuilderConfig(_name="base")
    )
    criterion_builder: RegistryConfig = field(
        default_factory=lambda: CriterionBuilderConfig(_name="base")
    )
    data_builder: RegistryConfig = field(
        default_factory=lambda: DataBuilderConfig(_name="base")
    )
    scheduler_builder: RegistryConfig = field(
        default_factory=lambda: SchedulerBuilderConfig(_name="base")
    )
    logger_manager: RegistryConfig = field(
        default_factory=lambda: LoggerManagerConfig(_name="base")
    )
    checkpoint_manager: RegistryConfig = field(
        default_factory=lambda: CheckpointManagerConfig(_name="base")
    )
    evaluator: RegistryConfig = field(
        default_factory=lambda: EvaluatorConfig(_name="base")
    )

TrainRecipeConfig dataclass

Configuration for training recipe common settings.

This class contains all the common settings shared across training runs, including experiment metadata, logging frequency, checkpointing, evaluation, and distributed training settings.

Parameters:

Name Type Description Default
exp_name str

Experiment name

'optimus-dl-run-${config_hash:}'
exp_description str | None

Experiment description

None
exp_tags list[str]

Experiment tags

<dynamic>
log_freq int

Frequency of train metrics logging

16
seed int

Seed to seed everything that's possible

42
data_seed int

Seed to seed everything data-related. Will be different on each rank.

42
deterministic bool

If True, force deterministic algorithms in PyTorch.

True
eval_iterations int | None

Max number of iterations of validation data for every subset

None
eval_freq int

Frequency of evaluations. Zero disables

100
eval_guaranteed_same_batches bool

Whether it is guaranteed that each DP rank sees the same number of batches during evaluation.

False
eval_checkpointing int | None

Frequency of saving checkpoints during evaluation. If None or non-positive (for example, 0), do not save checkpoints during evaluation. This is useful for long evaluations to be able to resume evaluation if it gets interrupted. Saves are fast and light, as they only contain the state of the meters and dataloader, not the model or optimizer states.

None
eval_resumable bool

Whether to make evaluation resumable by saving a checkpoint before evaluation starts. If True, a full checkpoint is saved before evaluation and a metadata-only checkpoint is saved after evaluation completes. This ensures that if evaluation is interrupted, it can be resumed from the same iteration without re-running training. Related: eval_checkpointing should be set to a positive integer to save checkpoints during evaluation, which allows resuming evaluation even if it gets interrupted in the middle of evaluation. If eval_checkpointing is None or non-positive, no checkpoints will be saved during evaluation, and if evaluation gets interrupted, it will have to be restarted from the beginning of evaluation.

True
save_freq int

Frequency of checkpoint savings. As eval_freq by default

'${.eval_freq}'
last_save_freq int | None

Frequency of saving last checkpoint. As save_freq by default

None
output_path str

Directory to dump checkpoints to

"${oc.env:PERSISTENT_PATH,'./outputs'}/${.exp_name}"
load_checkpoint str | None

Path to checkpoint to load from, what to load from it is controlled by load_checkpoint_strategy

None
load_checkpoint_strategy LoadStrategy

Strategy what to load from the checkpoint

<dynamic>
use_gpu bool
True
distributed DistributedConfig

Distributed training configuration (GPU, TP, etc.)

<dynamic>
Source code in optimus_dl/recipe/train/config.py
@dataclass
class TrainRecipeConfig:
    """Configuration for training recipe common settings.

    This class contains all the common settings shared across training runs,
    including experiment metadata, logging frequency, checkpointing, evaluation,
    and distributed training settings.
    """

    # Exp metadata
    exp_name: str = field(
        default="optimus-dl-run-${config_hash:}",
        metadata={"description": "Experiment name"},
    )
    exp_description: str | None = field(
        default=None, metadata={"description": "Experiment description"}
    )
    exp_tags: list[str] = field(
        default_factory=list, metadata={"description": "Experiment tags"}
    )
    log_freq: int = field(
        default=16, metadata={"description": "Frequency of train metrics logging"}
    )

    # Reproducibility
    seed: int = field(
        default=42, metadata={"description": "Seed to seed everything that's possible"}
    )
    data_seed: int = field(
        default=42,
        metadata={
            "description": "Seed to seed everything data-related. Will be different on each rank."
        },
    )
    deterministic: bool = field(
        default=True,
        metadata={"description": "If True, force deterministic algorithms in PyTorch."},
    )

    # Evaluation
    eval_iterations: int | None = field(
        default=None,
        metadata={
            "description": "Max number of iterations of validation data for every subset"
        },
    )
    eval_freq: int = field(
        default=100, metadata={"description": "Frequency of evaluations. Zero disables"}
    )
    eval_guaranteed_same_batches: bool = field(
        default=False,
        metadata={
            "description": "Whether it is guaranteed that each DP rank sees the same number of batches during evaluation."
        },
    )
    eval_checkpointing: int | None = field(
        default=None,
        metadata={
            "description": "Frequency of saving checkpoints during evaluation. If None or non-positive (for example, 0), "
            "do not save checkpoints during evaluation. This is useful for long evaluations to be able to resume "
            "evaluation if it gets interrupted. "
            "Saves are fast and light, as they only contain the state of the meters and dataloader, not the model or optimizer states."
        },
    )
    eval_resumable: bool = field(
        default=True,
        metadata={
            "description": "Whether to make evaluation resumable by saving a checkpoint before evaluation starts. "
            "If True, a full checkpoint is saved before evaluation and a metadata-only checkpoint is saved after "
            "evaluation completes. This ensures that if evaluation is interrupted, it can be resumed from the same "
            "iteration without re-running training. "
            "**Related**: `eval_checkpointing` should be set to a positive integer to save checkpoints during evaluation, "
            "which allows resuming evaluation even if it gets interrupted in the middle of evaluation. "
            "If `eval_checkpointing` is None or non-positive, no checkpoints will be saved during evaluation, "
            "and if evaluation gets interrupted, it will have to be restarted from the beginning of evaluation."
        },
    )

    # Checkpointing
    save_freq: int = field(
        default=II(".eval_freq"),
        metadata={
            "description": "Frequency of checkpoint savings. As eval_freq by default"
        },
    )
    last_save_freq: int | None = field(
        default=None,
        metadata={
            "description": "Frequency of saving last checkpoint. As save_freq by default"
        },
    )
    output_path: str = field(
        default="${oc.env:PERSISTENT_PATH,'./outputs'}/${.exp_name}",
        metadata={"description": "Directory to dump checkpoints to"},
    )

    load_checkpoint: str | None = field(
        default=None,
        metadata={
            "description": "Path to checkpoint to load from, what to load from it is controlled by load_checkpoint_strategy"
        },
    )
    load_checkpoint_strategy: LoadStrategy = field(
        default_factory=LoadStrategy,
        metadata={"description": "Strategy what to load from the checkpoint"},
    )

    # Distributed
    use_gpu: bool = True
    distributed: DistributedConfig = field(
        default_factory=DistributedConfig,
        metadata={"description": "Distributed training configuration (GPU, TP, etc.)"},
    )