Skip to content

base

optimus_dl.recipe.train.base

TrainRecipe

Bases: TrainingContextMixin, TrainingIterationMixin, TrainingInterruptionMixin

Main training recipe that orchestrates all training components.

This class uses composition to coordinate specialized builders and managers:

  • ModelBuilder: Builds models and applies transforms
  • OptimizerBuilder: Builds optimizers
  • CriterionBuilder: Builds loss criteria
  • DataBuilder: Builds train/eval data pipelines
  • SchedulerBuilder: Builds learning rate schedulers
  • LoggerManager: Handles metrics logging setup and operations
  • CheckpointManager: Manages checkpoint saving and loading
  • Evaluator: Handles evaluation runs and metrics

It inherits from training logic mixins for the core loop execution:

  • TrainingContextMixin: Sets up training context (AMP, scaler, etc.)
  • TrainingIterationMixin: Orchestrates complete training iterations
  • TrainingInterruptionMixin: Handles training interruptions and errors
Source code in optimus_dl/recipe/train/base.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
@register_train_recipe("base", TrainConfig)
class TrainRecipe(
    TrainingContextMixin,
    TrainingIterationMixin,
    TrainingInterruptionMixin,
):
    """Main training recipe that orchestrates all training components.

    This class uses composition to coordinate specialized builders and managers:

    - ModelBuilder: Builds models and applies transforms
    - OptimizerBuilder: Builds optimizers
    - CriterionBuilder: Builds loss criteria
    - DataBuilder: Builds train/eval data pipelines
    - SchedulerBuilder: Builds learning rate schedulers
    - LoggerManager: Handles metrics logging setup and operations
    - CheckpointManager: Manages checkpoint saving and loading
    - Evaluator: Handles evaluation runs and metrics

    It inherits from training logic mixins for the core loop execution:

    - TrainingContextMixin: Sets up training context (AMP, scaler, etc.)
    - TrainingIterationMixin: Orchestrates complete training iterations
    - TrainingInterruptionMixin: Handles training interruptions and errors
    """

    cfg: TrainConfig

    def __init__(self, cfg: TrainConfig) -> None:
        self.cfg = cfg

        # Initialize components via composition using registry for dependency injection
        self.model_builder = build_component(
            "model_builder",
            cfg.model_builder,
            cast_to=ModelBuilder,
            model_transforms=cfg.model_transforms,
        )
        assert self.model_builder is not None, "Model builder not initialized"
        self.optimizer_builder = build_component(
            "optimizer_builder",
            cfg.optimizer_builder,
            cast_to=OptimizerBuilder,
            optimization_config=cfg.optimization,
        )
        assert self.optimizer_builder is not None, "Optimizer builder not initialized"
        self.criterion_builder = build_component(
            "criterion_builder",
            cfg.criterion_builder,
            cast_to=CriterionBuilder,
            criterion_config=cfg.criterion,
        )
        assert self.criterion_builder is not None, "Criterion builder not initialized"
        self.data_builder = build_component(
            "data_builder",
            cfg.data_builder,
            cast_to=DataBuilder,
            data_config=cfg.data,
            data_seed=cfg.common.data_seed,
        )
        assert self.data_builder is not None, "Data builder not initialized"
        self.scheduler_builder = build_component(
            "scheduler_builder",
            cfg.scheduler_builder,
            cast_to=SchedulerBuilder,
            lr_scheduler_config=cfg.lr_scheduler,
            optimization_config=cfg.optimization,
        )
        assert self.scheduler_builder is not None, "Scheduler builder not initialized"
        self.logger_manager = build_component(
            "logger_manager",
            cfg.logger_manager,
            cast_to=LoggerManager,
            loggers_config=cfg.loggers,
        )
        assert self.logger_manager is not None, "Logger manager not initialized"
        self.checkpoint_manager = build_component(
            "checkpoint_manager",
            cfg.checkpoint_manager,
            cast_to=CheckpointManager,
        )
        assert self.checkpoint_manager is not None, "Checkpoint manager not initialized"
        self.evaluator = build_component(
            "evaluator",
            cfg.evaluator,
            cast_to=Evaluator,
            eval_freq=cfg.common.eval_freq,
            eval_iterations=cfg.common.eval_iterations,
        )
        assert self.evaluator is not None, "Evaluator not initialized"

        # Initialize training logic mixins
        TrainingContextMixin.__init__(self, cfg.optimization)
        TrainingIterationMixin.__init__(self, cfg.optimization, cfg.common.log_freq)
        TrainingInterruptionMixin.__init__(
            self,
            cfg.common.save_freq,
            cfg.common.output_path,
            self.save_checkpoint,  # Pass the checkpoint method as callback
        )
        self.validate_config()

    # Delegate methods
    def build_model(self, *args, **kwargs) -> BaseModel:
        """Delegate to ModelBuilder."""
        return self.model_builder.build_model(*args, **kwargs)

    def build_optimizer(self, *args, **kwargs) -> Optimizer:
        """Delegate to OptimizerBuilder."""
        return self.optimizer_builder.build_optimizer(*args, **kwargs)

    def build_lr_scheduler(self, *args, **kwargs):
        """Delegate to SchedulerBuilder."""
        return self.scheduler_builder.build_lr_scheduler(*args, **kwargs)

    def build_criterion(self, *args, **kwargs) -> BaseCriterion:
        """Delegate to CriterionBuilder."""
        return self.criterion_builder.build_criterion(*args, **kwargs)

    def build_train_data(self, *args, **kwargs):
        """Delegate to DataBuilder for training data."""
        return self.data_builder.build_train_data(*args, **kwargs)

    def build_eval_data(self, *args, **kwargs):
        """Delegate to DataBuilder for evaluation data."""
        return self.data_builder.build_eval_data(*args, **kwargs)

    def build_loggers(self, *args, **kwargs):
        """Delegate to LoggerManager for building loggers."""
        return self.logger_manager.build_loggers(*args, **kwargs)

    def setup_loggers(self, experiment_name: str, full_config: dict | None = None):
        """Setup logging with experiment configuration.

        Args:
            experiment_name: Name of the current experiment.
            full_config: Full configuration dictionary. If None, uses self.cfg.
        """
        if full_config is None:
            full_config = self.cfg if hasattr(self.cfg, "__dict__") else dict(self.cfg)
        return self.logger_manager.setup_loggers(experiment_name, full_config)

    def log_metrics_to_loggers(self, *args, **kwargs):
        """Delegate metrics logging to LoggerManager."""
        return self.logger_manager.log_metrics_to_loggers(*args, **kwargs)

    def close_loggers(self, *args, **kwargs):
        """Delegate logger cleanup to LoggerManager."""
        return self.logger_manager.close_loggers(*args, **kwargs)

    def save_checkpoint_if_needed(self, *args, **kwargs):
        """Check save frequency and delegate to CheckpointManager."""
        config_dict = self.cfg if hasattr(self.cfg, "__dict__") else dict(self.cfg)
        kwargs["full_config"] = config_dict
        kwargs["checkpoint_path"] = self.cfg.common.output_path
        kwargs["save_freq"] = self.cfg.common.save_freq
        kwargs["last_save_freq"] = self.cfg.common.last_save_freq
        kwargs["logger_manager"] = self.logger_manager
        return self.checkpoint_manager.save_checkpoint_if_needed(*args, **kwargs)

    def save_checkpoint(self, *args, **kwargs):
        """Save a checkpoint via CheckpointManager."""
        config_dict = self.cfg if hasattr(self.cfg, "__dict__") else dict(self.cfg)
        kwargs["full_config"] = config_dict
        if "checkpoint_path" not in kwargs:
            kwargs["checkpoint_path"] = self.cfg.common.output_path
        if "logger_manager" not in kwargs:
            kwargs["logger_manager"] = self.logger_manager
        return self.checkpoint_manager.save_checkpoint(*args, **kwargs)

    def load_checkpoint_if_exists(self, *args, **kwargs):
        """Try to resume from latest checkpoint in output path."""
        if "checkpoint_path" not in kwargs:
            kwargs["checkpoint_path"] = self.cfg.common.output_path
        if "logger_manager" not in kwargs:
            kwargs["logger_manager"] = self.logger_manager
        return self.checkpoint_manager.load_checkpoint_if_exists(*args, **kwargs)

    def load_checkpoint(self, *args, **kwargs):
        """Load a specific checkpoint."""
        if "logger_manager" not in kwargs:
            kwargs["logger_manager"] = self.logger_manager
        return self.checkpoint_manager.load_checkpoint(*args, **kwargs)

    def run_evaluation_if_needed(self, *args, **kwargs):
        """Check eval frequency and run evaluation via Evaluator."""
        return self.evaluator.run_evaluation_if_needed(*args, **kwargs)

    def validate_config(self) -> None:
        """Validate configuration before training starts."""
        # Required fields
        assert self.cfg.model is not None, "Model configuration is required"
        assert self.cfg.data is not None, "Data configuration is required"
        assert self.cfg.criterion is not None, "Criterion configuration is required"
        assert (
            self.cfg.optimization is not None
        ), "Optimization configuration is required"

        # Training parameters
        assert self.cfg.optimization.iterations > 0, "iterations must be positive"
        assert self.cfg.optimization.acc_steps > 0, "acc_steps must be positive"
        assert self.cfg.common.log_freq > 0, "log_freq must be positive"

        if self.cfg.common.save_freq > 0:
            assert (
                self.cfg.common.output_path
            ), "output_path required when save_freq > 0"

        logger.info("Configuration validation passed")

    def setup_context(self):
        """Setup global training context (precision, etc.)."""
        set_seed(self.cfg.common.seed, deterministic=self.cfg.common.deterministic)
        torch.set_float32_matmul_precision("highest")
        if torch.cuda.is_available():
            torch.backends.cuda.matmul.fp32_precision = "ieee"
            if hasattr(torch.backends.cudnn, "conv"):
                torch.backends.cudnn.conv.fp32_precision = "ieee"

    def run(self):
        """Run the complete training pipeline."""
        self.setup_context()
        is_restart = self.checkpoint_manager.is_restart(self.cfg.common.output_path)

        with meters_group("init"):
            log_event_start("perf/init")
            logger.info(f"Using output path : {self.cfg.common.output_path}")
            logger.info(self.cfg)

            # Setup device and distributed collective
            device, collective = setup_device_and_collective(
                use_gpu=self.cfg.common.use_gpu, config=self.cfg.common.distributed
            )

            logger.info(
                "Setting up console logging. Will log from master only from now."
            )
            if not collective.is_master:
                setup_logging(logging.WARNING)

            model: BaseModel = self.build_model(
                model_config=self.cfg.model,
                collective=collective,
                is_restart=is_restart,
                checkpoint_manager=self.checkpoint_manager,
            )

            optimizer: Optimizer = self.build_optimizer(model.make_parameter_groups())
            lr_scheduler = self.build_lr_scheduler(optimizer)
            criterion: BaseCriterion = self.build_criterion(collective=collective)

            # Setup training context (AMP, scaler, etc.) using recipe mixin
            training_context = self.setup_training_context(device)

            try:
                train_datapipeline = self.build_train_data(
                    device=device, collective=collective
                )
                assert (
                    train_datapipeline is not None
                ), "Train data pipeline not initialized"
                eval_datapipeline = self.build_eval_data(
                    device=device, collective=collective
                )
                data_loaders = {
                    "train": train_datapipeline.dataloader,
                    # eval dataloader may be not restored
                }
            except Exception as e:
                logger.error(f"Failed to build data pipelines: {e}")
                raise

            model = model.to(device)

            # cannot be after checkpoint load as may erase the start event
            log_event_end("perf/init")

            # Try to resume from checkpoint in output paths
            start_iteration, metadata = self.load_checkpoint_if_exists(
                model=model,
                optimizer=optimizer,
                lr_scheduler=lr_scheduler,
                data_loaders=data_loaders,
                collective=collective,
                data_sources=train_datapipeline.datasets,
            )
            if is_restart:
                # cases when training run but did not produce any artifacts is
                # indistinguishable from the case when training is not started at all
                assert metadata is not None, "Misaligned is_restart flag"

            logger.info(f"Considering this run as {is_restart = }")
            if not is_restart and self.cfg.common.load_checkpoint is not None:
                # if checkpoint from output path was not loaded, we are sure that this launch is not
                # re-scheduling / preemption re-start, so we can try loading model from load_checkpoint
                metadata = self.load_checkpoint(
                    model=model,
                    optimizer=optimizer,
                    lr_scheduler=lr_scheduler,
                    data_loaders=data_loaders,
                    data_sources=train_datapipeline.datasets,
                    collective=collective,
                    load_strategy=self.cfg.common.load_checkpoint_strategy,
                    checkpoint_path=self.cfg.common.load_checkpoint,
                )
                start_iteration = metadata["iteration"] + 1
                logger.info(
                    "Loaded checkpoint from "
                    f"checkpoint_path = {self.cfg.common.load_checkpoint} path with "
                    f"load_strategy = {self.cfg.common.load_checkpoint_strategy} "
                    f"with {start_iteration = }"
                )

            if collective.is_master:
                self.build_loggers()
                self.setup_loggers(self.cfg.common.exp_name)

        init_metrics = compute_meters(
            group_name="init",
            aggregate=True,
            collective=collective,
        )
        if collective.is_local_master:
            self.log_metrics_to_loggers(init_metrics, start_iteration, "init")

        common_chkp_kwargs = {
            "model": model,
            "optimizer": optimizer,
            "collective": collective,
            "lr_scheduler": lr_scheduler,
            "data_loaders": data_loaders,
            "data_sources": train_datapipeline.datasets,
            "grad_scaler": training_context["scaler"],
        }

        train_data_iter = iter(train_datapipeline.dataloader)

        # Setup training metric engine if config exists
        train_metric_engine = None
        if self.cfg.metrics and "train" in self.cfg.metrics:
            from optimus_dl.modules.metrics.engine import MetricEngine

            train_metric_engine = MetricEngine("train", self.cfg.metrics["train"])

        collective.barrier()
        logger.info("All ranks are ready")

        pbar = trange(
            start_iteration,
            self.cfg.optimization.iterations + 1,
            initial=start_iteration,
            total=self.cfg.optimization.iterations,
            miniters=self.cfg.common.log_freq,
            maxinterval=1000000,
            disable=not collective.is_local_master,
            smoothing=0,
        )
        for iteration in pbar:
            try:
                # Execute one training iteration using recipe mixin
                self.run_training_iteration(
                    model=model,
                    optimizer=optimizer,
                    criterion=criterion,
                    train_data_iter=train_data_iter,
                    training_context=training_context,
                    lr_scheduler=lr_scheduler,
                    metric_engine=train_metric_engine,
                )

                with meters_group("train") as should_log:
                    if should_log:
                        # Get aggregated metrics for progress bar
                        current_metrics = compute_meters(
                            "train",
                            aggregate=True,
                            collective=collective,
                        )
                        if train_metric_engine:
                            current_metrics = train_metric_engine.compute(
                                current_metrics
                            )

                        if collective.is_local_master:
                            pbar.set_postfix(current_metrics, refresh=False)

                        # Log metrics to all configured loggers
                        if collective.is_master:
                            self.log_metrics_to_loggers(
                                current_metrics, iteration, "train"
                            )

                step_meters("train")  # Step the metrics logging iteration counter
                reset_meters(
                    "train"
                )  # Reset metrics after logging (keep metrics with reset=False)
                with training_context["amp_ctx"]:
                    metrics = self.run_evaluation_if_needed(
                        iteration=iteration,
                        model=model,
                        criterion=criterion,
                        eval_data={
                            k: v for k, v in eval_datapipeline.items() if v is not None
                        },
                        collective=collective,
                        all_metrics_configs=self.cfg.metrics,
                    )
                if metrics and collective.is_master:
                    for eval_name, eval_metrics in metrics.items():
                        self.log_metrics_to_loggers(eval_metrics, iteration, eval_name)

                self.save_checkpoint_if_needed(
                    iteration=iteration,
                    **common_chkp_kwargs,
                )

            except KeyboardInterrupt:
                self.handle_training_interruption(
                    iteration=iteration,
                    **common_chkp_kwargs,
                )
                break
            except Exception as e:
                logger.error(f"Training step failed at iteration {iteration}: {e}")
                raise

        # Close loggers at the end of training
        if collective.is_master:
            self.close_loggers()

