Skip to content

Index

optimus_dl.recipe.train.mixins.managers

Evaluator

Manager for running periodic evaluations during training.

Handles iterating over validation datasets, computing loss and other metrics, and aggregating results across distributed ranks.

Parameters:

Name Type Description Default
cfg EvaluatorConfig

Evaluator configuration.

required
eval_freq int

Frequency of evaluation runs (in iterations).

0
eval_iterations int | None

Max number of batches to process per evaluation dataset. If None or negative, processes the entire dataset (negative values are treated as unlimited).

None
eval_guaranteed_same_batches bool

If True, assumes all ranks will see the same number of batches, allowing for simpler stopping logic. If False, uses collective communication to determine when to stop if any rank exhausts its dataloader.

False
Source code in optimus_dl/recipe/train/mixins/managers/evaluation_manager.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
class Evaluator:
    """Manager for running periodic evaluations during training.

    Handles iterating over validation datasets, computing loss and other metrics,
    and aggregating results across distributed ranks.

    Args:
        cfg: Evaluator configuration.
        eval_freq: Frequency of evaluation runs (in iterations).
        eval_iterations: Max number of batches to process per evaluation dataset.
            If None or negative, processes the entire dataset (negative values are
            treated as unlimited).
        eval_guaranteed_same_batches: If True, assumes all ranks will see the same
            number of batches, allowing for simpler stopping logic. If False, uses
            collective communication to determine when to stop if any rank exhausts
            its dataloader.
    """

    def __init__(
        self,
        cfg: EvaluatorConfig,
        eval_freq: int = 0,
        eval_iterations: int | None = None,
        eval_guaranteed_same_batches: bool = False,
        eval_checkpointing: int | None = None,
        output_path: str | pathlib.Path | None = None,
        **kwargs: Any,
    ):
        self.cfg = cfg
        self.eval_freq = eval_freq
        self.eval_iterations = eval_iterations
        self.eval_guaranteed_same_batches = eval_guaranteed_same_batches
        self.eval_checkpointing = eval_checkpointing
        self.output_path = output_path
        self.eval_checkpoint_manager = None
        if output_path:
            self.eval_checkpoint_manager = EvaluationCheckpointManager(output_path)

    @contextlib.contextmanager
    def forward_context(self, device: torch.device):
        """Context manager for evaluation forward pass.

        Can be used to set up any necessary context (e.g., mixed precision) for the
        forward pass during evaluation.

        Returns:
            A context manager (e.g., from contextlib) that sets up the desired context.
        """
        if self.cfg.amp.enabled:
            amp_cfg = self.cfg.amp
            dtype = str_to_dtype(amp_cfg.dtype)
            amp_ctx = torch.autocast(device.type, dtype=dtype, enabled=amp_cfg.enabled)
            with amp_ctx:
                yield
        else:
            yield

    def should_run_evaluation(
        self,
        iteration: int,
        eval_data_dict: dict[str, EvalDataPipeline | None],
    ) -> bool:
        """Check if any of the evaluation datasets match the current iteration frequency.

        Args:
            iteration: Current training step.
            eval_data_dict: Dictionary mapping dataset names to eval data pipelines.

        Returns:
            True if at least one evaluation should run, False otherwise.
        """
        for eval_data in eval_data_dict.values():
            if eval_data is None:
                continue
            eval_freq = (
                eval_data.eval_freq
                if eval_data.eval_freq is not None
                else self.eval_freq
            )
            if eval_freq > 0 and iteration % eval_freq == 0:
                return True
        return False

    def run_evaluation_if_needed(
        self,
        iteration: int,
        model: BaseModel,
        criterion: BaseCriterion,
        eval_data_dict: dict[str, EvalDataPipeline],
        device: torch.device,
        collective: Collective | None = None,
        all_metrics_configs: dict[str, list[dict]] | None = None,
    ) -> None | dict:
        """Run evaluation if the current iteration matches the frequency for any dataset.

        Args:
            iteration: Current training step.
            model: The model to evaluate.
            criterion: The loss function.
            eval_data_dict: Dictionary mapping dataset names to dataloaders.
            collective: Distributed collective for metric aggregation.
            all_metrics_configs: Root metrics configuration from TrainConfig.

        Returns:
            Dictionary of computed metrics if evaluation ran, else None.
        """
        result = {}

        # deterministic order
        eval_data_dict_keys = sorted(eval_data_dict.keys())
        for eval_name in eval_data_dict_keys:
            eval_data = eval_data_dict[eval_name]

            max_iterations = (
                eval_data.eval_iterations
                if eval_data.eval_iterations is not None
                else self.eval_iterations
            )
            if max_iterations is not None and max_iterations < 0:
                max_iterations = None

            eval_freq = (
                eval_data.eval_freq
                if eval_data.eval_freq is not None
                else self.eval_freq
            )
            if eval_freq <= 0 or iteration % eval_freq != 0:
                continue

            try:
                result |= self.run_evaluation(
                    iteration=iteration,
                    model=model,
                    criterion=criterion,
                    eval_data_dict={eval_name: eval_data},
                    max_iterations=max_iterations,
                    collective=collective,
                    all_metrics_configs=all_metrics_configs,
                    show_progress=True,
                    device=device,
                )
            except Exception as e:
                logger.exception(
                    f"Error during evaluation of {eval_name} at iteration {iteration}: {e}"
                )
                raise

        if len(result) == 0:
            return None
        return result

    def run_evaluation(
        self,
        model: BaseModel,
        criterion: BaseCriterion,
        eval_data_dict: dict[str, EvalDataPipeline],
        device: torch.device,
        max_iterations: int | None = None,
        collective: Collective | None = None,
        all_metrics_configs: dict[str, list[dict]] | None = None,
        metrics_prefix: str = "eval",
        show_progress: bool = False,
        iteration: int | None = None,
    ):
        """Execute the evaluation loop for all provided datasets.

        Sets the model to eval mode, disables gradients, and runs the forward pass
        for each batch. Metrics are aggregated globally.

        Args:
            model: Model to evaluate.
            criterion: Loss function.
            eval_data_dict: Dictionary of {name: dataloader/DataPipeline}.
            max_iterations: Limit on number of batches.
            collective: Distributed collective.
            all_metrics_configs: Root metrics configuration mapping dataset names to configs.
            metrics_prefix: Prefix for metric groups (e.g., "eval" or "metrics").
            show_progress: Whether to show a progress bar.
            iteration: Current training iteration, used for naming checkpoints.

        Returns:
            Nested dictionary of results: {dataset_name: {metric_name: value}}.
        """
        model.eval()
        total_metrics = {}
        all_metrics_configs = all_metrics_configs or {}

        eval_data_dict_keys = sorted(eval_data_dict.keys())
        for eval_name in eval_data_dict_keys:
            eval_data = eval_data_dict[eval_name]
            max_iterations_local = (
                eval_data.eval_iterations
                if eval_data.eval_iterations is not None
                else max_iterations
            )
            guaranteed_same_batches_local = (
                eval_data.eval_guaranteed_same_batches
                if eval_data.eval_guaranteed_same_batches is not None
                else self.eval_guaranteed_same_batches
            )
            if max_iterations_local is not None and max_iterations_local < 0:
                max_iterations_local = None

            eval_checkpointing = (
                eval_data.eval_checkpointing
                if eval_data.eval_checkpointing is not None
                else self.eval_checkpointing
            )

            logger.info(
                f"Running evaluation {eval_name} for {max_iterations_local if max_iterations_local is not None else 'unlimited'} iterations (on each rank)"
            )

            # Handle both raw dataloader and DataPipeline object
            dataloader = getattr(eval_data, "dataloader", eval_data)

            engine = None
            requested_protocols = None
            dataset_metrics = all_metrics_configs.get(eval_name)
            if dataset_metrics:
                from optimus_dl.modules.metrics.engine import MetricEngine

                engine = MetricEngine(f"{metrics_prefix}/{eval_name}", dataset_metrics)
                requested_protocols = engine.required_external_protocols

            group_name = f"{metrics_prefix}/{eval_name}"
            with (
                torch.no_grad(),
                meters_group(group_name, log_freq=1, force_recreate=True),
            ):
                eval_iter = iter(dataloader)

                eval_iterations_processed = 0
                if self.eval_checkpoint_manager is not None and iteration is not None:
                    eval_iterations_processed = (
                        self.eval_checkpoint_manager.load_iteration_state(
                            iteration=iteration,
                            eval_name=eval_name,
                            group_name=group_name,
                            eval_iter=eval_iter,
                            collective=collective,
                        )
                    )

                log_event_start("perf/total_run")
                start_time = time.perf_counter()

                pbar = None
                if show_progress:
                    pbar = tqdm(
                        desc=f"Eval {eval_name}",
                        disable=(
                            collective is not None and not collective.is_local_master
                        ),
                        unit="batch",
                        total=(
                            max_iterations_local
                            if max_iterations_local is not None
                            else None
                        ),
                        initial=eval_iterations_processed,
                    )

                check_exhaustion = (
                    collective is not None and not guaranteed_same_batches_local
                )

                def should_stop(_iterations, max_iterations_local=max_iterations_local):
                    return (
                        max_iterations_local is not None
                        and max_iterations_local > 0
                        and _iterations >= max_iterations_local
                    )

                try:
                    # Consider we loaded state where one rank is already exhausted.
                    # We should check it before starting the loop to have the same collectives set.

                    stop_flag = None
                    if check_exhaustion:
                        assert collective is not None
                        stop_flag = torch.tensor(
                            [int(should_stop(eval_iterations_processed))],
                            device=collective.default_device,
                            dtype=torch.int32,
                        )
                        collective.all_reduce(
                            stop_flag,
                            op=Collective.ReduceOp.MAX,
                        )
                        if stop_flag.item() == 1:
                            logger.info(
                                "Some ranks have already finished evaluation before starting the loop, stopping immediately."
                            )
                            raise StopIteration

                    while not should_stop(eval_iterations_processed):
                        logger.debug(
                            f"Eval {eval_name}: Starting iteration {eval_iterations_processed}"
                        )
                        log_event_occurence("perf/full_iteration")
                        exhausted = False
                        try:
                            logger.debug(
                                f"Eval {eval_name}: Fetching batch for iteration {eval_iterations_processed}"
                            )
                            elapsed_batch_get, batch = measured_next(eval_iter)
                            info_once(logger, f"Batch has keys {batch.keys()}")
                        except StopIteration:
                            logger.debug(
                                f"Eval {eval_name}: Dataloader exhausted on this rank"
                            )
                            exhausted = True

                        if check_exhaustion:
                            assert collective is not None
                            assert stop_flag is not None
                            logger.debug(
                                f"Eval {eval_name}: Synchronizing exhaustion state (all_reduce MAX)"
                            )
                            stop_flag[0] = int(exhausted)
                            collective.all_reduce(
                                stop_flag,
                                op=Collective.ReduceOp.MAX,
                            )
                            if stop_flag.item() == 1:
                                # at least one rank is exhausted, stop evaluation
                                logger.info(
                                    f"Eval {eval_name}: At least one rank exhausted its dataloader, stopping evaluation."
                                )
                                raise StopIteration
                        else:
                            # If we are guaranteed that all ranks see the same number of batches,
                            # we can just stop when this rank is exhausted
                            if exhausted:
                                logger.info(
                                    f"Eval {eval_name}: Dataloader exhausted, stopping evaluation."
                                )
                                raise StopIteration

                        logger.debug(
                            f"Eval {eval_name}: Running forward pass for iteration {eval_iterations_processed}"
                        )
                        with self.forward_context(device=device):
                            loss, exposed = criterion(
                                model, batch, requested_protocols=requested_protocols
                            )

                        if engine:
                            logger.debug(f"Eval {eval_name}: Updating metric engine")
                            computed_data = exposed.copy()
                            computed_data["loss"] = loss
                            engine.update(
                                data=dict(model=model, batch=batch),
                                computed_data=computed_data,
                            )

                        log_summed("num_batches", lambda: 1)
                        log_averaged(
                            "perf/batch_get",
                            elapsed_batch_get,
                        )

                        eval_iterations_processed += 1
                        if pbar is not None:
                            pbar.update(1)

                        # Step metrics for each evaluation iteration
                        step_meters(f"{metrics_prefix}/{eval_name}")
                        logger.debug(
                            f"Eval {eval_name}: Finished iteration {eval_iterations_processed-1}"
                        )

                        if (
                            self.eval_checkpoint_manager is not None
                            and eval_checkpointing is not None
                            and eval_checkpointing > 0
                            and eval_iterations_processed % eval_checkpointing == 0
                        ):
                            assert (
                                iteration is not None
                            ), "Iteration must be provided for checkpointing"
                            self.eval_checkpoint_manager.save_iteration_state(
                                iteration=iteration,
                                eval_name=eval_name,
                                dataloader_state=eval_iter.state_dict(),
                                group_name=group_name,
                                collective=collective,
                                eval_iterations_processed=eval_iterations_processed,
                            )
                            logger.info(
                                f"Saved evaluation metrics checkpoint at iteration {eval_iterations_processed}"
                            )

                except StopIteration:
                    pass
                finally:
                    if pbar is not None:
                        pbar.refresh()
                        pbar.close()

                total_time = time.perf_counter() - start_time
                log_event_end("perf/total_run")

            logger.debug(f"Eval {eval_name}: Computing aggregated meters")
            eval_metrics = compute_meters(
                f"{metrics_prefix}/{eval_name}",
                aggregate=True,
                collective=collective,
            )

            if engine:
                eval_metrics = engine.compute(eval_metrics)

            # Add basic performance stats
            eval_metrics["perf/total_run_ms"] = total_time * 1000
            if eval_iterations_processed > 0:
                eval_metrics["perf/ms_per_batch"] = (
                    total_time / eval_iterations_processed
                ) * 1000

            logger.info(
                f"Finished eval {eval_name}: {eval_metrics} in {total_time:.1f}s"
            )
            total_metrics[f"{metrics_prefix}/{eval_name}"] = eval_metrics
        return total_metrics

    def cleanup_all_eval_checkpoints(
        self, iteration: int | None = None, exclude_iteration: int | None = None
    ) -> None:
        """Cleanup evaluation checkpoints.

        If iteration is provided, cleans up checkpoints for that iteration only.
        If iteration is None, cleans up ALL evaluation checkpoints in the output path,
        optionally excluding one specific iteration.
        """
        if self.eval_checkpoint_manager is not None:
            self.eval_checkpoint_manager.cleanup(
                iteration=iteration, exclude_iteration=exclude_iteration
            )
        else:
            logger.debug(
                "No evaluation checkpoint manager initialized, skipping cleanup."
            )

