Skip to content

base

optimus_dl.recipe.train.base

TrainRecipe

Bases: TrainingContextMixin, TrainingIterationMixin, TrainingInterruptionMixin

Main training recipe that orchestrates all training components.

This class uses composition to coordinate specialized builders and managers:

  • ModelBuilder: Builds models and applies transforms
  • OptimizerBuilder: Builds optimizers
  • CriterionBuilder: Builds loss criteria
  • DataBuilder: Builds train/eval data pipelines
  • SchedulerBuilder: Builds learning rate schedulers
  • LoggerManager: Handles metrics logging setup and operations
  • CheckpointManager: Manages checkpoint saving and loading
  • Evaluator: Handles evaluation runs and metrics

It inherits from training logic mixins for the core loop execution:

  • TrainingContextMixin: Sets up training context (AMP, scaler, etc.)
  • TrainingIterationMixin: Orchestrates complete training iterations
  • TrainingInterruptionMixin: Handles training interruptions and errors
Source code in optimus_dl/recipe/train/base.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
@register_train_recipe("base", TrainConfig)
class TrainRecipe(
    TrainingContextMixin,
    TrainingIterationMixin,
    TrainingInterruptionMixin,
):
    """Main training recipe that orchestrates all training components.

    This class uses composition to coordinate specialized builders and managers:

    - ModelBuilder: Builds models and applies transforms
    - OptimizerBuilder: Builds optimizers
    - CriterionBuilder: Builds loss criteria
    - DataBuilder: Builds train/eval data pipelines
    - SchedulerBuilder: Builds learning rate schedulers
    - LoggerManager: Handles metrics logging setup and operations
    - CheckpointManager: Manages checkpoint saving and loading
    - Evaluator: Handles evaluation runs and metrics

    It inherits from training logic mixins for the core loop execution:

    - TrainingContextMixin: Sets up training context (AMP, scaler, etc.)
    - TrainingIterationMixin: Orchestrates complete training iterations
    - TrainingInterruptionMixin: Handles training interruptions and errors
    """

    cfg: TrainConfig

    def __init__(self, cfg: TrainConfig) -> None:
        self.cfg = cfg

        # Initialize components via composition using registry for dependency injection
        self.model_builder = build_component(
            "model_builder",
            cfg.model_builder,
            cast_to=ModelBuilder,
            model_transforms=cfg.model_transforms,
        )
        assert self.model_builder is not None, "Model builder not initialized"
        self.optimizer_builder = build_component(
            "optimizer_builder",
            cfg.optimizer_builder,
            cast_to=OptimizerBuilder,
            optimization_config=cfg.optimization,
        )
        assert self.optimizer_builder is not None, "Optimizer builder not initialized"
        self.criterion_builder = build_component(
            "criterion_builder",
            cfg.criterion_builder,
            cast_to=CriterionBuilder,
            criterion_config=cfg.criterion,
        )
        assert self.criterion_builder is not None, "Criterion builder not initialized"
        self.data_builder = build_component(
            "data_builder",
            cfg.data_builder,
            cast_to=DataBuilder,
            data_config=cfg.data,
            data_seed=cfg.common.data_seed,
        )
        assert self.data_builder is not None, "Data builder not initialized"
        self.scheduler_builder = build_component(
            "scheduler_builder",
            cfg.scheduler_builder,
            cast_to=SchedulerBuilder,
            lr_scheduler_config=cfg.lr_scheduler,
            optimization_config=cfg.optimization,
        )
        assert self.scheduler_builder is not None, "Scheduler builder not initialized"
        self.logger_manager = build_component(
            "logger_manager",
            cfg.logger_manager,
            cast_to=LoggerManager,
            loggers_config=cfg.loggers,
        )
        assert self.logger_manager is not None, "Logger manager not initialized"
        self.checkpoint_manager = build_component(
            "checkpoint_manager",
            cfg.checkpoint_manager,
            cast_to=CheckpointManager,
        )
        assert self.checkpoint_manager is not None, "Checkpoint manager not initialized"
        self.evaluator = build_component(
            "evaluator",
            cfg.evaluator,
            cast_to=Evaluator,
            eval_freq=cfg.common.eval_freq,
            eval_iterations=cfg.common.eval_iterations,
            eval_guaranteed_same_batches=cfg.common.eval_guaranteed_same_batches,
            eval_checkpointing=cfg.common.eval_checkpointing,
            output_path=cfg.common.output_path,
        )
        assert self.evaluator is not None, "Evaluator not initialized"

        # Initialize training logic mixins
        TrainingContextMixin.__init__(self, cfg.optimization)
        TrainingIterationMixin.__init__(self, cfg.optimization, cfg.common.log_freq)
        TrainingInterruptionMixin.__init__(
            self,
            cfg.common.save_freq,
            cfg.common.output_path,
            self.save_checkpoint,  # Pass the checkpoint method as callback
        )
        self.validate_config()

    # Delegate methods
    def build_model(self, *args, **kwargs) -> BaseModel:
        """Delegate to ModelBuilder."""
        return self.model_builder.build_model(*args, **kwargs)

    def build_optimizer(self, *args, **kwargs) -> Optimizer:
        """Delegate to OptimizerBuilder."""
        return self.optimizer_builder.build_optimizer(*args, **kwargs)

    def build_lr_scheduler(self, *args, **kwargs):
        """Delegate to SchedulerBuilder."""
        return self.scheduler_builder.build_lr_scheduler(*args, **kwargs)

    def build_criterion(self, *args, **kwargs) -> BaseCriterion:
        """Delegate to CriterionBuilder."""
        return self.criterion_builder.build_criterion(*args, **kwargs)

    def build_train_data(self, *args, **kwargs):
        """Delegate to DataBuilder for training data."""
        return self.data_builder.build_train_data(*args, **kwargs)

    def build_eval_data(self, *args, **kwargs):
        """Delegate to DataBuilder for evaluation data."""
        return self.data_builder.build_eval_data(*args, **kwargs)

    def build_loggers(self, *args, **kwargs):
        """Delegate to LoggerManager for building loggers."""
        return self.logger_manager.build_loggers(*args, **kwargs)

    def setup_loggers(
        self,
        experiment_name: str,
        full_config: dict | None = None,
        logs_parent_path: str | None = None,
        start_iteration: int | None = None,
    ):
        """Setup logging with experiment configuration.

        Args:
            experiment_name: Name of the current experiment.
            full_config: Full configuration dictionary. If None, uses self.cfg.
            start_iteration: Starting iteration number for logging.
        """
        if full_config is None:
            full_config = self.cfg if hasattr(self.cfg, "__dict__") else dict(self.cfg)
        return self.logger_manager.setup_loggers(
            experiment_name=experiment_name,
            full_config=full_config,
            logs_parent_path=logs_parent_path,
            start_iteration=start_iteration,
        )

    def log_metrics_to_loggers(self, *args, **kwargs):
        """Delegate metrics logging to LoggerManager."""
        return self.logger_manager.log_metrics_to_loggers(*args, **kwargs)

    def close_loggers(self, run_status: RunStatus = RunStatus.SUCCESS, *args, **kwargs):
        """Delegate logger cleanup to LoggerManager."""
        self.logger_manager.finished(run_status)
        return self.logger_manager.close_loggers(*args, **kwargs)

    def save_checkpoint_if_needed(self, *args, **kwargs):
        """Check save frequency and delegate to CheckpointManager."""
        config_dict = self.cfg if hasattr(self.cfg, "__dict__") else dict(self.cfg)
        kwargs["full_config"] = config_dict
        kwargs["checkpoint_path"] = self.cfg.common.output_path
        kwargs["save_freq"] = self.cfg.common.save_freq
        kwargs["last_save_freq"] = self.cfg.common.last_save_freq
        kwargs["logger_manager"] = self.logger_manager
        return self.checkpoint_manager.save_checkpoint_if_needed(*args, **kwargs)

    def save_checkpoint(self, *args, **kwargs):
        """Save a checkpoint via CheckpointManager."""
        config_dict = self.cfg if hasattr(self.cfg, "__dict__") else dict(self.cfg)
        kwargs["full_config"] = config_dict
        if "checkpoint_path" not in kwargs:
            kwargs["checkpoint_path"] = self.cfg.common.output_path
        if "logger_manager" not in kwargs:
            kwargs["logger_manager"] = self.logger_manager
        return self.checkpoint_manager.save_checkpoint(*args, **kwargs)

    def load_checkpoint_if_exists(self, *args, **kwargs):
        """Try to resume from latest checkpoint in output path."""
        if "checkpoint_path" not in kwargs:
            kwargs["checkpoint_path"] = self.cfg.common.output_path
        if "logger_manager" not in kwargs:
            kwargs["logger_manager"] = self.logger_manager
        return self.checkpoint_manager.load_checkpoint_if_exists(*args, **kwargs)

    def load_checkpoint(self, *args, **kwargs):
        """Load a specific checkpoint."""
        if "logger_manager" not in kwargs:
            kwargs["logger_manager"] = self.logger_manager
        return self.checkpoint_manager.load_checkpoint(*args, **kwargs)

    def run_evaluation_if_needed(self, *args, **kwargs):
        """Check eval frequency and run evaluation via Evaluator."""
        return self.evaluator.run_evaluation_if_needed(*args, **kwargs)

    def validate_config(self) -> None:
        """Validate configuration before training starts."""
        # Required fields
        assert self.cfg.model is not None, "Model configuration is required"
        assert self.cfg.data is not None, "Data configuration is required"
        assert self.cfg.criterion is not None, "Criterion configuration is required"
        assert (
            self.cfg.optimization is not None
        ), "Optimization configuration is required"

        # Training parameters
        assert self.cfg.optimization.iterations > 0, "iterations must be positive"
        assert self.cfg.optimization.acc_steps > 0, "acc_steps must be positive"
        assert self.cfg.common.log_freq > 0, "log_freq must be positive"

        if self.cfg.common.save_freq > 0:
            assert (
                self.cfg.common.output_path
            ), "output_path required when save_freq > 0"

        logger.info("Configuration validation passed")

    def setup_context(self):
        """Setup global training context (precision, etc.)."""
        set_seed(self.cfg.common.seed, deterministic=self.cfg.common.deterministic)
        torch.set_float32_matmul_precision("highest")
        if torch.cuda.is_available():
            torch.backends.cuda.matmul.fp32_precision = "ieee"
            if hasattr(torch.backends.cudnn, "conv"):
                torch.backends.cudnn.conv.fp32_precision = "ieee"

    def evaluate_and_log(
        self,
        iteration,
        model,
        criterion,
        eval_datapipeline,
        collective,
        device,
    ):
        """Run evaluation and log metrics if needed.

        Args:
            iteration: Current training iteration.
            model: Model to evaluate.
            criterion: Criterion to use.
            eval_datapipeline: Evaluation data pipelines.
            collective: Distributed collective.
            device: Device to use.

        Returns:
            Dictionary of metrics if evaluation ran, else None.
        """
        try:
            metrics = self.run_evaluation_if_needed(
                iteration=iteration,
                model=model,
                criterion=criterion,
                eval_data_dict={
                    k: v for k, v in eval_datapipeline.items() if v is not None
                },
                collective=collective,
                all_metrics_configs=self.cfg.metrics,
                device=device,
            )
            if metrics and collective.is_master:
                for eval_name, eval_metrics in metrics.items():
                    self.log_metrics_to_loggers(eval_metrics, iteration, eval_name)

            return metrics
        except Exception as e:
            logger.exception(f"Evaluation failed at iteration {iteration}: {e}")
            raise

    def run(self):
        """Run the complete training pipeline."""
        self.setup_context()
        is_restart = self.checkpoint_manager.is_restart(self.cfg.common.output_path)
        finished_run = False

        with meters_group("init"):
            log_event_start("perf/init")
            logger.info(f"Using output path : {self.cfg.common.output_path}")
            logger.info(self.cfg)

            # Setup device and distributed collective
            logger.debug("Setting up device and distributed collective...")
            device, collective = setup_device_and_collective(
                use_gpu=self.cfg.common.use_gpu, config=self.cfg.common.distributed
            )
            logger.debug(f"Device and collective setup complete. Device: {device}")

            logger.info(
                "Setting up console logging. Will log from master only from now."
            )
            logs_parent_path = pathlib.Path(self.cfg.common.output_path) / "logging"
            rank = collective.rank if collective is not None else 0
            log_path = logs_parent_path / f"rank_{rank}"
            if not collective.is_master:
                setup_logging(
                    logging.WARNING,
                )
                setup_logging(
                    log_path=log_path,
                    clear_existing=False,
                )
            else:
                setup_logging(
                    log_path=log_path,
                    clear_existing=False,
                )

            logger.debug("Building model...")
            model: BaseModel = self.build_model(
                model_config=self.cfg.model,
                collective=collective,
                is_restart=is_restart,
                checkpoint_manager=self.checkpoint_manager,
            )
            logger.info("Model built")

            logger.debug("Building optimizer...")
            optimizer: Optimizer = self.build_optimizer(model.make_parameter_groups())
            logger.info("Optimizer built")

            logger.debug("Building LR scheduler...")
            lr_scheduler = self.build_lr_scheduler(optimizer)
            logger.info("LR scheduler built")

            logger.debug("Building criterion...")
            criterion: BaseCriterion = self.build_criterion(collective=collective)
            logger.info("Criterion built")

            # Setup training context (AMP, scaler, etc.) using recipe mixin
            training_context = self.setup_training_context(device)

            try:
                logger.debug("Building data pipelines...")
                train_datapipeline = self.build_train_data(
                    device=device, collective=collective
                )
                assert (
                    train_datapipeline is not None
                ), "Train data pipeline not initialized"
                eval_datapipeline = self.build_eval_data(
                    device=device, collective=collective
                )
                data_loaders = {
                    "train": train_datapipeline.dataloader,
                    # eval dataloader may be not restored
                }
                logger.debug("Data pipelines built")
            except Exception as e:
                logger.error(f"Failed to build data pipelines: {e}")
                raise

            model = model.to(device)

            # cannot be after checkpoint load as may erase the start event
            log_event_end("perf/init")

            # Try to resume from checkpoint in output paths
            logger.debug("Checking for existing checkpoints to load...")
            common_chkp_kwargs = {
                "model": model,
                "optimizer": optimizer,
                "collective": collective,
                "lr_scheduler": lr_scheduler,
                "data_loaders": data_loaders,
                "data_sources": train_datapipeline.datasets,
                "grad_scaler": training_context["scaler"],
            }
            start_iteration, metadata = self.load_checkpoint_if_exists(
                **common_chkp_kwargs
            )
            if is_restart:
                # cases when training run but did not produce any artifacts is
                # indistinguishable from the case when training is not started at all
                assert metadata is not None, "Misaligned is_restart flag"

            logger.info(f"Considering this run as {is_restart = }")
            if not is_restart and self.cfg.common.load_checkpoint is not None:
                # if checkpoint from output path was not loaded, we are sure that this launch is not
                # re-scheduling / preemption re-start, so we can try loading model from load_checkpoint
                logger.debug(
                    f"Loading checkpoint from {self.cfg.common.load_checkpoint}..."
                )
                metadata = self.load_checkpoint(
                    **common_chkp_kwargs,
                    load_strategy=self.cfg.common.load_checkpoint_strategy,
                    checkpoint_path=self.cfg.common.load_checkpoint,
                )
                start_iteration = metadata["iteration"] + 1
                logger.info(
                    "Loaded checkpoint from "
                    f"checkpoint_path = {self.cfg.common.load_checkpoint} path with "
                    f"load_strategy = {self.cfg.common.load_checkpoint_strategy} "
                    f"with {start_iteration = }"
                )

            if self.cfg.optimization.iterations <= start_iteration:
                start_iteration = self.cfg.optimization.iterations
                finished_run = True
                logger.info(
                    "This run resumed from a checkpoint at iteration "
                    f"{start_iteration}, but the configured max iterations is "
                    f"{self.cfg.optimization.iterations}. Treating the run as already "
                    "finished and performing only final evaluations."
                )

            if collective.is_master:
                self.build_loggers()
                self.setup_loggers(
                    experiment_name=self.cfg.common.exp_name,
                    logs_parent_path=logs_parent_path,
                    start_iteration=start_iteration,
                )

        init_metrics = compute_meters(
            group_name="init",
            aggregate=True,
            collective=collective,
        )
        if collective.is_local_master:
            self.log_metrics_to_loggers(init_metrics, start_iteration, "init")

        logger.debug("Initializing training data iterator...")
        train_data_iter = iter(train_datapipeline.dataloader)
        logger.debug("Training data iterator initialized")

        # Setup training metric engine if config exists
        train_metric_engine = None
        if self.cfg.metrics and "train" in self.cfg.metrics:
            from optimus_dl.modules.metrics.engine import MetricEngine

            train_metric_engine = MetricEngine("train", self.cfg.metrics["train"])

        logger.debug("Reaching pre-training barrier...")
        collective.barrier()
        logger.info("All ranks are ready")

        if not finished_run:
            # Determine if we need to keep any evaluation checkpoints for resumption
            iteration_to_keep = None
            if metadata is not None and not metadata.get("eval_finished", True):
                iteration_to_keep = metadata["iteration"]

            # Global cleanup of evaluation checkpoints, preserving only the one needed for resumption
            if collective.is_master:
                self.evaluator.cleanup_all_eval_checkpoints(
                    exclude_iteration=iteration_to_keep
                )
            collective.barrier()

            if iteration_to_keep is not None:
                logger.info(
                    f"Previous checkpoint (iter {iteration_to_keep}) was saved before evaluation, running evaluation before resuming training..."
                )
                self.evaluate_and_log(
                    iteration=iteration_to_keep,
                    model=model,
                    criterion=criterion,
                    eval_datapipeline=eval_datapipeline,
                    collective=collective,
                    device=device,
                )
                self.save_checkpoint_if_needed(
                    iteration=iteration_to_keep,
                    force_save=True,
                    **common_chkp_kwargs,
                    extra_metadata={"eval_finished": True},
                    metadata_only=True,
                )
                if collective.is_master:
                    self.evaluator.cleanup_all_eval_checkpoints(iteration_to_keep)
                collective.barrier()

            pbar = trange(
                start_iteration,
                self.cfg.optimization.iterations + 1,
                initial=start_iteration,
                total=self.cfg.optimization.iterations,
                miniters=self.cfg.common.log_freq,
                maxinterval=1000000,
                disable=not collective.is_local_master,
                smoothing=0,
            )
            for iteration in pbar:
                try:
                    logger.debug(f"Starting training iteration {iteration}")
                    # Execute one training iteration using recipe mixin
                    self.run_training_iteration(
                        model=model,
                        optimizer=optimizer,
                        criterion=criterion,
                        train_data_iter=train_data_iter,
                        training_context=training_context,
                        lr_scheduler=lr_scheduler,
                        metric_engine=train_metric_engine,
                    )
                    logger.debug(f"Finished training iteration {iteration}")

                    with meters_group("train") as should_log:
                        if should_log:
                            # Get aggregated metrics for progress bar
                            logger.debug(
                                f"Computing training metrics for iteration {iteration}"
                            )
                            current_metrics = compute_meters(
                                "train",
                                aggregate=True,
                                collective=collective,
                            )
                            if train_metric_engine:
                                current_metrics = train_metric_engine.compute(
                                    current_metrics
                                )

                            if collective.is_local_master:
                                pbar.set_postfix(current_metrics, refresh=False)

                            # Log metrics to all configured loggers
                            if collective.is_master:
                                self.log_metrics_to_loggers(
                                    current_metrics, iteration, "train"
                                )

                    step_meters("train")  # Step the metrics logging iteration counter
                    reset_meters(
                        "train"
                    )  # Reset metrics after logging (keep metrics with reset=False)

                    logger.debug(
                        f"Running evaluation if needed for iteration {iteration}"
                    )

                    eval_needed = self.evaluator.should_run_evaluation(
                        iteration, eval_datapipeline
                    )

                    if eval_needed:
                        if self.cfg.common.eval_resumable:
                            logger.debug(
                                f"Saving pre-evaluation checkpoint for iteration {iteration}..."
                            )
                            self.save_checkpoint_if_needed(
                                iteration=iteration,
                                force_save=True,
                                **common_chkp_kwargs,
                                extra_metadata={"eval_finished": False},
                            )
                            logger.debug(
                                f"Pre-eval checkpoint saved for iteration {iteration}"
                            )

                        self.evaluate_and_log(
                            iteration=iteration,
                            model=model,
                            criterion=criterion,
                            eval_datapipeline=eval_datapipeline,
                            collective=collective,
                            device=device,
                        )

                        if self.cfg.common.eval_resumable:
                            logger.debug(
                                f"Saving post-evaluation metadata-only checkpoint for iteration {iteration}"
                            )
                            self.save_checkpoint_if_needed(
                                iteration=iteration,
                                force_save=True,
                                **common_chkp_kwargs,
                                extra_metadata={"eval_finished": True},
                                metadata_only=True,
                            )
                            if collective.is_master:
                                self.evaluator.cleanup_all_eval_checkpoints(iteration)
                            collective.barrier()
                    else:
                        # Regular checkpointing if no evaluation ran
                        self.save_checkpoint_if_needed(
                            iteration=iteration,
                            **common_chkp_kwargs,
                            extra_metadata={"eval_finished": True},
                        )

                except KeyboardInterrupt:
                    logger.info("Training interrupted by user")
                    self.handle_training_interruption(
                        iteration=iteration,
                        **common_chkp_kwargs,
                    )
                    if collective.is_master:
                        logger.debug("Closing loggers...")
                        self.close_loggers(run_status=RunStatus.INTERRUPTED)
                    break
                except Exception as e:
                    logger.error(f"Training step failed at iteration {iteration}: {e}")
                    raise
        else:
            logger.info("As finished run was resumed, assuming evaluation objective.")
            self.evaluate_and_log(
                iteration=start_iteration,
                model=model,
                criterion=criterion,
                eval_datapipeline=eval_datapipeline,
                collective=collective,
                device=device,
            )

        # Close loggers at the end of training
        if collective.is_master:
            logger.debug("Closing loggers...")
            self.close_loggers()

        gc.collect()
        collective.close()
        logger.info("Training run complete")

