Skip to content

Index

optimus_dl.recipe.train.mixins.managers

Evaluator

Manager for running periodic evaluations during training.

Handles iterating over validation datasets, computing loss and other metrics, and aggregating results across distributed ranks.

Parameters:

Name Type Description Default
cfg EvaluatorConfig

Evaluator configuration.

required
eval_freq int

Frequency of evaluation runs (in iterations).

0
eval_iterations int | None

Max number of batches to process per evaluation dataset. If None, processes the entire dataset.

None
Source code in optimus_dl/recipe/train/mixins/managers/evaluation_manager.py
class Evaluator:
    """Manager for running periodic evaluations during training.

    Handles iterating over validation datasets, computing loss and other metrics,
    and aggregating results across distributed ranks.

    Args:
        cfg: Evaluator configuration.
        eval_freq: Frequency of evaluation runs (in iterations).
        eval_iterations: Max number of batches to process per evaluation dataset.
            If None, processes the entire dataset.
    """

    def __init__(
        self,
        cfg: EvaluatorConfig,
        eval_freq: int = 0,
        eval_iterations: int | None = None,
        **kwargs: Any,
    ):
        self.cfg = cfg
        self.eval_freq = eval_freq
        self.eval_iterations = eval_iterations

    def run_evaluation_if_needed(
        self,
        iteration: int,
        model: BaseModel,
        criterion: BaseCriterion,
        eval_data: dict[str, EvalDataPipeline],
        collective: Any = None,
        all_metrics_configs: dict[str, list[dict]] | None = None,
    ) -> None | dict:
        """Run evaluation if the current iteration matches the frequency.

        Args:
            iteration: Current training step.
            model: The model to evaluate.
            criterion: The loss function.
            eval_data: Dictionary mapping dataset names to dataloaders.
            collective: Distributed collective for metric aggregation.
            all_metrics_configs: Root metrics configuration from TrainConfig.

        Returns:
            Dictionary of computed metrics if evaluation ran, else None.
        """
        result = {}
        for k, v in eval_data.items():
            max_iterations = (
                v.eval_iterations
                if v.eval_iterations is not None
                else self.eval_iterations
            )
            eval_freq = v.eval_freq if v.eval_freq is not None else self.eval_freq
            if eval_freq <= 0 or iteration % eval_freq != 0:
                continue

            try:
                result |= self.run_evaluation(
                    model=model,
                    criterion=criterion,
                    eval_data_dict={k: v},
                    max_iterations=max_iterations,
                    collective=collective,
                    all_metrics_configs=all_metrics_configs,
                )
            except Exception:
                logger.exception(f"Evaluation for {k} failed.")

        if len(result) == 0:
            return None
        return result

    def run_evaluation(
        self,
        model: BaseModel,
        criterion: BaseCriterion,
        eval_data_dict: dict,
        max_iterations: int | None = None,
        collective: Any = None,
        all_metrics_configs: dict[str, list[dict]] | None = None,
        metrics_prefix: str = "eval",
        show_progress: bool = False,
    ):
        """Execute the evaluation loop for all provided datasets.

        Sets the model to eval mode, disables gradients, and runs the forward pass
        for each batch. Metrics are aggregated globally.

        Args:
            model: Model to evaluate.
            criterion: Loss function.
            eval_data_dict: Dictionary of {name: dataloader/DataPipeline}.
            max_iterations: Limit on number of batches.
            collective: Distributed collective.
            all_metrics_configs: Root metrics configuration mapping dataset names to configs.
            metrics_prefix: Prefix for metric groups (e.g., "eval" or "metrics").
            show_progress: Whether to show a progress bar.

        Returns:
            Nested dictionary of results: {dataset_name: {metric_name: value}}.
        """
        model.eval()
        total_metrics = {}
        all_metrics_configs = all_metrics_configs or {}

        for eval_name, eval_data in eval_data_dict.items():
            max_iterations_local = (
                eval_data.eval_iterations
                if eval_data.eval_iterations is not None
                else max_iterations
            )
            logger.info(f"Running evaluation {eval_name}")

            # Handle both raw dataloader and DataPipeline object
            dataloader = getattr(eval_data, "dataloader", eval_data)

            engine = None
            requested_protocols = None
            dataset_metrics = all_metrics_configs.get(eval_name)
            if dataset_metrics:
                from optimus_dl.modules.metrics.engine import MetricEngine

                engine = MetricEngine(f"{metrics_prefix}/{eval_name}", dataset_metrics)
                requested_protocols = engine.required_external_protocols

            with (
                torch.no_grad(),
                meters_group(
                    f"{metrics_prefix}/{eval_name}", log_freq=1, force_recreate=True
                ),
            ):
                log_event_start("perf/total_run")
                start_time = time.perf_counter()

                eval_iter = iter(dataloader)
                iterations = 0

                pbar = None
                if show_progress:
                    pbar = tqdm(
                        desc=f"Eval {eval_name}",
                        disable=collective is not None
                        and not collective.is_local_master,
                        unit="batch",
                        total=max_iterations_local,
                    )

                try:
                    while (
                        max_iterations_local is None
                        or max_iterations_local < 0
                        or iterations < max_iterations_local
                    ):
                        log_event_occurence("perf/full_iteration")

                        elapsed_batch_get, batch = measured_next(eval_iter)
                        loss, exposed = criterion(
                            model, batch, requested_protocols=requested_protocols
                        )

                        if engine:
                            computed_data = exposed.copy()
                            computed_data["loss"] = loss
                            engine.update(
                                data=dict(model=model, batch=batch),
                                computed_data=computed_data,
                            )

                        log_summed("num_batches", lambda: 1)
                        log_averaged(
                            "perf/batch_get",
                            elapsed_batch_get,
                        )

                        iterations += 1
                        if pbar:
                            pbar.update(1)

                        # Step metrics for each evaluation iteration
                        step_meters(f"{metrics_prefix}/{eval_name}")

                except StopIteration:
                    pass
                finally:
                    if pbar:
                        pbar.close()

                total_time = time.perf_counter() - start_time
                log_event_end("perf/total_run")

            eval_metrics = compute_meters(
                f"{metrics_prefix}/{eval_name}",
                aggregate=True,
                collective=collective,
            )

            if engine:
                eval_metrics = engine.compute(eval_metrics)

            # Add basic performance stats
            eval_metrics["perf/total_run_ms"] = total_time * 1000
            if iterations > 0:
                eval_metrics["perf/ms_per_batch"] = (total_time / iterations) * 1000

            logger.info(f"Finished eval {eval_name}: {eval_metrics}")
            total_metrics[f"{metrics_prefix}/{eval_name}"] = eval_metrics
        return total_metrics