cleanup_all_eval_checkpoints(iteration=None, exclude_iteration=None)

Cleanup evaluation checkpoints.

If iteration is provided, cleans up checkpoints for that iteration only. If iteration is None, cleans up ALL evaluation checkpoints in the output path, optionally excluding one specific iteration.

Source code in optimus_dl/recipe/train/mixins/managers/evaluation_manager.py
def cleanup_all_eval_checkpoints(
    self, iteration: int | None = None, exclude_iteration: int | None = None
) -> None:
    """Cleanup evaluation checkpoints.

    If iteration is provided, cleans up checkpoints for that iteration only.
    If iteration is None, cleans up ALL evaluation checkpoints in the output path,
    optionally excluding one specific iteration.
    """
    if self.eval_checkpoint_manager is not None:
        self.eval_checkpoint_manager.cleanup(
            iteration=iteration, exclude_iteration=exclude_iteration
        )
    else:
        logger.debug(
            "No evaluation checkpoint manager initialized, skipping cleanup."
        )

forward_context(device)

Context manager for evaluation forward pass.

Can be used to set up any necessary context (e.g., mixed precision) for the forward pass during evaluation.

Returns:

Type Description

A context manager (e.g., from contextlib) that sets up the desired context.

Source code in optimus_dl/recipe/train/mixins/managers/evaluation_manager.py
@contextlib.contextmanager
def forward_context(self, device: torch.device):
    """Context manager for evaluation forward pass.

    Can be used to set up any necessary context (e.g., mixed precision) for the
    forward pass during evaluation.

    Returns:
        A context manager (e.g., from contextlib) that sets up the desired context.
    """
    if self.cfg.amp.enabled:
        amp_cfg = self.cfg.amp
        dtype = str_to_dtype(amp_cfg.dtype)
        amp_ctx = torch.autocast(device.type, dtype=dtype, enabled=amp_cfg.enabled)
        with amp_ctx:
            yield
    else:
        yield