build_criterion(*args, **kwargs)

Delegate to CriterionBuilder.

Source code in optimus_dl/recipe/train/base.py
def build_criterion(self, *args, **kwargs) -> BaseCriterion:
    """Delegate to CriterionBuilder."""
    return self.criterion_builder.build_criterion(*args, **kwargs)

build_eval_data(*args, **kwargs)

Delegate to DataBuilder for evaluation data.

Source code in optimus_dl/recipe/train/base.py
def build_eval_data(self, *args, **kwargs):
    """Delegate to DataBuilder for evaluation data."""
    return self.data_builder.build_eval_data(*args, **kwargs)

build_loggers(*args, **kwargs)

Delegate to LoggerManager for building loggers.

Source code in optimus_dl/recipe/train/base.py
def build_loggers(self, *args, **kwargs):
    """Delegate to LoggerManager for building loggers."""
    return self.logger_manager.build_loggers(*args, **kwargs)

build_lr_scheduler(*args, **kwargs)

Delegate to SchedulerBuilder.

Source code in optimus_dl/recipe/train/base.py
def build_lr_scheduler(self, *args, **kwargs):
    """Delegate to SchedulerBuilder."""
    return self.scheduler_builder.build_lr_scheduler(*args, **kwargs)

build_model(*args, **kwargs)