run_evaluation(model, criterion, eval_data_dict, max_iterations=None, collective=None, all_metrics_configs=None, metrics_prefix='eval', show_progress=False)

Execute the evaluation loop for all provided datasets.

Sets the model to eval mode, disables gradients, and runs the forward pass for each batch. Metrics are aggregated globally.

Parameters:

Name Type Description Default
model BaseModel

Model to evaluate.

required
criterion BaseCriterion

Loss function.

required
eval_data_dict dict

Dictionary of {name: dataloader/DataPipeline}.

required
max_iterations int | None

Limit on number of batches.

None
collective Any

Distributed collective.

None
all_metrics_configs dict[str, list[dict]] | None

Root metrics configuration mapping dataset names to configs.

None
metrics_prefix str

Prefix for metric groups (e.g., "eval" or "metrics").

'eval'
show_progress bool

Whether to show a progress bar.

False

Returns:

Type Description

Nested dictionary of results: {dataset_name: {metric_name: value}}.

Source code in optimus_dl/recipe/train/mixins/managers/evaluation_manager.py
def run_evaluation(
    self,
    model: BaseModel,
    criterion: BaseCriterion,
    eval_data_dict: dict,
    max_iterations: int | None = None,
    collective: Any = None,
    all_metrics_configs: dict[str, list[dict]] | None = None,
    metrics_prefix: str = "eval",
    show_progress: bool = False,
):
    """Execute the evaluation loop for all provided datasets.

    Sets the model to eval mode, disables gradients, and runs the forward pass
    for each batch. Metrics are aggregated globally.

    Args:
        model: Model to evaluate.
        criterion: Loss function.
        eval_data_dict: Dictionary of {name: dataloader/DataPipeline}.
        max_iterations: Limit on number of batches.
        collective: Distributed collective.
        all_metrics_configs: Root metrics configuration mapping dataset names to configs.
        metrics_prefix: Prefix for metric groups (e.g., "eval" or "metrics").
        show_progress: Whether to show a progress bar.

    Returns:
        Nested dictionary of results: {dataset_name: {metric_name: value}}.
    """
    model.eval()
    total_metrics = {}
    all_metrics_configs = all_metrics_configs or {}

    for eval_name, eval_data in eval_data_dict.items():
        max_iterations_local = (
            eval_data.eval_iterations
            if eval_data.eval_iterations is not None
            else max_iterations
        )
        logger.info(f"Running evaluation {eval_name}")

        # Handle both raw dataloader and DataPipeline object
        dataloader = getattr(eval_data, "dataloader", eval_data)

        engine = None
        requested_protocols = None
        dataset_metrics = all_metrics_configs.get(eval_name)
        if dataset_metrics:
            from optimus_dl.modules.metrics.engine import MetricEngine

            engine = MetricEngine(f"{metrics_prefix}/{eval_name}", dataset_metrics)
            requested_protocols = engine.required_external_protocols

        with (
            torch.no_grad(),
            meters_group(
                f"{metrics_prefix}/{eval_name}", log_freq=1, force_recreate=True
            ),
        ):
            log_event_start("perf/total_run")
            start_time = time.perf_counter()

            eval_iter = iter(dataloader)
            iterations = 0

            pbar = None
            if show_progress:
                pbar = tqdm(
                    desc=f"Eval {eval_name}",
                    disable=collective is not None
                    and not collective.is_local_master,
                    unit="batch",
                    total=max_iterations_local,
                )

            try:
                while (
                    max_iterations_local is None
                    or max_iterations_local < 0
                    or iterations < max_iterations_local
                ):
                    log_event_occurence("perf/full_iteration")

                    elapsed_batch_get, batch = measured_next(eval_iter)
                    loss, exposed = criterion(
                        model, batch, requested_protocols=requested_protocols
                    )

                    if engine:
                        computed_data = exposed.copy()
                        computed_data["loss"] = loss
                        engine.update(
                            data=dict(model=model, batch=batch),
                            computed_data=computed_data,
                        )

                    log_summed("num_batches", lambda: 1)
                    log_averaged(
                        "perf/batch_get",
                        elapsed_batch_get,
                    )

                    iterations += 1
                    if pbar:
                        pbar.update(1)

                    # Step metrics for each evaluation iteration
                    step_meters(f"{metrics_prefix}/{eval_name}")

            except StopIteration:
                pass
            finally:
                if pbar:
                    pbar.close()

            total_time = time.perf_counter() - start_time
            log_event_end("perf/total_run")

        eval_metrics = compute_meters(
            f"{metrics_prefix}/{eval_name}",
            aggregate=True,
            collective=collective,
        )

        if engine:
            eval_metrics = engine.compute(eval_metrics)

        # Add basic performance stats
        eval_metrics["perf/total_run_ms"] = total_time * 1000
        if iterations > 0:
            eval_metrics["perf/ms_per_batch"] = (total_time / iterations) * 1000

        logger.info(f"Finished eval {eval_name}: {eval_metrics}")
        total_metrics[f"{metrics_prefix}/{eval_name}"] = eval_metrics
    return total_metrics