run_evaluation(model, criterion, eval_data_dict, device, max_iterations=None, collective=None, all_metrics_configs=None, metrics_prefix='eval', show_progress=False, iteration=None)

Execute the evaluation loop for all provided datasets.

Sets the model to eval mode, disables gradients, and runs the forward pass for each batch. Metrics are aggregated globally.

Parameters:

Name Type Description Default
model BaseModel

Model to evaluate.

required
criterion BaseCriterion

Loss function.

required
eval_data_dict dict[str, EvalDataPipeline]

Dictionary of {name: dataloader/DataPipeline}.

required
max_iterations int | None

Limit on number of batches.

None
collective Collective | None

Distributed collective.

None
all_metrics_configs dict[str, list[dict]] | None

Root metrics configuration mapping dataset names to configs.

None
metrics_prefix str

Prefix for metric groups (e.g., "eval" or "metrics").

'eval'
show_progress bool

Whether to show a progress bar.

False
iteration int | None

Current training iteration, used for naming checkpoints.

None

Returns:

Type Description

Nested dictionary of results: {dataset_name: {metric_name: value}}.

Source code in optimus_dl/recipe/train/mixins/managers/evaluation_manager.py
def run_evaluation(
    self,
    model: BaseModel,
    criterion: BaseCriterion,
    eval_data_dict: dict[str, EvalDataPipeline],
    device: torch.device,
    max_iterations: int | None = None,
    collective: Collective | None = None,
    all_metrics_configs: dict[str, list[dict]] | None = None,
    metrics_prefix: str = "eval",
    show_progress: bool = False,
    iteration: int | None = None,
):
    """Execute the evaluation loop for all provided datasets.

    Sets the model to eval mode, disables gradients, and runs the forward pass
    for each batch. Metrics are aggregated globally.

    Args:
        model: Model to evaluate.
        criterion: Loss function.
        eval_data_dict: Dictionary of {name: dataloader/DataPipeline}.
        max_iterations: Limit on number of batches.
        collective: Distributed collective.
        all_metrics_configs: Root metrics configuration mapping dataset names to configs.
        metrics_prefix: Prefix for metric groups (e.g., "eval" or "metrics").
        show_progress: Whether to show a progress bar.
        iteration: Current training iteration, used for naming checkpoints.

    Returns:
        Nested dictionary of results: {dataset_name: {metric_name: value}}.
    """
    model.eval()
    total_metrics = {}
    all_metrics_configs = all_metrics_configs or {}

    eval_data_dict_keys = sorted(eval_data_dict.keys())
    for eval_name in eval_data_dict_keys:
        eval_data = eval_data_dict[eval_name]
        max_iterations_local = (
            eval_data.eval_iterations
            if eval_data.eval_iterations is not None
            else max_iterations
        )
        guaranteed_same_batches_local = (
            eval_data.eval_guaranteed_same_batches
            if eval_data.eval_guaranteed_same_batches is not None
            else self.eval_guaranteed_same_batches
        )
        if max_iterations_local is not None and max_iterations_local < 0:
            max_iterations_local = None

        eval_checkpointing = (
            eval_data.eval_checkpointing
            if eval_data.eval_checkpointing is not None
            else self.eval_checkpointing
        )

        logger.info(
            f"Running evaluation {eval_name} for {max_iterations_local if max_iterations_local is not None else 'unlimited'} iterations (on each rank)"
        )

        # Handle both raw dataloader and DataPipeline object
        dataloader = getattr(eval_data, "dataloader", eval_data)

        engine = None
        requested_protocols = None
        dataset_metrics = all_metrics_configs.get(eval_name)
        if dataset_metrics:
            from optimus_dl.modules.metrics.engine import MetricEngine

            engine = MetricEngine(f"{metrics_prefix}/{eval_name}", dataset_metrics)
            requested_protocols = engine.required_external_protocols

        group_name = f"{metrics_prefix}/{eval_name}"
        with (
            torch.no_grad(),
            meters_group(group_name, log_freq=1, force_recreate=True),
        ):
            eval_iter = iter(dataloader)

            eval_iterations_processed = 0
            if self.eval_checkpoint_manager is not None and iteration is not None:
                eval_iterations_processed = (
                    self.eval_checkpoint_manager.load_iteration_state(
                        iteration=iteration,
                        eval_name=eval_name,
                        group_name=group_name,
                        eval_iter=eval_iter,
                        collective=collective,
                    )
                )

            log_event_start("perf/total_run")
            start_time = time.perf_counter()

            pbar = None
            if show_progress:
                pbar = tqdm(
                    desc=f"Eval {eval_name}",
                    disable=(
                        collective is not None and not collective.is_local_master
                    ),
                    unit="batch",
                    total=(
                        max_iterations_local
                        if max_iterations_local is not None
                        else None
                    ),
                    initial=eval_iterations_processed,
                )

            check_exhaustion = (
                collective is not None and not guaranteed_same_batches_local
            )

            def should_stop(_iterations, max_iterations_local=max_iterations_local):
                return (
                    max_iterations_local is not None
                    and max_iterations_local > 0
                    and _iterations >= max_iterations_local
                )

            try:
                # Consider we loaded state where one rank is already exhausted.
                # We should check it before starting the loop to have the same collectives set.

                stop_flag = None
                if check_exhaustion:
                    assert collective is not None
                    stop_flag = torch.tensor(
                        [int(should_stop(eval_iterations_processed))],
                        device=collective.default_device,
                        dtype=torch.int32,
                    )
                    collective.all_reduce(
                        stop_flag,
                        op=Collective.ReduceOp.MAX,
                    )
                    if stop_flag.item() == 1:
                        logger.info(
                            "Some ranks have already finished evaluation before starting the loop, stopping immediately."
                        )
                        raise StopIteration

                while not should_stop(eval_iterations_processed):
                    logger.debug(
                        f"Eval {eval_name}: Starting iteration {eval_iterations_processed}"
                    )
                    log_event_occurence("perf/full_iteration")
                    exhausted = False
                    try:
                        logger.debug(
                            f"Eval {eval_name}: Fetching batch for iteration {eval_iterations_processed}"
                        )
                        elapsed_batch_get, batch = measured_next(eval_iter)
                        info_once(logger, f"Batch has keys {batch.keys()}")
                    except StopIteration:
                        logger.debug(
                            f"Eval {eval_name}: Dataloader exhausted on this rank"
                        )
                        exhausted = True

                    if check_exhaustion:
                        assert collective is not None
                        assert stop_flag is not None
                        logger.debug(
                            f"Eval {eval_name}: Synchronizing exhaustion state (all_reduce MAX)"
                        )
                        stop_flag[0] = int(exhausted)
                        collective.all_reduce(
                            stop_flag,
                            op=Collective.ReduceOp.MAX,
                        )
                        if stop_flag.item() == 1:
                            # at least one rank is exhausted, stop evaluation
                            logger.info(
                                f"Eval {eval_name}: At least one rank exhausted its dataloader, stopping evaluation."
                            )
                            raise StopIteration
                    else:
                        # If we are guaranteed that all ranks see the same number of batches,
                        # we can just stop when this rank is exhausted
                        if exhausted:
                            logger.info(
                                f"Eval {eval_name}: Dataloader exhausted, stopping evaluation."
                            )
                            raise StopIteration

                    logger.debug(
                        f"Eval {eval_name}: Running forward pass for iteration {eval_iterations_processed}"
                    )
                    with self.forward_context(device=device):
                        loss, exposed = criterion(
                            model, batch, requested_protocols=requested_protocols
                        )

                    if engine:
                        logger.debug(f"Eval {eval_name}: Updating metric engine")
                        computed_data = exposed.copy()
                        computed_data["loss"] = loss
                        engine.update(
                            data=dict(model=model, batch=batch),
                            computed_data=computed_data,
                        )

                    log_summed("num_batches", lambda: 1)
                    log_averaged(
                        "perf/batch_get",
                        elapsed_batch_get,
                    )

                    eval_iterations_processed += 1
                    if pbar is not None:
                        pbar.update(1)

                    # Step metrics for each evaluation iteration
                    step_meters(f"{metrics_prefix}/{eval_name}")
                    logger.debug(
                        f"Eval {eval_name}: Finished iteration {eval_iterations_processed-1}"
                    )

                    if (
                        self.eval_checkpoint_manager is not None
                        and eval_checkpointing is not None
                        and eval_checkpointing > 0
                        and eval_iterations_processed % eval_checkpointing == 0
                    ):
                        assert (
                            iteration is not None
                        ), "Iteration must be provided for checkpointing"
                        self.eval_checkpoint_manager.save_iteration_state(
                            iteration=iteration,
                            eval_name=eval_name,
                            dataloader_state=eval_iter.state_dict(),
                            group_name=group_name,
                            collective=collective,
                            eval_iterations_processed=eval_iterations_processed,
                        )
                        logger.info(
                            f"Saved evaluation metrics checkpoint at iteration {eval_iterations_processed}"
                        )

            except StopIteration:
                pass
            finally:
                if pbar is not None:
                    pbar.refresh()
                    pbar.close()

            total_time = time.perf_counter() - start_time
            log_event_end("perf/total_run")

        logger.debug(f"Eval {eval_name}: Computing aggregated meters")
        eval_metrics = compute_meters(
            f"{metrics_prefix}/{eval_name}",
            aggregate=True,
            collective=collective,
        )

        if engine:
            eval_metrics = engine.compute(eval_metrics)

        # Add basic performance stats
        eval_metrics["perf/total_run_ms"] = total_time * 1000
        if eval_iterations_processed > 0:
            eval_metrics["perf/ms_per_batch"] = (
                total_time / eval_iterations_processed
            ) * 1000

        logger.info(
            f"Finished eval {eval_name}: {eval_metrics} in {total_time:.1f}s"
        )
        total_metrics[f"{metrics_prefix}/{eval_name}"] = eval_metrics
    return total_metrics