build_criterion(*args, **kwargs)

Delegate to CriterionBuilder.

Source code in optimus_dl/recipe/train/base.py
def build_criterion(self, *args, **kwargs) -> BaseCriterion:
    """Delegate to CriterionBuilder."""
    return self.criterion_builder.build_criterion(*args, **kwargs)

build_eval_data(*args, **kwargs)

Delegate to DataBuilder for evaluation data.

Source code in optimus_dl/recipe/train/base.py
def build_eval_data(self, *args, **kwargs):
    """Delegate to DataBuilder for evaluation data."""
    return self.data_builder.build_eval_data(*args, **kwargs)

build_loggers(*args, **kwargs)

Delegate to LoggerManager for building loggers.

Source code in optimus_dl/recipe/train/base.py
def build_loggers(self, *args, **kwargs):
    """Delegate to LoggerManager for building loggers."""
    return self.logger_manager.build_loggers(*args, **kwargs)

build_lr_scheduler(*args, **kwargs)

Delegate to SchedulerBuilder.

Source code in optimus_dl/recipe/train/base.py
def build_lr_scheduler(self, *args, **kwargs):
    """Delegate to SchedulerBuilder."""
    return self.scheduler_builder.build_lr_scheduler(*args, **kwargs)

build_model(*args, **kwargs)