run_evaluation_if_needed(iteration, model, criterion, eval_data, collective=None, all_metrics_configs=None)

Run evaluation if the current iteration matches the frequency.

Parameters:

Name Type Description Default
iteration int

Current training step.

required
model BaseModel

The model to evaluate.

required
criterion BaseCriterion

The loss function.

required
eval_data dict[str, EvalDataPipeline]

Dictionary mapping dataset names to dataloaders.

required
collective Any

Distributed collective for metric aggregation.

None
all_metrics_configs dict[str, list[dict]] | None

Root metrics configuration from TrainConfig.

None

Returns:

Type Description
None | dict

Dictionary of computed metrics if evaluation ran, else None.

Source code in optimus_dl/recipe/train/mixins/managers/evaluation_manager.py
def run_evaluation_if_needed(
    self,
    iteration: int,
    model: BaseModel,
    criterion: BaseCriterion,
    eval_data: dict[str, EvalDataPipeline],
    collective: Any = None,
    all_metrics_configs: dict[str, list[dict]] | None = None,
) -> None | dict:
    """Run evaluation if the current iteration matches the frequency.

    Args:
        iteration: Current training step.
        model: The model to evaluate.
        criterion: The loss function.
        eval_data: Dictionary mapping dataset names to dataloaders.
        collective: Distributed collective for metric aggregation.
        all_metrics_configs: Root metrics configuration from TrainConfig.

    Returns:
        Dictionary of computed metrics if evaluation ran, else None.
    """
    result = {}
    for k, v in eval_data.items():
        max_iterations = (
            v.eval_iterations
            if v.eval_iterations is not None
            else self.eval_iterations
        )
        eval_freq = v.eval_freq if v.eval_freq is not None else self.eval_freq
        if eval_freq <= 0 or iteration % eval_freq != 0:
            continue

        try:
            result |= self.run_evaluation(
                model=model,
                criterion=criterion,
                eval_data_dict={k: v},
                max_iterations=max_iterations,
                collective=collective,
                all_metrics_configs=all_metrics_configs,
            )
        except Exception:
            logger.exception(f"Evaluation for {k} failed.")

    if len(result) == 0:
        return None
    return result