Delegate to ModelBuilder.

Source code in optimus_dl/recipe/train/base.py
def build_model(self, *args, **kwargs) -> BaseModel:
    """Delegate to ModelBuilder."""
    return self.model_builder.build_model(*args, **kwargs)

build_optimizer(*args, **kwargs)

Delegate to OptimizerBuilder.

Source code in optimus_dl/recipe/train/base.py
def build_optimizer(self, *args, **kwargs) -> Optimizer:
    """Delegate to OptimizerBuilder."""
    return self.optimizer_builder.build_optimizer(*args, **kwargs)

build_train_data(*args, **kwargs)

Delegate to DataBuilder for training data.

Source code in optimus_dl/recipe/train/base.py
def build_train_data(self, *args, **kwargs):
    """Delegate to DataBuilder for training data."""
    return self.data_builder.build_train_data(*args, **kwargs)

close_loggers(*args, **kwargs)

Delegate logger cleanup to LoggerManager.

Source code in optimus_dl/recipe/train/base.py
def close_loggers(self, *args, **kwargs):
    """Delegate logger cleanup to LoggerManager."""
    return self.logger_manager.close_loggers(*args, **kwargs)

load_checkpoint(*args, **kwargs)

Load a specific checkpoint.

Source code in optimus_dl/recipe/train/base.py
def load_checkpoint(self, *args, **kwargs):
    """Load a specific checkpoint."""
    if "logger_manager" not in kwargs:
        kwargs["logger_manager"] = self.logger_manager
    return self.checkpoint_manager.load_checkpoint(*args, **kwargs)