Delegate to ModelBuilder.

Source code in optimus_dl/recipe/train/base.py
def build_model(self, *args, **kwargs) -> BaseModel:
    """Delegate to ModelBuilder."""
    return self.model_builder.build_model(*args, **kwargs)

build_optimizer(*args, **kwargs)

Delegate to OptimizerBuilder.

Source code in optimus_dl/recipe/train/base.py
def build_optimizer(self, *args, **kwargs) -> Optimizer:
    """Delegate to OptimizerBuilder."""
    return self.optimizer_builder.build_optimizer(*args, **kwargs)

build_train_data(*args, **kwargs)

Delegate to DataBuilder for training data.

Source code in optimus_dl/recipe/train/base.py
def build_train_data(self, *args, **kwargs):
    """Delegate to DataBuilder for training data."""
    return self.data_builder.build_train_data(*args, **kwargs)

close_loggers(run_status=RunStatus.SUCCESS, *args, **kwargs)

Delegate logger cleanup to LoggerManager.

Source code in optimus_dl/recipe/train/base.py
def close_loggers(self, run_status: RunStatus = RunStatus.SUCCESS, *args, **kwargs):
    """Delegate logger cleanup to LoggerManager."""
    self.logger_manager.finished(run_status)
    return self.logger_manager.close_loggers(*args, **kwargs)