LoggerManager

Manager for multiple metrics loggers.

This class instantiates and orchestrates a list of logging backends (e.g., JSONL, WandB). It provides a unified interface for setting up, logging to, and closing all configured loggers.

Parameters:

Name Type Description Default
cfg LoggerManagerConfig

Manager configuration.

required
loggers_config list[MetricsLoggerConfig] | None

List of configurations for individual loggers.

required
Source code in optimus_dl/recipe/train/mixins/managers/logger_manager.py
class LoggerManager:
    """Manager for multiple metrics loggers.

    This class instantiates and orchestrates a list of logging backends (e.g.,
    JSONL, WandB). It provides a unified interface for setting up, logging to,
    and closing all configured loggers.

    Args:
        cfg: Manager configuration.
        loggers_config: List of configurations for individual loggers.
    """

    def __init__(
        self,
        cfg: LoggerManagerConfig,
        loggers_config: list[MetricsLoggerConfig] | None,
        **kwargs: Any,
    ):
        self.loggers_config = loggers_config
        self.previous_state = {}
        self.loggers: list[BaseMetricsLogger] | None = None

    def build_loggers(self, **kwargs):
        """Instantiate all configured loggers.

        Uses the registry to build logger instances. If previous state is available
        (from a checkpoint), it is passed to the logger builders for resumption.

        Returns:
            List of active logger instances.
        """
        if self.loggers_config is None:
            logger.info("No loggers configuration found, metrics logging disabled")
            return
        assert self.loggers is None, "Loggers already built"

        loggers = []
        for logger_config in self.loggers_config:
            try:
                logger_instance = build(
                    "metrics_logger",
                    logger_config,
                    state_dict=self.previous_state.get(logger_config.id),
                    **kwargs,
                )
                loggers.append(logger_instance)
                logger.info(f"Built logger: {logger_instance.__class__.__name__}")
            except Exception as e:
                logger.error(f"Failed to build logger from config {logger_config}: {e}")
                raise

        self.loggers = loggers

    def setup_loggers(self, experiment_name: str, full_config: dict):
        """Initialize all loggers with experiment context.

        Args:
            experiment_name: Name of the experiment.
            full_config: Complete training configuration dictionary.
        """
        for logger_instance in self.loggers or []:
            try:
                logger_instance.setup(experiment_name, full_config)
            except Exception as e:
                logger.error(
                    f"Failed to setup logger {logger_instance.__class__.__name__}: {e}"
                )

    def log_metrics_to_loggers(self, metrics, step: int, group: str = "train"):
        """Dispatch metrics to all active loggers.

        Args:
            metrics: Dictionary of metric values.
            step: Current iteration.
            group: Metric group name.
        """
        for logger_instance in self.loggers or []:
            try:
                logger_instance.log_metrics(metrics, step, group)
            except Exception as e:
                logger.error(
                    f"Failed to log metrics with {logger_instance.__class__.__name__}: {e}"
                )

    def close_loggers(self):
        """Clean up all loggers."""
        for logger_instance in self.loggers or []:
            try:
                logger_instance.close()
            except Exception as e:
                logger.error(
                    f"Failed to close logger {logger_instance.__class__.__name__}: {e}"
                )

    def state_dict(self):
        """Collect state from all loggers for checkpointing."""
        return {
            logger_instance.cfg.id: logger_instance.state_dict()
            for logger_instance in self.loggers or []
        }

    def load_state_dict(self, state_dict):
        """Load logger states from a checkpoint."""
        self.previous_state = state_dict

