Skip to content

evaluation_manager

optimus_dl.recipe.train.mixins.managers.evaluation_manager

Evaluation mixin for evaluation functionality.

Evaluator

Manager for running periodic evaluations during training.

Handles iterating over validation datasets, computing loss and other metrics, and aggregating results across distributed ranks.

Parameters:

Name Type Description Default
cfg EvaluatorConfig

Evaluator configuration.

required
eval_freq int

Frequency of evaluation runs (in iterations).

0
eval_iterations int | None

Max number of batches to process per evaluation dataset. If None or negative, processes the entire dataset (negative values are treated as unlimited).

None
eval_guaranteed_same_batches bool

If True, assumes all ranks will see the same number of batches, allowing for simpler stopping logic. If False, uses collective communication to determine when to stop if any rank exhausts its dataloader.

False
Source code in optimus_dl/recipe/train/mixins/managers/evaluation_manager.py
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
class Evaluator:
    """Manager for running periodic evaluations during training.

    Handles iterating over validation datasets, computing loss and other metrics,
    and aggregating results across distributed ranks.

    Args:
        cfg: Evaluator configuration.
        eval_freq: Frequency of evaluation runs (in iterations).
        eval_iterations: Max number of batches to process per evaluation dataset.
            If None or negative, processes the entire dataset (negative values are
            treated as unlimited).
        eval_guaranteed_same_batches: If True, assumes all ranks will see the same
            number of batches, allowing for simpler stopping logic. If False, uses
            collective communication to determine when to stop if any rank exhausts
            its dataloader.
    """

    def __init__(
        self,
        cfg: EvaluatorConfig,
        eval_freq: int = 0,
        eval_iterations: int | None = None,
        eval_guaranteed_same_batches: bool = False,
        eval_checkpointing: int | None = None,
        output_path: str | pathlib.Path | None = None,
        **kwargs: Any,
    ):
        self.cfg = cfg
        self.eval_freq = eval_freq
        self.eval_iterations = eval_iterations
        self.eval_guaranteed_same_batches = eval_guaranteed_same_batches
        self.eval_checkpointing = eval_checkpointing
        self.output_path = output_path
        self.eval_checkpoint_manager = None
        if output_path:
            self.eval_checkpoint_manager = EvaluationCheckpointManager(output_path)

    @contextlib.contextmanager
    def forward_context(self, device: torch.device):
        """Context manager for evaluation forward pass.

        Can be used to set up any necessary context (e.g., mixed precision) for the
        forward pass during evaluation.

        Returns:
            A context manager (e.g., from contextlib) that sets up the desired context.
        """
        if self.cfg.amp.enabled:
            amp_cfg = self.cfg.amp
            dtype = str_to_dtype(amp_cfg.dtype)
            amp_ctx = torch.autocast(device.type, dtype=dtype, enabled=amp_cfg.enabled)
            with amp_ctx:
                yield
        else:
            yield

    def should_run_evaluation(
        self,
        iteration: int,
        eval_data_dict: dict[str, EvalDataPipeline | None],
    ) -> bool:
        """Check if any of the evaluation datasets match the current iteration frequency.

        Args:
            iteration: Current training step.
            eval_data_dict: Dictionary mapping dataset names to eval data pipelines.

        Returns:
            True if at least one evaluation should run, False otherwise.
        """
        for eval_data in eval_data_dict.values():
            if eval_data is None:
                continue
            eval_freq = (
                eval_data.eval_freq
                if eval_data.eval_freq is not None
                else self.eval_freq
            )
            if eval_freq > 0 and iteration % eval_freq == 0:
                return True
        return False

    def run_evaluation_if_needed(
        self,
        iteration: int,
        model: BaseModel,
        criterion: BaseCriterion,
        eval_data_dict: dict[str, EvalDataPipeline],
        device: torch.device,
        collective: Collective | None = None,
        all_metrics_configs: dict[str, list[dict]] | None = None,
    ) -> None | dict:
        """Run evaluation if the current iteration matches the frequency for any dataset.

        Args:
            iteration: Current training step.
            model: The model to evaluate.
            criterion: The loss function.
            eval_data_dict: Dictionary mapping dataset names to dataloaders.
            collective: Distributed collective for metric aggregation.
            all_metrics_configs: Root metrics configuration from TrainConfig.

        Returns:
            Dictionary of computed metrics if evaluation ran, else None.
        """
        result = {}

        # deterministic order
        eval_data_dict_keys = sorted(eval_data_dict.keys())
        for eval_name in eval_data_dict_keys:
            eval_data = eval_data_dict[eval_name]

            max_iterations = (
                eval_data.eval_iterations
                if eval_data.eval_iterations is not None
                else self.eval_iterations
            )
            if max_iterations is not None and max_iterations < 0:
                max_iterations = None

            eval_freq = (
                eval_data.eval_freq
                if eval_data.eval_freq is not None
                else self.eval_freq
            )
            if eval_freq <= 0 or iteration % eval_freq != 0:
                continue

            try:
                result |= self.run_evaluation(
                    iteration=iteration,
                    model=model,
                    criterion=criterion,
                    eval_data_dict={eval_name: eval_data},
                    max_iterations=max_iterations,
                    collective=collective,
                    all_metrics_configs=all_metrics_configs,
                    show_progress=True,
                    device=device,
                )
            except Exception as e:
                logger.exception(
                    f"Error during evaluation of {eval_name} at iteration {iteration}: {e}"
                )
                raise

        if len(result) == 0:
            return None
        return result

    def run_evaluation(
        self,
        model: BaseModel,
        criterion: BaseCriterion,
        eval_data_dict: dict[str, EvalDataPipeline],
        device: torch.device,
        max_iterations: int | None = None,
        collective: Collective | None = None,
        all_metrics_configs: dict[str, list[dict]] | None = None,
        metrics_prefix: str = "eval",
        show_progress: bool = False,
        iteration: int | None = None,
    ):
        """Execute the evaluation loop for all provided datasets.

        Sets the model to eval mode, disables gradients, and runs the forward pass
        for each batch. Metrics are aggregated globally.

        Args:
            model: Model to evaluate.
            criterion: Loss function.
            eval_data_dict: Dictionary of {name: dataloader/DataPipeline}.
            max_iterations: Limit on number of batches.
            collective: Distributed collective.
            all_metrics_configs: Root metrics configuration mapping dataset names to configs.
            metrics_prefix: Prefix for metric groups (e.g., "eval" or "metrics").
            show_progress: Whether to show a progress bar.
            iteration: Current training iteration, used for naming checkpoints.

        Returns:
            Nested dictionary of results: {dataset_name: {metric_name: value}}.
        """
        model.eval()
        total_metrics = {}
        all_metrics_configs = all_metrics_configs or {}

        eval_data_dict_keys = sorted(eval_data_dict.keys())
        for eval_name in eval_data_dict_keys:
            eval_data = eval_data_dict[eval_name]
            max_iterations_local = (
                eval_data.eval_iterations
                if eval_data.eval_iterations is not None
                else max_iterations
            )
            guaranteed_same_batches_local = (
                eval_data.eval_guaranteed_same_batches
                if eval_data.eval_guaranteed_same_batches is not None
                else self.eval_guaranteed_same_batches
            )
            if max_iterations_local is not None and max_iterations_local < 0:
                max_iterations_local = None

            eval_checkpointing = (
                eval_data.eval_checkpointing
                if eval_data.eval_checkpointing is not None
                else self.eval_checkpointing
            )

            logger.info(
                f"Running evaluation {eval_name} for {max_iterations_local if max_iterations_local is not None else 'unlimited'} iterations (on each rank)"
            )

            # Handle both raw dataloader and DataPipeline object
            dataloader = getattr(eval_data, "dataloader", eval_data)

            engine = None
            requested_protocols = None
            dataset_metrics = all_metrics_configs.get(eval_name)
            if dataset_metrics:
                from optimus_dl.modules.metrics.engine import MetricEngine

                engine = MetricEngine(f"{metrics_prefix}/{eval_name}", dataset_metrics)
                requested_protocols = engine.required_external_protocols

            group_name = f"{metrics_prefix}/{eval_name}"
            with (
                torch.no_grad(),
                meters_group(group_name, log_freq=1, force_recreate=True),
            ):
                eval_iter = iter(dataloader)

                eval_iterations_processed = 0
                if self.eval_checkpoint_manager is not None and iteration is not None:
                    eval_iterations_processed = (
                        self.eval_checkpoint_manager.load_iteration_state(
                            iteration=iteration,
                            eval_name=eval_name,
                            group_name=group_name,
                            eval_iter=eval_iter,
                            collective=collective,
                        )
                    )

                log_event_start("perf/total_run")
                start_time = time.perf_counter()

                pbar = None
                if show_progress:
                    pbar = tqdm(
                        desc=f"Eval {eval_name}",
                        disable=(
                            collective is not None and not collective.is_local_master
                        ),
                        unit="batch",
                        total=(
                            max_iterations_local
                            if max_iterations_local is not None
                            else None
                        ),
                        initial=eval_iterations_processed,
                    )

                check_exhaustion = (
                    collective is not None and not guaranteed_same_batches_local
                )

                def should_stop(_iterations, max_iterations_local=max_iterations_local):
                    return (
                        max_iterations_local is not None
                        and max_iterations_local > 0
                        and _iterations >= max_iterations_local
                    )

                try:
                    # Consider we loaded state where one rank is already exhausted.
                    # We should check it before starting the loop to have the same collectives set.

                    stop_flag = None
                    if check_exhaustion:
                        assert collective is not None
                        stop_flag = torch.tensor(
                            [int(should_stop(eval_iterations_processed))],
                            device=collective.default_device,
                            dtype=torch.int32,
                        )
                        collective.all_reduce(
                            stop_flag,
                            op=Collective.ReduceOp.MAX,
                        )
                        if stop_flag.item() == 1:
                            logger.info(
                                "Some ranks have already finished evaluation before starting the loop, stopping immediately."
                            )
                            raise StopIteration

                    while not should_stop(eval_iterations_processed):
                        logger.debug(
                            f"Eval {eval_name}: Starting iteration {eval_iterations_processed}"
                        )
                        log_event_occurence("perf/full_iteration")
                        exhausted = False
                        try:
                            logger.debug(
                                f"Eval {eval_name}: Fetching batch for iteration {eval_iterations_processed}"
                            )
                            elapsed_batch_get, batch = measured_next(eval_iter)
                            info_once(logger, f"Batch has keys {batch.keys()}")
                        except StopIteration:
                            logger.debug(
                                f"Eval {eval_name}: Dataloader exhausted on this rank"
                            )
                            exhausted = True

                        if check_exhaustion:
                            assert collective is not None
                            assert stop_flag is not None
                            logger.debug(
                                f"Eval {eval_name}: Synchronizing exhaustion state (all_reduce MAX)"
                            )
                            stop_flag[0] = int(exhausted)
                            collective.all_reduce(
                                stop_flag,
                                op=Collective.ReduceOp.MAX,
                            )
                            if stop_flag.item() == 1:
                                # at least one rank is exhausted, stop evaluation
                                logger.info(
                                    f"Eval {eval_name}: At least one rank exhausted its dataloader, stopping evaluation."
                                )
                                raise StopIteration
                        else:
                            # If we are guaranteed that all ranks see the same number of batches,
                            # we can just stop when this rank is exhausted
                            if exhausted:
                                logger.info(
                                    f"Eval {eval_name}: Dataloader exhausted, stopping evaluation."
                                )
                                raise StopIteration

                        logger.debug(
                            f"Eval {eval_name}: Running forward pass for iteration {eval_iterations_processed}"
                        )
                        with self.forward_context(device=device):
                            loss, exposed = criterion(
                                model, batch, requested_protocols=requested_protocols
                            )

                        if engine:
                            logger.debug(f"Eval {eval_name}: Updating metric engine")
                            computed_data = exposed.copy()
                            computed_data["loss"] = loss
                            engine.update(
                                data=dict(model=model, batch=batch),
                                computed_data=computed_data,
                            )

                        log_summed("num_batches", lambda: 1)
                        log_averaged(
                            "perf/batch_get",
                            elapsed_batch_get,
                        )

                        eval_iterations_processed += 1
                        if pbar is not None:
                            pbar.update(1)

                        # Step metrics for each evaluation iteration
                        step_meters(f"{metrics_prefix}/{eval_name}")
                        logger.debug(
                            f"Eval {eval_name}: Finished iteration {eval_iterations_processed-1}"
                        )

                        if (
                            self.eval_checkpoint_manager is not None
                            and eval_checkpointing is not None
                            and eval_checkpointing > 0
                            and eval_iterations_processed % eval_checkpointing == 0
                        ):
                            assert (
                                iteration is not None
                            ), "Iteration must be provided for checkpointing"
                            self.eval_checkpoint_manager.save_iteration_state(
                                iteration=iteration,
                                eval_name=eval_name,
                                dataloader_state=eval_iter.state_dict(),
                                group_name=group_name,
                                collective=collective,
                                eval_iterations_processed=eval_iterations_processed,
                            )
                            logger.info(
                                f"Saved evaluation metrics checkpoint at iteration {eval_iterations_processed}"
                            )

                except StopIteration:
                    pass
                finally:
                    if pbar is not None:
                        pbar.refresh()
                        pbar.close()

                total_time = time.perf_counter() - start_time
                log_event_end("perf/total_run")

            logger.debug(f"Eval {eval_name}: Computing aggregated meters")
            eval_metrics = compute_meters(
                f"{metrics_prefix}/{eval_name}",
                aggregate=True,
                collective=collective,
            )

            if engine:
                eval_metrics = engine.compute(eval_metrics)

            # Add basic performance stats
            eval_metrics["perf/total_run_ms"] = total_time * 1000
            if eval_iterations_processed > 0:
                eval_metrics["perf/ms_per_batch"] = (
                    total_time / eval_iterations_processed
                ) * 1000

            logger.info(
                f"Finished eval {eval_name}: {eval_metrics} in {total_time:.1f}s"
            )
            total_metrics[f"{metrics_prefix}/{eval_name}"] = eval_metrics
        return total_metrics

    def cleanup_all_eval_checkpoints(
        self, iteration: int | None = None, exclude_iteration: int | None = None
    ) -> None:
        """Cleanup evaluation checkpoints.

        If iteration is provided, cleans up checkpoints for that iteration only.
        If iteration is None, cleans up ALL evaluation checkpoints in the output path,
        optionally excluding one specific iteration.
        """
        if self.eval_checkpoint_manager is not None:
            self.eval_checkpoint_manager.cleanup(
                iteration=iteration, exclude_iteration=exclude_iteration
            )
        else:
            logger.debug(
                "No evaluation checkpoint manager initialized, skipping cleanup."
            )