evaluate_and_log(iteration, model, criterion, eval_datapipeline, collective, device)

Run evaluation and log metrics if needed.

Parameters:

Name Type Description Default
iteration

Current training iteration.

required
model

Model to evaluate.

required
criterion

Criterion to use.

required
eval_datapipeline

Evaluation data pipelines.

required
collective

Distributed collective.

required
device

Device to use.

required

Returns:

Type Description

Dictionary of metrics if evaluation ran, else None.

Source code in optimus_dl/recipe/train/base.py
def evaluate_and_log(
    self,
    iteration,
    model,
    criterion,
    eval_datapipeline,
    collective,
    device,
):
    """Run evaluation and log metrics if needed.

    Args:
        iteration: Current training iteration.
        model: Model to evaluate.
        criterion: Criterion to use.
        eval_datapipeline: Evaluation data pipelines.
        collective: Distributed collective.
        device: Device to use.

    Returns:
        Dictionary of metrics if evaluation ran, else None.
    """
    try:
        metrics = self.run_evaluation_if_needed(
            iteration=iteration,
            model=model,
            criterion=criterion,
            eval_data_dict={
                k: v for k, v in eval_datapipeline.items() if v is not None
            },
            collective=collective,
            all_metrics_configs=self.cfg.metrics,
            device=device,
        )
        if metrics and collective.is_master:
            for eval_name, eval_metrics in metrics.items():
                self.log_metrics_to_loggers(eval_metrics, iteration, eval_name)

        return metrics
    except Exception as e:
        logger.exception(f"Evaluation failed at iteration {iteration}: {e}")
        raise