run_evaluation_if_needed(iteration, model, criterion, eval_data_dict, device, collective=None, all_metrics_configs=None)

Run evaluation if the current iteration matches the frequency for any dataset.

Parameters:

Name Type Description Default
iteration int

Current training step.

required
model BaseModel

The model to evaluate.

required
criterion BaseCriterion

The loss function.

required
eval_data_dict dict[str, EvalDataPipeline]

Dictionary mapping dataset names to dataloaders.

required
collective Collective | None

Distributed collective for metric aggregation.

None
all_metrics_configs dict[str, list[dict]] | None

Root metrics configuration from TrainConfig.

None

Returns:

Type Description
None | dict

Dictionary of computed metrics if evaluation ran, else None.

Source code in optimus_dl/recipe/train/mixins/managers/evaluation_manager.py
def run_evaluation_if_needed(
    self,
    iteration: int,
    model: BaseModel,
    criterion: BaseCriterion,
    eval_data_dict: dict[str, EvalDataPipeline],
    device: torch.device,
    collective: Collective | None = None,
    all_metrics_configs: dict[str, list[dict]] | None = None,
) -> None | dict:
    """Run evaluation if the current iteration matches the frequency for any dataset.

    Args:
        iteration: Current training step.
        model: The model to evaluate.
        criterion: The loss function.
        eval_data_dict: Dictionary mapping dataset names to dataloaders.
        collective: Distributed collective for metric aggregation.
        all_metrics_configs: Root metrics configuration from TrainConfig.

    Returns:
        Dictionary of computed metrics if evaluation ran, else None.
    """
    result = {}

    # deterministic order
    eval_data_dict_keys = sorted(eval_data_dict.keys())
    for eval_name in eval_data_dict_keys:
        eval_data = eval_data_dict[eval_name]

        max_iterations = (
            eval_data.eval_iterations
            if eval_data.eval_iterations is not None
            else self.eval_iterations
        )
        if max_iterations is not None and max_iterations < 0:
            max_iterations = None

        eval_freq = (
            eval_data.eval_freq
            if eval_data.eval_freq is not None
            else self.eval_freq
        )
        if eval_freq <= 0 or iteration % eval_freq != 0:
            continue

        try:
            result |= self.run_evaluation(
                iteration=iteration,
                model=model,
                criterion=criterion,
                eval_data_dict={eval_name: eval_data},
                max_iterations=max_iterations,
                collective=collective,
                all_metrics_configs=all_metrics_configs,
                show_progress=True,
                device=device,
            )
        except Exception as e:
            logger.exception(
                f"Error during evaluation of {eval_name} at iteration {iteration}: {e}"
            )
            raise

    if len(result) == 0:
        return None
    return result