cleanup_all_eval_checkpoints(iteration=None, exclude_iteration=None)

Cleanup evaluation checkpoints.

If iteration is provided, cleans up checkpoints for that iteration only. If iteration is None, cleans up ALL evaluation checkpoints in the output path, optionally excluding one specific iteration.

Source code in optimus_dl/recipe/train/mixins/managers/evaluation_manager.py
def cleanup_all_eval_checkpoints(
    self, iteration: int | None = None, exclude_iteration: int | None = None
) -> None:
    """Cleanup evaluation checkpoints.

    If iteration is provided, cleans up checkpoints for that iteration only.
    If iteration is None, cleans up ALL evaluation checkpoints in the output path,
    optionally excluding one specific iteration.
    """
    if self.eval_checkpoint_manager is not None:
        self.eval_checkpoint_manager.cleanup(
            iteration=iteration, exclude_iteration=exclude_iteration
        )
    else:
        logger.debug(
            "No evaluation checkpoint manager initialized, skipping cleanup."
        )

forward_context(device)

Context manager for evaluation forward pass.

Can be used to set up any necessary context (e.g., mixed precision) for the forward pass during evaluation.

Returns:

Type Description

A context manager (e.g., from contextlib) that sets up the desired context.

Source code in optimus_dl/recipe/train/mixins/managers/evaluation_manager.py
@contextlib.contextmanager
def forward_context(self, device: torch.device):
    """Context manager for evaluation forward pass.

    Can be used to set up any necessary context (e.g., mixed precision) for the
    forward pass during evaluation.

    Returns:
        A context manager (e.g., from contextlib) that sets up the desired context.
    """
    if self.cfg.amp.enabled:
        amp_cfg = self.cfg.amp
        dtype = str_to_dtype(amp_cfg.dtype)
        amp_ctx = torch.autocast(device.type, dtype=dtype, enabled=amp_cfg.enabled)
        with amp_ctx:
            yield
    else:
        yield