load_checkpoint_if_exists(*args, **kwargs)

Try to resume from latest checkpoint in output path.

Source code in optimus_dl/recipe/train/base.py
def load_checkpoint_if_exists(self, *args, **kwargs):
    """Try to resume from latest checkpoint in output path."""
    if "checkpoint_path" not in kwargs:
        kwargs["checkpoint_path"] = self.cfg.common.output_path
    if "logger_manager" not in kwargs:
        kwargs["logger_manager"] = self.logger_manager
    return self.checkpoint_manager.load_checkpoint_if_exists(*args, **kwargs)

log_metrics_to_loggers(*args, **kwargs)

Delegate metrics logging to LoggerManager.

Source code in optimus_dl/recipe/train/base.py
def log_metrics_to_loggers(self, *args, **kwargs):
    """Delegate metrics logging to LoggerManager."""
    return self.logger_manager.log_metrics_to_loggers(*args, **kwargs)

run()

Run the complete training pipeline.

Source code in optimus_dl/recipe/train/base.py
def run(self):
    """Run the complete training pipeline."""
    self.setup_context()
    is_restart = self.checkpoint_manager.is_restart(self.cfg.common.output_path)

    with meters_group("init"):
        log_event_start("perf/init")
        logger.info(f"Using output path : {self.cfg.common.output_path}")
        logger.info(self.cfg)

        # Setup device and distributed collective
        device, collective = setup_device_and_collective(
            use_gpu=self.cfg.common.use_gpu, config=self.cfg.common.distributed
        )

        logger.info(
            "Setting up console logging. Will log from master only from now."
        )
        if not collective.is_master:
            setup_logging(logging.WARNING)

        model: BaseModel = self.build_model(
            model_config=self.cfg.model,
            collective=collective,
            is_restart=is_restart,
            checkpoint_manager=self.checkpoint_manager,
        )

        optimizer: Optimizer = self.build_optimizer(model.make_parameter_groups())
        lr_scheduler = self.build_lr_scheduler(optimizer)
        criterion: BaseCriterion = self.build_criterion(collective=collective)

        # Setup training context (AMP, scaler, etc.) using recipe mixin
        training_context = self.setup_training_context(device)

        try:
            train_datapipeline = self.build_train_data(
                device=device, collective=collective
            )
            assert (
                train_datapipeline is not None
            ), "Train data pipeline not initialized"
            eval_datapipeline = self.build_eval_data(
                device=device, collective=collective
            )
            data_loaders = {
                "train": train_datapipeline.dataloader,
                # eval dataloader may be not restored
            }
        except Exception as e:
            logger.error(f"Failed to build data pipelines: {e}")
            raise

        model = model.to(device)

        # cannot be after checkpoint load as may erase the start event
        log_event_end("perf/init")

        # Try to resume from checkpoint in output paths
        start_iteration, metadata = self.load_checkpoint_if_exists(
            model=model,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            data_loaders=data_loaders,
            collective=collective,
            data_sources=train_datapipeline.datasets,
        )
        if is_restart:
            # cases when training run but did not produce any artifacts is
            # indistinguishable from the case when training is not started at all
            assert metadata is not None, "Misaligned is_restart flag"

        logger.info(f"Considering this run as {is_restart = }")
        if not is_restart and self.cfg.common.load_checkpoint is not None:
            # if checkpoint from output path was not loaded, we are sure that this launch is not
            # re-scheduling / preemption re-start, so we can try loading model from load_checkpoint
            metadata = self.load_checkpoint(
                model=model,
                optimizer=optimizer,
                lr_scheduler=lr_scheduler,
                data_loaders=data_loaders,
                data_sources=train_datapipeline.datasets,
                collective=collective,
                load_strategy=self.cfg.common.load_checkpoint_strategy,
                checkpoint_path=self.cfg.common.load_checkpoint,
            )
            start_iteration = metadata["iteration"] + 1
            logger.info(
                "Loaded checkpoint from "
                f"checkpoint_path = {self.cfg.common.load_checkpoint} path with "
                f"load_strategy = {self.cfg.common.load_checkpoint_strategy} "
                f"with {start_iteration = }"
            )

        if collective.is_master:
            self.build_loggers()
            self.setup_loggers(self.cfg.common.exp_name)

    init_metrics = compute_meters(
        group_name="init",
        aggregate=True,
        collective=collective,
    )
    if collective.is_local_master:
        self.log_metrics_to_loggers(init_metrics, start_iteration, "init")

    common_chkp_kwargs = {
        "model": model,
        "optimizer": optimizer,
        "collective": collective,
        "lr_scheduler": lr_scheduler,
        "data_loaders": data_loaders,
        "data_sources": train_datapipeline.datasets,
        "grad_scaler": training_context["scaler"],
    }

    train_data_iter = iter(train_datapipeline.dataloader)

    # Setup training metric engine if config exists
    train_metric_engine = None
    if self.cfg.metrics and "train" in self.cfg.metrics:
        from optimus_dl.modules.metrics.engine import MetricEngine

        train_metric_engine = MetricEngine("train", self.cfg.metrics["train"])

    collective.barrier()
    logger.info("All ranks are ready")

    pbar = trange(
        start_iteration,
        self.cfg.optimization.iterations + 1,
        initial=start_iteration,
        total=self.cfg.optimization.iterations,
        miniters=self.cfg.common.log_freq,
        maxinterval=1000000,
        disable=not collective.is_local_master,
        smoothing=0,
    )
    for iteration in pbar:
        try:
            # Execute one training iteration using recipe mixin
            self.run_training_iteration(
                model=model,
                optimizer=optimizer,
                criterion=criterion,
                train_data_iter=train_data_iter,
                training_context=training_context,
                lr_scheduler=lr_scheduler,
                metric_engine=train_metric_engine,
            )

            with meters_group("train") as should_log:
                if should_log:
                    # Get aggregated metrics for progress bar
                    current_metrics = compute_meters(
                        "train",
                        aggregate=True,
                        collective=collective,
                    )
                    if train_metric_engine:
                        current_metrics = train_metric_engine.compute(
                            current_metrics
                        )

                    if collective.is_local_master:
                        pbar.set_postfix(current_metrics, refresh=False)

                    # Log metrics to all configured loggers
                    if collective.is_master:
                        self.log_metrics_to_loggers(
                            current_metrics, iteration, "train"
                        )

            step_meters("train")  # Step the metrics logging iteration counter
            reset_meters(
                "train"
            )  # Reset metrics after logging (keep metrics with reset=False)
            with training_context["amp_ctx"]:
                metrics = self.run_evaluation_if_needed(
                    iteration=iteration,
                    model=model,
                    criterion=criterion,
                    eval_data={
                        k: v for k, v in eval_datapipeline.items() if v is not None
                    },
                    collective=collective,
                    all_metrics_configs=self.cfg.metrics,
                )
            if metrics and collective.is_master:
                for eval_name, eval_metrics in metrics.items():
                    self.log_metrics_to_loggers(eval_metrics, iteration, eval_name)

            self.save_checkpoint_if_needed(
                iteration=iteration,
                **common_chkp_kwargs,
            )

        except KeyboardInterrupt:
            self.handle_training_interruption(
                iteration=iteration,
                **common_chkp_kwargs,
            )
            break
        except Exception as e:
            logger.error(f"Training step failed at iteration {iteration}: {e}")
            raise

    # Close loggers at the end of training
    if collective.is_master:
        self.close_loggers()