load_checkpoint(*args, **kwargs)

Load a specific checkpoint.

Source code in optimus_dl/recipe/train/base.py
def load_checkpoint(self, *args, **kwargs):
    """Load a specific checkpoint."""
    if "logger_manager" not in kwargs:
        kwargs["logger_manager"] = self.logger_manager
    return self.checkpoint_manager.load_checkpoint(*args, **kwargs)

load_checkpoint_if_exists(*args, **kwargs)

Try to resume from latest checkpoint in output path.

Source code in optimus_dl/recipe/train/base.py
def load_checkpoint_if_exists(self, *args, **kwargs):
    """Try to resume from latest checkpoint in output path."""
    if "checkpoint_path" not in kwargs:
        kwargs["checkpoint_path"] = self.cfg.common.output_path
    if "logger_manager" not in kwargs:
        kwargs["logger_manager"] = self.logger_manager
    return self.checkpoint_manager.load_checkpoint_if_exists(*args, **kwargs)

log_metrics_to_loggers(*args, **kwargs)

Delegate metrics logging to LoggerManager.

Source code in optimus_dl/recipe/train/base.py
def log_metrics_to_loggers(self, *args, **kwargs):
    """Delegate metrics logging to LoggerManager."""
    return self.logger_manager.log_metrics_to_loggers(*args, **kwargs)