should_run_evaluation(iteration, eval_data_dict)

Check if any of the evaluation datasets match the current iteration frequency.

Parameters:

Name Type Description Default
iteration int

Current training step.

required
eval_data_dict dict[str, EvalDataPipeline | None]

Dictionary mapping dataset names to eval data pipelines.

required

Returns:

Type Description
bool

True if at least one evaluation should run, False otherwise.

Source code in optimus_dl/recipe/train/mixins/managers/evaluation_manager.py
def should_run_evaluation(
    self,
    iteration: int,
    eval_data_dict: dict[str, EvalDataPipeline | None],
) -> bool:
    """Check if any of the evaluation datasets match the current iteration frequency.

    Args:
        iteration: Current training step.
        eval_data_dict: Dictionary mapping dataset names to eval data pipelines.

    Returns:
        True if at least one evaluation should run, False otherwise.
    """
    for eval_data in eval_data_dict.values():
        if eval_data is None:
            continue
        eval_freq = (
            eval_data.eval_freq
            if eval_data.eval_freq is not None
            else self.eval_freq
        )
        if eval_freq > 0 and iteration % eval_freq == 0:
            return True
    return False

LoggerManager

Manager for multiple metrics loggers.

This class instantiates and orchestrates a list of logging backends (e.g., JSONL, WandB). It provides a unified interface for setting up, logging to, and closing all configured loggers.