run_evaluation_if_needed(*args, **kwargs)

Check eval frequency and run evaluation via Evaluator.

Source code in optimus_dl/recipe/train/base.py
def run_evaluation_if_needed(self, *args, **kwargs):
    """Check eval frequency and run evaluation via Evaluator."""
    return self.evaluator.run_evaluation_if_needed(*args, **kwargs)

save_checkpoint(*args, **kwargs)

Save a checkpoint via CheckpointManager.

Source code in optimus_dl/recipe/train/base.py
def save_checkpoint(self, *args, **kwargs):
    """Save a checkpoint via CheckpointManager."""
    config_dict = self.cfg if hasattr(self.cfg, "__dict__") else dict(self.cfg)
    kwargs["full_config"] = config_dict
    if "checkpoint_path" not in kwargs:
        kwargs["checkpoint_path"] = self.cfg.common.output_path
    if "logger_manager" not in kwargs:
        kwargs["logger_manager"] = self.logger_manager
    return self.checkpoint_manager.save_checkpoint(*args, **kwargs)

save_checkpoint_if_needed(*args, **kwargs)

Check save frequency and delegate to CheckpointManager.

Source code in optimus_dl/recipe/train/base.py
def save_checkpoint_if_needed(self, *args, **kwargs):
    """Check save frequency and delegate to CheckpointManager."""
    config_dict = self.cfg if hasattr(self.cfg, "__dict__") else dict(self.cfg)
    kwargs["full_config"] = config_dict
    kwargs["checkpoint_path"] = self.cfg.common.output_path
    kwargs["save_freq"] = self.cfg.common.save_freq
    kwargs["last_save_freq"] = self.cfg.common.last_save_freq
    kwargs["logger_manager"] = self.logger_manager
    return self.checkpoint_manager.save_checkpoint_if_needed(*args, **kwargs)