run_evaluation(model, criterion, eval_data_dict, device, max_iterations=None, collective=None, all_metrics_configs=None, metrics_prefix='eval', show_progress=False, iteration=None)

Execute the evaluation loop for all provided datasets.

Sets the model to eval mode, disables gradients, and runs the forward pass for each batch. Metrics are aggregated globally.

Parameters:

Name Type Description Default
model BaseModel

Model to evaluate.

required
criterion BaseCriterion

Loss function.

required
eval_data_dict dict[str, EvalDataPipeline]

Dictionary of {name: dataloader/DataPipeline}.

required
max_iterations int | None

Limit on number of batches.

None
collective Collective | None

Distributed collective.

None
all_metrics_configs dict[str, list[dict]] | None

Root metrics configuration mapping dataset names to configs.

None
metrics_prefix str

Prefix for metric groups (e.g., "eval" or "metrics").

'eval'
show_progress bool

Whether to show a progress bar.

False
iteration int | None

Current training iteration, used for naming checkpoints.

None

Returns:

Type Description

Nested dictionary of results: {dataset_name: {metric_name: value}}.

Source code in optimus_dl/recipe/train/mixins/managers/evaluation_manager.py
def run_evaluation(
    self,
    model: BaseModel,
    criterion: BaseCriterion,
    eval_data_dict: dict[str, EvalDataPipeline],
    device: torch.device,
    max_iterations: int | None = None,
    collective: Collective | None = None,
    all_metrics_configs: dict[str, list[dict]] | None = None,
    metrics_prefix: str = "eval",
    show_progress: bool = False,
    iteration: int | None = None,
):
    """Execute the evaluation loop for all provided datasets.

    Sets the model to eval mode, disables gradients, and runs the forward pass
    for each batch. Metrics are aggregated globally.

    Args:
        model: Model to evaluate.
        criterion: Loss function.
        eval_data_dict: Dictionary of {name: dataloader/DataPipeline}.
        max_iterations: Limit on number of batches.
        collective: Distributed collective.
        all_metrics_configs: Root metrics configuration mapping dataset names to configs.
        metrics_prefix: Prefix for metric groups (e.g., "eval" or "metrics").
        show_progress: Whether to show a progress bar.
        iteration: Current training iteration, used for naming checkpoints.

    Returns:
        Nested dictionary of results: {dataset_name: {metric_name: value}}.
    """
    model.eval()
    total_metrics = {}
    all_metrics_configs = all_metrics_configs or {}

    eval_data_dict_keys = sorted(eval_data_dict.keys())
    for eval_name in eval_data_dict_keys:
        eval_data = eval_data_dict[eval_name]
        max_iterations_local = (
            eval_data.eval_iterations
            if eval_data.eval_iterations is not None
            else max_iterations
        )
        guaranteed_same_batches_local = (
            eval_data.eval_guaranteed_same_batches
            if eval_data.eval_guaranteed_same_batches is not None
            else self.eval_guaranteed_same_batches
        )
        if max_iterations_local is not None and max_iterations_local < 0:
            max_iterations_local = None

        eval_checkpointing = (
            eval_data.eval_checkpointing
            if eval_data.eval_checkpointing is not None
            else self.eval_checkpointing
        )

        logger.info(
            f"Running evaluation {eval_name} for {max_iterations_local if max_iterations_local is not None else 'unlimited'} iterations (on each rank)"
        )

        # Handle both raw dataloader and DataPipeline object
        dataloader = getattr(eval_data, "dataloader", eval_data)

        engine = None
        requested_protocols = None
        dataset_metrics = all_metrics_configs.get(eval_name)
        if dataset_metrics:
            from optimus_dl.modules.metrics.engine import MetricEngine

            engine = MetricEngine(f"{metrics_prefix}/{eval_name}", dataset_metrics)
            requested_protocols = engine.required_external_protocols

        group_name = f"{metrics_prefix}/{eval_name}"
        with (
            torch.no_grad(),
            meters_group(group_name, log_freq=1, force_recreate=True),
        ):
            eval_iter = iter(dataloader)

            eval_iterations_processed = 0
            if self.eval_checkpoint_manager is not None and iteration is not None:
                eval_iterations_processed = (
                    self.eval_checkpoint_manager.load_iteration_state(
                        iteration=iteration,
                        eval_name=eval_name,
                        group_name=group_name,
                        eval_iter=eval_iter,
                        collective=collective,
                    )
                )

            log_event_start("perf/total_run")
            start_time = time.perf_counter()

            pbar = None
            if show_progress:
                pbar = tqdm(
                    desc=f"Eval {eval_name}",
                    disable=(
                        collective is not None and not collective.is_local_master
                    ),
                    unit="batch",
                    total=(
                        max_iterations_local
                        if max_iterations_local is not None
                        else None
                    ),
                    initial=eval_iterations_processed,
                )

            check_exhaustion = (
                collective is not None and not guaranteed_same_batches_local
            )

            def should_stop(_iterations, max_iterations_local=max_iterations_local):
                return (
                    max_iterations_local is not None
                    and max_iterations_local > 0
                    and _iterations >= max_iterations_local
                )

            try:
                # Consider we loaded state where one rank is already exhausted.
                # We should check it before starting the loop to have the same collectives set.

                stop_flag = None
                if check_exhaustion:
                    assert collective is not None
                    stop_flag = torch.tensor(
                        [int(should_stop(eval_iterations_processed))],
                        device=collective.default_device,
                        dtype=torch.int32,
                    )
                    collective.all_reduce(
                        stop_flag,
                        op=Collective.ReduceOp.MAX,
                    )
                    if stop_flag.item() == 1:
                        logger.info(
                            "Some ranks have already finished evaluation before starting the loop, stopping immediately."
                        )
                        raise StopIteration

                while not should_stop(eval_iterations_processed):
                    logger.debug(
                        f"Eval {eval_name}: Starting iteration {eval_iterations_processed}"
                    )
                    log_event_occurence("perf/full_iteration")
                    exhausted = False
                    try:
                        logger.debug(
                            f"Eval {eval_name}: Fetching batch for iteration {eval_iterations_processed}"
                        )
                        elapsed_batch_get, batch = measured_next(eval_iter)
                        info_once(logger, f"Batch has keys {batch.keys()}")
                    except StopIteration:
                        logger.debug(
                            f"Eval {eval_name}: Dataloader exhausted on this rank"
                        )
                        exhausted = True

                    if check_exhaustion:
                        assert collective is not None
                        assert stop_flag is not None
                        logger.debug(
                            f"Eval {eval_name}: Synchronizing exhaustion state (all_reduce MAX)"
                        )
                        stop_flag[0] = int(exhausted)
                        collective.all_reduce(
                            stop_flag,
                            op=Collective.ReduceOp.MAX,
                        )
                        if stop_flag.item() == 1:
                            # at least one rank is exhausted, stop evaluation
                            logger.info(
                                f"Eval {eval_name}: At least one rank exhausted its dataloader, stopping evaluation."
                            )
                            raise StopIteration
                    else:
                        # If we are guaranteed that all ranks see the same number of batches,
                        # we can just stop when this rank is exhausted
                        if exhausted:
                            logger.info(
                                f"Eval {eval_name}: Dataloader exhausted, stopping evaluation."
                            )
                            raise StopIteration

                    logger.debug(
                        f"Eval {eval_name}: Running forward pass for iteration {eval_iterations_processed}"
                    )
                    with self.forward_context(device=device):
                        loss, exposed = criterion(
                            model, batch, requested_protocols=requested_protocols
                        )

                    if engine:
                        logger.debug(f"Eval {eval_name}: Updating metric engine")
                        computed_data = exposed.copy()
                        computed_data["loss"] = loss
                        engine.update(
                            data=dict(model=model, batch=batch),
                            computed_data=computed_data,
                        )

                    log_summed("num_batches", lambda: 1)
                    log_averaged(
                        "perf/batch_get",
                        elapsed_batch_get,
                    )

                    eval_iterations_processed += 1
                    if pbar is not None:
                        pbar.update(1)

                    # Step metrics for each evaluation iteration
                    step_meters(f"{metrics_prefix}/{eval_name}")
                    logger.debug(
                        f"Eval {eval_name}: Finished iteration {eval_iterations_processed-1}"
                    )

                    if (
                        self.eval_checkpoint_manager is not None
                        and eval_checkpointing is not None
                        and eval_checkpointing > 0
                        and eval_iterations_processed % eval_checkpointing == 0
                    ):
                        assert (
                            iteration is not None
                        ), "Iteration must be provided for checkpointing"
                        self.eval_checkpoint_manager.save_iteration_state(
                            iteration=iteration,
                            eval_name=eval_name,
                            dataloader_state=eval_iter.state_dict(),
                            group_name=group_name,
                            collective=collective,
                            eval_iterations_processed=eval_iterations_processed,
                        )
                        logger.info(
                            f"Saved evaluation metrics checkpoint at iteration {eval_iterations_processed}"
                        )

            except StopIteration:
                pass
            finally:
                if pbar is not None:
                    pbar.refresh()
                    pbar.close()

            total_time = time.perf_counter() - start_time
            log_event_end("perf/total_run")

        logger.debug(f"Eval {eval_name}: Computing aggregated meters")
        eval_metrics = compute_meters(
            f"{metrics_prefix}/{eval_name}",
            aggregate=True,
            collective=collective,
        )

        if engine:
            eval_metrics = engine.compute(eval_metrics)

        # Add basic performance stats
        eval_metrics["perf/total_run_ms"] = total_time * 1000
        if eval_iterations_processed > 0:
            eval_metrics["perf/ms_per_batch"] = (
                total_time / eval_iterations_processed
            ) * 1000

        logger.info(
            f"Finished eval {eval_name}: {eval_metrics} in {total_time:.1f}s"
        )
        total_metrics[f"{metrics_prefix}/{eval_name}"] = eval_metrics
    return total_metrics