Parameters:

Name Type Description Default
cfg LoggerManagerConfig

Manager configuration.

required
loggers_config list[MetricsLoggerConfig] | None

List of configurations for individual loggers.

required
Source code in optimus_dl/recipe/train/mixins/managers/logger_manager.py
class LoggerManager:
    """Manager for multiple metrics loggers.

    This class instantiates and orchestrates a list of logging backends (e.g.,
    JSONL, WandB). It provides a unified interface for setting up, logging to,
    and closing all configured loggers.

    Args:
        cfg: Manager configuration.
        loggers_config: List of configurations for individual loggers.
    """

    def __init__(
        self,
        cfg: LoggerManagerConfig,
        loggers_config: list[MetricsLoggerConfig] | None,
        **kwargs: Any,
    ):
        self.loggers_config = loggers_config
        self.previous_state = {}
        self.loggers: list[BaseMetricsLogger] | None = None
        self.closed = False
        self.spent_logging = None

        self._atexit_handlers = []

    def build_loggers(self, **kwargs):
        """Instantiate all configured loggers.

        Uses the registry to build logger instances. If previous state is available
        (from a checkpoint), it is passed to the logger builders for resumption.

        Returns:
            List of active logger instances.
        """
        if self.loggers_config is None:
            logger.info("No loggers configuration found, metrics logging disabled")
            return
        assert self.loggers is None, "Loggers already built"

        loggers = []
        for logger_config in self.loggers_config:
            try:
                logger_instance = build(
                    "metrics_logger",
                    logger_config,
                    state_dict=self.previous_state.get(logger_config.id),
                    **kwargs,
                )
                loggers.append(logger_instance)
                logger.info(f"Built logger: {logger_instance.__class__.__name__}")

                close_handle = atexit.register(logger_instance.close)
                self._atexit_handlers.append(close_handle)
            except Exception as e:
                logger.error(f"Failed to build logger from config {logger_config}: {e}")
                raise

        self.loggers = loggers

    def setup_loggers(
        self,
        experiment_name: str,
        full_config: dict,
        logs_parent_path: str | None = None,
        start_iteration: int | None = None,
    ):
        """Initialize all loggers with experiment context.

        Args:
            experiment_name: Name of the experiment.
            full_config: Complete training configuration dictionary.
            logs_parent_path: Optional filesystem path as a string under which
                logger-specific log files or run directories are created.
                Use this to log stdout / stderr if applicable
            start_iteration: Starting iteration number for the logging.
        """
        for logger_instance in self.loggers or []:
            try:
                logger_instance.setup(
                    experiment_name=experiment_name,
                    config=full_config,
                    logs_parent_path=logs_parent_path,
                    start_iteration=start_iteration,
                )
            except Exception as e:
                logger.error(
                    f"Failed to setup logger {logger_instance.__class__.__name__}: {e}"
                )

    def log_metrics_to_loggers(self, metrics, step: int, group: str = "train"):
        """Dispatch metrics to all active loggers.

        Args:
            metrics: Dictionary of metric values.
            step: Current iteration.
            group: Metric group name.
        """
        start_time = time.perf_counter()
        if self.spent_logging is not None:
            metrics["ms_spent_logging"] = self.spent_logging * 1000

        for logger_instance in self.loggers or []:
            try:
                logger_instance.log_metrics(metrics, step, group)
            except Exception as e:
                logger.error(
                    f"Failed to log metrics with {logger_instance.__class__.__name__}: {e}"
                )
        end_time = time.perf_counter()
        self.spent_logging = end_time - start_time

    def close_loggers(self):
        """Clean up all loggers."""
        if self.closed:
            return
        self.closed = True

        for logger_instance, close_handle in zip(
            self.loggers or [], self._atexit_handlers, strict=True
        ):
            try:
                logger_instance.close()
            except Exception as e:
                logger.error(
                    f"Failed to close logger {logger_instance.__class__.__name__}: {e}"
                )
            atexit.unregister(close_handle)

    def state_dict(self):
        """Collect state from all loggers for checkpointing."""
        return {
            logger_instance.cfg.id: logger_instance.state_dict()
            for logger_instance in self.loggers or []
        }

    def load_state_dict(self, state_dict):
        """Load logger states from a checkpoint."""
        self.previous_state = state_dict

    def finished(self, status: RunStatus):
        """Hook for when training finishes, to log final status."""
        for logger_instance in self.loggers or []:
            try:
                logger_instance.finished(status)
            except Exception as e:
                logger.error(
                    f"Failed to log finished status with {logger_instance.__class__.__name__}: {e}"
                )