setup_context()

Setup global training context (precision, etc.).

Source code in optimus_dl/recipe/train/base.py
def setup_context(self):
    """Setup global training context (precision, etc.)."""
    set_seed(self.cfg.common.seed, deterministic=self.cfg.common.deterministic)
    torch.set_float32_matmul_precision("highest")
    if torch.cuda.is_available():
        torch.backends.cuda.matmul.fp32_precision = "ieee"
        if hasattr(torch.backends.cudnn, "conv"):
            torch.backends.cudnn.conv.fp32_precision = "ieee"

setup_loggers(experiment_name, full_config=None)

Setup logging with experiment configuration.

Parameters:

Name Type Description Default
experiment_name str

Name of the current experiment.

required
full_config dict | None

Full configuration dictionary. If None, uses self.cfg.

None
Source code in optimus_dl/recipe/train/base.py
def setup_loggers(self, experiment_name: str, full_config: dict | None = None):
    """Setup logging with experiment configuration.

    Args:
        experiment_name: Name of the current experiment.
        full_config: Full configuration dictionary. If None, uses self.cfg.
    """
    if full_config is None:
        full_config = self.cfg if hasattr(self.cfg, "__dict__") else dict(self.cfg)
    return self.logger_manager.setup_loggers(experiment_name, full_config)

validate_config()

Validate configuration before training starts.

Source code in optimus_dl/recipe/train/base.py
def validate_config(self) -> None:
    """Validate configuration before training starts."""
    # Required fields
    assert self.cfg.model is not None, "Model configuration is required"
    assert self.cfg.data is not None, "Data configuration is required"
    assert self.cfg.criterion is not None, "Criterion configuration is required"
    assert (
        self.cfg.optimization is not None
    ), "Optimization configuration is required"

    # Training parameters
    assert self.cfg.optimization.iterations > 0, "iterations must be positive"
    assert self.cfg.optimization.acc_steps > 0, "acc_steps must be positive"
    assert self.cfg.common.log_freq > 0, "log_freq must be positive"

    if self.cfg.common.save_freq > 0:
        assert (
            self.cfg.common.output_path
        ), "output_path required when save_freq > 0"

    logger.info("Configuration validation passed")