run_evaluation_if_needed(iteration, model, criterion, eval_data_dict, device, collective=None, all_metrics_configs=None)

Run evaluation if the current iteration matches the frequency for any dataset.

Parameters:

Name Type Description Default
iteration int

Current training step.

required
model BaseModel

The model to evaluate.

required
criterion BaseCriterion

The loss function.

required
eval_data_dict dict[str, EvalDataPipeline]

Dictionary mapping dataset names to dataloaders.

required
collective Collective | None

Distributed collective for metric aggregation.

None
all_metrics_configs dict[str, list[dict]] | None

Root metrics configuration from TrainConfig.

None

Returns:

Type Description
None | dict

Dictionary of computed metrics if evaluation ran, else None.

Source code in optimus_dl/recipe/train/mixins/managers/evaluation_manager.py
def run_evaluation_if_needed(
    self,
    iteration: int,
    model: BaseModel,
    criterion: BaseCriterion,
    eval_data_dict: dict[str, EvalDataPipeline],
    device: torch.device,
    collective: Collective | None = None,
    all_metrics_configs: dict[str, list[dict]] | None = None,
) -> None | dict:
    """Run evaluation if the current iteration matches the frequency for any dataset.

    Args:
        iteration: Current training step.
        model: The model to evaluate.
        criterion: The loss function.
        eval_data_dict: Dictionary mapping dataset names to dataloaders.
        collective: Distributed collective for metric aggregation.
        all_metrics_configs: Root metrics configuration from TrainConfig.

    Returns:
        Dictionary of computed metrics if evaluation ran, else None.
    """
    result = {}

    # deterministic order
    eval_data_dict_keys = sorted(eval_data_dict.keys())
    for eval_name in eval_data_dict_keys:
        eval_data = eval_data_dict[eval_name]

        max_iterations = (
            eval_data.eval_iterations
            if eval_data.eval_iterations is not None
            else self.eval_iterations
        )
        if max_iterations is not None and max_iterations < 0:
            max_iterations = None

        eval_freq = (
            eval_data.eval_freq
            if eval_data.eval_freq is not None
            else self.eval_freq
        )
        if eval_freq <= 0 or iteration % eval_freq != 0:
            continue

        try:
            result |= self.run_evaluation(
                iteration=iteration,
                model=model,
                criterion=criterion,
                eval_data_dict={eval_name: eval_data},
                max_iterations=max_iterations,
                collective=collective,
                all_metrics_configs=all_metrics_configs,
                show_progress=True,
                device=device,
            )
        except Exception as e:
            logger.exception(
                f"Error during evaluation of {eval_name} at iteration {iteration}: {e}"
            )
            raise

    if len(result) == 0:
        return None
    return result