build_loggers(**kwargs)

Instantiate all configured loggers.

Uses the registry to build logger instances. If previous state is available (from a checkpoint), it is passed to the logger builders for resumption.

Returns:

Type Description

List of active logger instances.

Source code in optimus_dl/recipe/train/mixins/managers/logger_manager.py
def build_loggers(self, **kwargs):
    """Instantiate all configured loggers.

    Uses the registry to build logger instances. If previous state is available
    (from a checkpoint), it is passed to the logger builders for resumption.

    Returns:
        List of active logger instances.
    """
    if self.loggers_config is None:
        logger.info("No loggers configuration found, metrics logging disabled")
        return
    assert self.loggers is None, "Loggers already built"

    loggers = []
    for logger_config in self.loggers_config:
        try:
            logger_instance = build(
                "metrics_logger",
                logger_config,
                state_dict=self.previous_state.get(logger_config.id),
                **kwargs,
            )
            loggers.append(logger_instance)
            logger.info(f"Built logger: {logger_instance.__class__.__name__}")

            close_handle = atexit.register(logger_instance.close)
            self._atexit_handlers.append(close_handle)
        except Exception as e:
            logger.error(f"Failed to build logger from config {logger_config}: {e}")
            raise

    self.loggers = loggers

close_loggers()

Clean up all loggers.