build_loggers(**kwargs)

Instantiate all configured loggers.

Uses the registry to build logger instances. If previous state is available (from a checkpoint), it is passed to the logger builders for resumption.

Returns:

Type Description

List of active logger instances.

Source code in optimus_dl/recipe/train/mixins/managers/logger_manager.py
def build_loggers(self, **kwargs):
    """Instantiate all configured loggers.

    Uses the registry to build logger instances. If previous state is available
    (from a checkpoint), it is passed to the logger builders for resumption.

    Returns:
        List of active logger instances.
    """
    if self.loggers_config is None:
        logger.info("No loggers configuration found, metrics logging disabled")
        return
    assert self.loggers is None, "Loggers already built"

    loggers = []
    for logger_config in self.loggers_config:
        try:
            logger_instance = build(
                "metrics_logger",
                logger_config,
                state_dict=self.previous_state.get(logger_config.id),
                **kwargs,
            )
            loggers.append(logger_instance)
            logger.info(f"Built logger: {logger_instance.__class__.__name__}")
        except Exception as e:
            logger.error(f"Failed to build logger from config {logger_config}: {e}")
            raise

    self.loggers = loggers

close_loggers()

Clean up all loggers.

Source code in optimus_dl/recipe/train/mixins/managers/logger_manager.py
def close_loggers(self):
    """Clean up all loggers."""
    for logger_instance in self.loggers or []:
        try:
            logger_instance.close()
        except Exception as e:
            logger.error(
                f"Failed to close logger {logger_instance.__class__.__name__}: {e}"
            )

load_state_dict(state_dict)

Load logger states from a checkpoint.

Source code in optimus_dl/recipe/train/mixins/managers/logger_manager.py
def load_state_dict(self, state_dict):
    """Load logger states from a checkpoint."""
    self.previous_state = state_dict

log_metrics_to_loggers(metrics, step, group='train')

Dispatch metrics to all active loggers.

Parameters:

Name Type Description Default
metrics

Dictionary of metric values.

required
step int

Current iteration.

required
group str

Metric group name.

'train'
Source code in optimus_dl/recipe/train/mixins/managers/logger_manager.py
def log_metrics_to_loggers(self, metrics, step: int, group: str = "train"):
    """Dispatch metrics to all active loggers.

    Args:
        metrics: Dictionary of metric values.
        step: Current iteration.
        group: Metric group name.
    """
    for logger_instance in self.loggers or []:
        try:
            logger_instance.log_metrics(metrics, step, group)
        except Exception as e:
            logger.error(
                f"Failed to log metrics with {logger_instance.__class__.__name__}: {e}"
            )

setup_loggers(experiment_name, full_config)

Initialize all loggers with experiment context.

Parameters:

Name Type Description Default
experiment_name str

Name of the experiment.

required
full_config dict

Complete training configuration dictionary.

required
Source code in optimus_dl/recipe/train/mixins/managers/logger_manager.py
def setup_loggers(self, experiment_name: str, full_config: dict):
    """Initialize all loggers with experiment context.

    Args:
        experiment_name: Name of the experiment.
        full_config: Complete training configuration dictionary.
    """
    for logger_instance in self.loggers or []:
        try:
            logger_instance.setup(experiment_name, full_config)
        except Exception as e:
            logger.error(
                f"Failed to setup logger {logger_instance.__class__.__name__}: {e}"
            )

state_dict()

Collect state from all loggers for checkpointing.

Source code in optimus_dl/recipe/train/mixins/managers/logger_manager.py
def state_dict(self):
    """Collect state from all loggers for checkpointing."""
    return {
        logger_instance.cfg.id: logger_instance.state_dict()
        for logger_instance in self.loggers or []
    }

Modules and Sub-packages