should_run_evaluation(iteration, eval_data_dict)

Check if any of the evaluation datasets match the current iteration frequency.

Parameters:

Name Type Description Default
iteration int

Current training step.

required
eval_data_dict dict[str, EvalDataPipeline | None]

Dictionary mapping dataset names to eval data pipelines.

required

Returns:

Type Description
bool

True if at least one evaluation should run, False otherwise.

Source code in optimus_dl/recipe/train/mixins/managers/evaluation_manager.py
def should_run_evaluation(
    self,
    iteration: int,
    eval_data_dict: dict[str, EvalDataPipeline | None],
) -> bool:
    """Check if any of the evaluation datasets match the current iteration frequency.

    Args:
        iteration: Current training step.
        eval_data_dict: Dictionary mapping dataset names to eval data pipelines.

    Returns:
        True if at least one evaluation should run, False otherwise.
    """
    for eval_data in eval_data_dict.values():
        if eval_data is None:
            continue
        eval_freq = (
            eval_data.eval_freq
            if eval_data.eval_freq is not None
            else self.eval_freq
        )
        if eval_freq > 0 and iteration % eval_freq == 0:
            return True
    return False

EvaluatorConfig dataclass

Bases: RegistryConfig

Configuration for the Evaluator.

Parameters:

Name Type Description Default
amp AmpConfig

AmpConfig(enabled: bool = False, dtype: str = 'torch.bfloat16', enable_scaler: bool = '\({eval: \'"\){.dtype}" == "torch.float16"\'}', init_scale: float = 65536, growth_factor: float = 2.0, backoff_factor: float = 0.5, growth_interval: int = 2000)

<dynamic>
Source code in optimus_dl/recipe/train/mixins/managers/evaluation_manager.py
@dataclass
class EvaluatorConfig(RegistryConfig):
    """Configuration for the Evaluator."""

    amp: AmpConfig = field(default_factory=AmpConfig)