run()

Run the complete training pipeline.

Source code in optimus_dl/recipe/train/base.py
def run(self):
    """Run the complete training pipeline."""
    self.setup_context()
    is_restart = self.checkpoint_manager.is_restart(self.cfg.common.output_path)
    finished_run = False

    with meters_group("init"):
        log_event_start("perf/init")
        logger.info(f"Using output path : {self.cfg.common.output_path}")
        logger.info(self.cfg)

        # Setup device and distributed collective
        logger.debug("Setting up device and distributed collective...")
        device, collective = setup_device_and_collective(
            use_gpu=self.cfg.common.use_gpu, config=self.cfg.common.distributed
        )
        logger.debug(f"Device and collective setup complete. Device: {device}")

        logger.info(
            "Setting up console logging. Will log from master only from now."
        )
        logs_parent_path = pathlib.Path(self.cfg.common.output_path) / "logging"
        rank = collective.rank if collective is not None else 0
        log_path = logs_parent_path / f"rank_{rank}"
        if not collective.is_master:
            setup_logging(
                logging.WARNING,
            )
            setup_logging(
                log_path=log_path,
                clear_existing=False,
            )
        else:
            setup_logging(
                log_path=log_path,
                clear_existing=False,
            )

        logger.debug("Building model...")
        model: BaseModel = self.build_model(
            model_config=self.cfg.model,
            collective=collective,
            is_restart=is_restart,
            checkpoint_manager=self.checkpoint_manager,
        )
        logger.info("Model built")

        logger.debug("Building optimizer...")
        optimizer: Optimizer = self.build_optimizer(model.make_parameter_groups())
        logger.info("Optimizer built")

        logger.debug("Building LR scheduler...")
        lr_scheduler = self.build_lr_scheduler(optimizer)
        logger.info("LR scheduler built")

        logger.debug("Building criterion...")
        criterion: BaseCriterion = self.build_criterion(collective=collective)
        logger.info("Criterion built")

        # Setup training context (AMP, scaler, etc.) using recipe mixin
        training_context = self.setup_training_context(device)

        try:
            logger.debug("Building data pipelines...")
            train_datapipeline = self.build_train_data(
                device=device, collective=collective
            )
            assert (
                train_datapipeline is not None
            ), "Train data pipeline not initialized"
            eval_datapipeline = self.build_eval_data(
                device=device, collective=collective
            )
            data_loaders = {
                "train": train_datapipeline.dataloader,
                # eval dataloader may be not restored
            }
            logger.debug("Data pipelines built")
        except Exception as e:
            logger.error(f"Failed to build data pipelines: {e}")
            raise

        model = model.to(device)

        # cannot be after checkpoint load as may erase the start event
        log_event_end("perf/init")

        # Try to resume from checkpoint in output paths
        logger.debug("Checking for existing checkpoints to load...")
        common_chkp_kwargs = {
            "model": model,
            "optimizer": optimizer,
            "collective": collective,
            "lr_scheduler": lr_scheduler,
            "data_loaders": data_loaders,
            "data_sources": train_datapipeline.datasets,
            "grad_scaler": training_context["scaler"],
        }
        start_iteration, metadata = self.load_checkpoint_if_exists(
            **common_chkp_kwargs
        )
        if is_restart:
            # cases when training run but did not produce any artifacts is
            # indistinguishable from the case when training is not started at all
            assert metadata is not None, "Misaligned is_restart flag"

        logger.info(f"Considering this run as {is_restart = }")
        if not is_restart and self.cfg.common.load_checkpoint is not None:
            # if checkpoint from output path was not loaded, we are sure that this launch is not
            # re-scheduling / preemption re-start, so we can try loading model from load_checkpoint
            logger.debug(
                f"Loading checkpoint from {self.cfg.common.load_checkpoint}..."
            )
            metadata = self.load_checkpoint(
                **common_chkp_kwargs,
                load_strategy=self.cfg.common.load_checkpoint_strategy,
                checkpoint_path=self.cfg.common.load_checkpoint,
            )
            start_iteration = metadata["iteration"] + 1
            logger.info(
                "Loaded checkpoint from "
                f"checkpoint_path = {self.cfg.common.load_checkpoint} path with "
                f"load_strategy = {self.cfg.common.load_checkpoint_strategy} "
                f"with {start_iteration = }"
            )

        if self.cfg.optimization.iterations <= start_iteration:
            start_iteration = self.cfg.optimization.iterations
            finished_run = True
            logger.info(
                "This run resumed from a checkpoint at iteration "
                f"{start_iteration}, but the configured max iterations is "
                f"{self.cfg.optimization.iterations}. Treating the run as already "
                "finished and performing only final evaluations."
            )

        if collective.is_master:
            self.build_loggers()
            self.setup_loggers(
                experiment_name=self.cfg.common.exp_name,
                logs_parent_path=logs_parent_path,
                start_iteration=start_iteration,
            )

    init_metrics = compute_meters(
        group_name="init",
        aggregate=True,
        collective=collective,
    )
    if collective.is_local_master:
        self.log_metrics_to_loggers(init_metrics, start_iteration, "init")

    logger.debug("Initializing training data iterator...")
    train_data_iter = iter(train_datapipeline.dataloader)
    logger.debug("Training data iterator initialized")

    # Setup training metric engine if config exists
    train_metric_engine = None
    if self.cfg.metrics and "train" in self.cfg.metrics:
        from optimus_dl.modules.metrics.engine import MetricEngine

        train_metric_engine = MetricEngine("train", self.cfg.metrics["train"])

    logger.debug("Reaching pre-training barrier...")
    collective.barrier()
    logger.info("All ranks are ready")

    if not finished_run:
        # Determine if we need to keep any evaluation checkpoints for resumption
        iteration_to_keep = None
        if metadata is not None and not metadata.get("eval_finished", True):
            iteration_to_keep = metadata["iteration"]

        # Global cleanup of evaluation checkpoints, preserving only the one needed for resumption
        if collective.is_master:
            self.evaluator.cleanup_all_eval_checkpoints(
                exclude_iteration=iteration_to_keep
            )
        collective.barrier()

        if iteration_to_keep is not None:
            logger.info(
                f"Previous checkpoint (iter {iteration_to_keep}) was saved before evaluation, running evaluation before resuming training..."
            )
            self.evaluate_and_log(
                iteration=iteration_to_keep,
                model=model,
                criterion=criterion,
                eval_datapipeline=eval_datapipeline,
                collective=collective,
                device=device,
            )
            self.save_checkpoint_if_needed(
                iteration=iteration_to_keep,
                force_save=True,
                **common_chkp_kwargs,
                extra_metadata={"eval_finished": True},
                metadata_only=True,
            )
            if collective.is_master:
                self.evaluator.cleanup_all_eval_checkpoints(iteration_to_keep)
            collective.barrier()

        pbar = trange(
            start_iteration,
            self.cfg.optimization.iterations + 1,
            initial=start_iteration,
            total=self.cfg.optimization.iterations,
            miniters=self.cfg.common.log_freq,
            maxinterval=1000000,
            disable=not collective.is_local_master,
            smoothing=0,
        )
        for iteration in pbar:
            try:
                logger.debug(f"Starting training iteration {iteration}")
                # Execute one training iteration using recipe mixin
                self.run_training_iteration(
                    model=model,
                    optimizer=optimizer,
                    criterion=criterion,
                    train_data_iter=train_data_iter,
                    training_context=training_context,
                    lr_scheduler=lr_scheduler,
                    metric_engine=train_metric_engine,
                )
                logger.debug(f"Finished training iteration {iteration}")

                with meters_group("train") as should_log:
                    if should_log:
                        # Get aggregated metrics for progress bar
                        logger.debug(
                            f"Computing training metrics for iteration {iteration}"
                        )
                        current_metrics = compute_meters(
                            "train",
                            aggregate=True,
                            collective=collective,
                        )
                        if train_metric_engine:
                            current_metrics = train_metric_engine.compute(
                                current_metrics
                            )

                        if collective.is_local_master:
                            pbar.set_postfix(current_metrics, refresh=False)

                        # Log metrics to all configured loggers
                        if collective.is_master:
                            self.log_metrics_to_loggers(
                                current_metrics, iteration, "train"
                            )

                step_meters("train")  # Step the metrics logging iteration counter
                reset_meters(
                    "train"
                )  # Reset metrics after logging (keep metrics with reset=False)

                logger.debug(
                    f"Running evaluation if needed for iteration {iteration}"
                )

                eval_needed = self.evaluator.should_run_evaluation(
                    iteration, eval_datapipeline
                )

                if eval_needed:
                    if self.cfg.common.eval_resumable:
                        logger.debug(
                            f"Saving pre-evaluation checkpoint for iteration {iteration}..."
                        )
                        self.save_checkpoint_if_needed(
                            iteration=iteration,
                            force_save=True,
                            **common_chkp_kwargs,
                            extra_metadata={"eval_finished": False},
                        )
                        logger.debug(
                            f"Pre-eval checkpoint saved for iteration {iteration}"
                        )

                    self.evaluate_and_log(
                        iteration=iteration,
                        model=model,
                        criterion=criterion,
                        eval_datapipeline=eval_datapipeline,
                        collective=collective,
                        device=device,
                    )

                    if self.cfg.common.eval_resumable:
                        logger.debug(
                            f"Saving post-evaluation metadata-only checkpoint for iteration {iteration}"
                        )
                        self.save_checkpoint_if_needed(
                            iteration=iteration,
                            force_save=True,
                            **common_chkp_kwargs,
                            extra_metadata={"eval_finished": True},
                            metadata_only=True,
                        )
                        if collective.is_master:
                            self.evaluator.cleanup_all_eval_checkpoints(iteration)
                        collective.barrier()
                else:
                    # Regular checkpointing if no evaluation ran
                    self.save_checkpoint_if_needed(
                        iteration=iteration,
                        **common_chkp_kwargs,
                        extra_metadata={"eval_finished": True},
                    )

            except KeyboardInterrupt:
                logger.info("Training interrupted by user")
                self.handle_training_interruption(
                    iteration=iteration,
                    **common_chkp_kwargs,
                )
                if collective.is_master:
                    logger.debug("Closing loggers...")
                    self.close_loggers(run_status=RunStatus.INTERRUPTED)
                break
            except Exception as e:
                logger.error(f"Training step failed at iteration {iteration}: {e}")
                raise
    else:
        logger.info("As finished run was resumed, assuming evaluation objective.")
        self.evaluate_and_log(
            iteration=start_iteration,
            model=model,
            criterion=criterion,
            eval_datapipeline=eval_datapipeline,
            collective=collective,
            device=device,
        )

    # Close loggers at the end of training
    if collective.is_master:
        logger.debug("Closing loggers...")
        self.close_loggers()

    gc.collect()
    collective.close()
    logger.info("Training run complete")