Source code in optimus_dl/recipe/train/mixins/managers/logger_manager.py
def close_loggers(self):
    """Clean up all loggers."""
    if self.closed:
        return
    self.closed = True

    for logger_instance, close_handle in zip(
        self.loggers or [], self._atexit_handlers, strict=True
    ):
        try:
            logger_instance.close()
        except Exception as e:
            logger.error(
                f"Failed to close logger {logger_instance.__class__.__name__}: {e}"
            )
        atexit.unregister(close_handle)

finished(status)

Hook for when training finishes, to log final status.

Source code in optimus_dl/recipe/train/mixins/managers/logger_manager.py
def finished(self, status: RunStatus):
    """Hook for when training finishes, to log final status."""
    for logger_instance in self.loggers or []:
        try:
            logger_instance.finished(status)
        except Exception as e:
            logger.error(
                f"Failed to log finished status with {logger_instance.__class__.__name__}: {e}"
            )

load_state_dict(state_dict)

Load logger states from a checkpoint.

Source code in optimus_dl/recipe/train/mixins/managers/logger_manager.py
def load_state_dict(self, state_dict):
    """Load logger states from a checkpoint."""
    self.previous_state = state_dict

log_metrics_to_loggers(metrics, step, group='train')

Dispatch metrics to all active loggers.

Parameters:

Name Type Description Default
metrics

Dictionary of metric values.

required
step int

Current iteration.

required
group str

Metric group name.

'train'
Source code in optimus_dl/recipe/train/mixins/managers/logger_manager.py
def log_metrics_to_loggers(self, metrics, step: int, group: str = "train"):
    """Dispatch metrics to all active loggers.

    Args:
        metrics: Dictionary of metric values.
        step: Current iteration.
        group: Metric group name.
    """
    start_time = time.perf_counter()
    if self.spent_logging is not None:
        metrics["ms_spent_logging"] = self.spent_logging * 1000

    for logger_instance in self.loggers or []:
        try:
            logger_instance.log_metrics(metrics, step, group)
        except Exception as e:
            logger.error(
                f"Failed to log metrics with {logger_instance.__class__.__name__}: {e}"
            )
    end_time = time.perf_counter()
    self.spent_logging = end_time - start_time

setup_loggers(experiment_name, full_config, logs_parent_path=None, start_iteration=None)

Initialize all loggers with experiment context.

Parameters:

Name Type Description Default
experiment_name str

Name of the experiment.

required
full_config dict

Complete training configuration dictionary.

required
logs_parent_path str | None

Optional filesystem path as a string under which logger-specific log files or run directories are created. Use this to log stdout / stderr if applicable

None
start_iteration int | None

Starting iteration number for the logging.

None
Source code in optimus_dl/recipe/train/mixins/managers/logger_manager.py
def setup_loggers(
    self,
    experiment_name: str,
    full_config: dict,
    logs_parent_path: str | None = None,
    start_iteration: int | None = None,
):
    """Initialize all loggers with experiment context.

    Args:
        experiment_name: Name of the experiment.
        full_config: Complete training configuration dictionary.
        logs_parent_path: Optional filesystem path as a string under which
            logger-specific log files or run directories are created.
            Use this to log stdout / stderr if applicable
        start_iteration: Starting iteration number for the logging.
    """
    for logger_instance in self.loggers or []:
        try:
            logger_instance.setup(
                experiment_name=experiment_name,
                config=full_config,
                logs_parent_path=logs_parent_path,
                start_iteration=start_iteration,
            )
        except Exception as e:
            logger.error(
                f"Failed to setup logger {logger_instance.__class__.__name__}: {e}"
            )

state_dict()

Collect state from all loggers for checkpointing.

Source code in optimus_dl/recipe/train/mixins/managers/logger_manager.py
def state_dict(self):
    """Collect state from all loggers for checkpointing."""
    return {
        logger_instance.cfg.id: logger_instance.state_dict()
        for logger_instance in self.loggers or []
    }

Modules and Sub-packages