run_evaluation_if_needed(*args, **kwargs)

Check eval frequency and run evaluation via Evaluator.

Source code in optimus_dl/recipe/train/base.py
def run_evaluation_if_needed(self, *args, **kwargs):
    """Check eval frequency and run evaluation via Evaluator."""
    return self.evaluator.run_evaluation_if_needed(*args, **kwargs)

save_checkpoint(*args, **kwargs)

Save a checkpoint via CheckpointManager.

Source code in optimus_dl/recipe/train/base.py
def save_checkpoint(self, *args, **kwargs):
    """Save a checkpoint via CheckpointManager."""
    config_dict = self.cfg if hasattr(self.cfg, "__dict__") else dict(self.cfg)
    kwargs["full_config"] = config_dict
    if "checkpoint_path" not in kwargs:
        kwargs["checkpoint_path"] = self.cfg.common.output_path
    if "logger_manager" not in kwargs:
        kwargs["logger_manager"] = self.logger_manager
    return self.checkpoint_manager.save_checkpoint(*args, **kwargs)

save_checkpoint_if_needed(*args, **kwargs)

Check save frequency and delegate to CheckpointManager.

Source code in optimus_dl/recipe/train/base.py
def save_checkpoint_if_needed(self, *args, **kwargs):
    """Check save frequency and delegate to CheckpointManager."""
    config_dict = self.cfg if hasattr(self.cfg, "__dict__") else dict(self.cfg)
    kwargs["full_config"] = config_dict
    kwargs["checkpoint_path"] = self.cfg.common.output_path
    kwargs["save_freq"] = self.cfg.common.save_freq
    kwargs["last_save_freq"] = self.cfg.common.last_save_freq
    kwargs["logger_manager"] = self.logger_manager
    return self.checkpoint_manager.save_checkpoint_if_needed(*args, **kwargs)

setup_context()

Setup global training context (precision, etc.).

Source code in optimus_dl/recipe/train/base.py
def setup_context(self):
    """Setup global training context (precision, etc.)."""
    set_seed(self.cfg.common.seed, deterministic=self.cfg.common.deterministic)
    torch.set_float32_matmul_precision("highest")
    if torch.cuda.is_available():
        torch.backends.cuda.matmul.fp32_precision = "ieee"
        if hasattr(torch.backends.cudnn, "conv"):
            torch.backends.cudnn.conv.fp32_precision = "ieee"

setup_loggers(experiment_name, full_config=None, logs_parent_path=None, start_iteration=None)

Setup logging with experiment configuration.

Parameters:

Name Type Description Default
experiment_name str

Name of the current experiment.

required
full_config dict | None

Full configuration dictionary. If None, uses self.cfg.

None
start_iteration int | None

Starting iteration number for logging.

None
Source code in optimus_dl/recipe/train/base.py
def setup_loggers(
    self,
    experiment_name: str,
    full_config: dict | None = None,
    logs_parent_path: str | None = None,
    start_iteration: int | None = None,
):
    """Setup logging with experiment configuration.

    Args:
        experiment_name: Name of the current experiment.
        full_config: Full configuration dictionary. If None, uses self.cfg.
        start_iteration: Starting iteration number for logging.
    """
    if full_config is None:
        full_config = self.cfg if hasattr(self.cfg, "__dict__") else dict(self.cfg)
    return self.logger_manager.setup_loggers(
        experiment_name=experiment_name,
        full_config=full_config,
        logs_parent_path=logs_parent_path,
        start_iteration=start_iteration,
    )

validate_config()

Validate configuration before training starts.

Source code in optimus_dl/recipe/train/base.py
def validate_config(self) -> None:
    """Validate configuration before training starts."""
    # Required fields
    assert self.cfg.model is not None, "Model configuration is required"
    assert self.cfg.data is not None, "Data configuration is required"
    assert self.cfg.criterion is not None, "Criterion configuration is required"
    assert (
        self.cfg.optimization is not None
    ), "Optimization configuration is required"

    # Training parameters
    assert self.cfg.optimization.iterations > 0, "iterations must be positive"
    assert self.cfg.optimization.acc_steps > 0, "acc_steps must be positive"
    assert self.cfg.common.log_freq > 0, "log_freq must be positive"

    if self.cfg.common.save_freq > 0:
        assert (
            self.cfg.common.output_path
        ), "output_path required when save_freq > 0"

    logger.info("Configuration validation passed")