Skip to content

Index

optimus_dl.recipe.train.mixins.execution

TrainingContextMixin

Mixin for setting up the training context (precision, scaling, devices).

Responsible for initializing PyTorch's AMP (Automatic Mixed Precision) and GradScaler based on the optimization configuration. This ensures consistent precision settings across the training loop.

Parameters:

Name Type Description Default
optimization_config OptimizationConfig

Configuration containing AMP settings.

required
Source code in optimus_dl/recipe/train/mixins/execution/context_mixin.py
class TrainingContextMixin:
    """Mixin for setting up the training context (precision, scaling, devices).

    Responsible for initializing PyTorch's AMP (Automatic Mixed Precision) and
    GradScaler based on the optimization configuration. This ensures consistent
    precision settings across the training loop.

    Args:
        optimization_config: Configuration containing AMP settings.
    """

    def __init__(self, optimization_config: OptimizationConfig):
        self.optimization_config = optimization_config

    def setup_training_context(self, device: torch.device) -> dict[str, Any]:
        """Initialize AMP context and Gradient Scaler.

        Args:
            device: The target compute device.

        Returns:
            A dictionary containing:

            - "scaler": The torch.cuda.amp.GradScaler instance.
            - "amp_ctx": The torch.autocast context manager.
            - "amp_cfg": The raw AMP configuration object.
            - "device": The device being used.
        """
        amp_cfg = self.optimization_config.amp
        scaler = torch.GradScaler(
            device=device.type,
            enabled=amp_cfg.enabled and amp_cfg.enable_scaler,
            init_scale=amp_cfg.init_scale,
            growth_factor=amp_cfg.growth_factor,
            backoff_factor=amp_cfg.backoff_factor,
            growth_interval=amp_cfg.growth_interval,
        )
        logger.info(f"Using grad scaler: {scaler.is_enabled()}")
        # Safe dtype conversion without eval()
        dtype_map = {
            "torch.float16": torch.float16,
            "torch.float32": torch.float32,
            "torch.bfloat16": torch.bfloat16,
            "float16": torch.float16,
            "float32": torch.float32,
            "bfloat16": torch.bfloat16,
        }

        dtype = dtype_map.get(amp_cfg.dtype, torch.float16)
        if amp_cfg.dtype not in dtype_map:
            logger.warning(f"Unknown dtype '{amp_cfg.dtype}', defaulting to float16")

        amp_ctx = torch.autocast(device.type, dtype=dtype, enabled=amp_cfg.enabled)

        return {
            "scaler": scaler,
            "amp_ctx": amp_ctx,
            "amp_cfg": amp_cfg,
            "device": device,
        }

setup_training_context(device)

Initialize AMP context and Gradient Scaler.

Parameters:

Name Type Description Default
device device

The target compute device.

required

Returns:

Type Description
dict[str, Any]

A dictionary containing:

dict[str, Any]
  • "scaler": The torch.cuda.amp.GradScaler instance.
dict[str, Any]
  • "amp_ctx": The torch.autocast context manager.
dict[str, Any]
  • "amp_cfg": The raw AMP configuration object.
dict[str, Any]
  • "device": The device being used.
Source code in optimus_dl/recipe/train/mixins/execution/context_mixin.py
def setup_training_context(self, device: torch.device) -> dict[str, Any]:
    """Initialize AMP context and Gradient Scaler.

    Args:
        device: The target compute device.

    Returns:
        A dictionary containing:

        - "scaler": The torch.cuda.amp.GradScaler instance.
        - "amp_ctx": The torch.autocast context manager.
        - "amp_cfg": The raw AMP configuration object.
        - "device": The device being used.
    """
    amp_cfg = self.optimization_config.amp
    scaler = torch.GradScaler(
        device=device.type,
        enabled=amp_cfg.enabled and amp_cfg.enable_scaler,
        init_scale=amp_cfg.init_scale,
        growth_factor=amp_cfg.growth_factor,
        backoff_factor=amp_cfg.backoff_factor,
        growth_interval=amp_cfg.growth_interval,
    )
    logger.info(f"Using grad scaler: {scaler.is_enabled()}")
    # Safe dtype conversion without eval()
    dtype_map = {
        "torch.float16": torch.float16,
        "torch.float32": torch.float32,
        "torch.bfloat16": torch.bfloat16,
        "float16": torch.float16,
        "float32": torch.float32,
        "bfloat16": torch.bfloat16,
    }

    dtype = dtype_map.get(amp_cfg.dtype, torch.float16)
    if amp_cfg.dtype not in dtype_map:
        logger.warning(f"Unknown dtype '{amp_cfg.dtype}', defaulting to float16")

    amp_ctx = torch.autocast(device.type, dtype=dtype, enabled=amp_cfg.enabled)

    return {
        "scaler": scaler,
        "amp_ctx": amp_ctx,
        "amp_cfg": amp_cfg,
        "device": device,
    }

TrainingInterruptionMixin

Mixin for gracefully handling training interruptions.

Provides a mechanism to catch KeyboardInterrupt (Ctrl+C) and trigger a safe shutdown sequence, which typically involves saving a final checkpoint to ensure progress is not lost.

Parameters:

Name Type Description Default
save_freq int

Frequency of regular checkpoints. If 0, saving is disabled.

0
output_path str | None

Path where checkpoints are saved.

None
checkpoint_callback Callable[..., None] | None

Callable to execute for saving the checkpoint.

None
Source code in optimus_dl/recipe/train/mixins/execution/interruption_mixin.py
class TrainingInterruptionMixin:
    """Mixin for gracefully handling training interruptions.

    Provides a mechanism to catch `KeyboardInterrupt` (Ctrl+C) and trigger a
    safe shutdown sequence, which typically involves saving a final checkpoint
    to ensure progress is not lost.

    Args:
        save_freq: Frequency of regular checkpoints. If 0, saving is disabled.
        output_path: Path where checkpoints are saved.
        checkpoint_callback: Callable to execute for saving the checkpoint.
    """

    def __init__(
        self,
        save_freq: int = 0,
        output_path: str | None = None,
        checkpoint_callback: Callable[..., None] | None = None,
    ):
        self.save_freq = save_freq
        self.output_path = output_path
        self.checkpoint_callback = checkpoint_callback

    def handle_training_interruption(
        self,
        iteration: int,
        collective: Collective | None,
        **kwargs: Any,
    ) -> None:
        """Handle interruption by saving a final checkpoint.

        Args:
            iteration: The current training iteration count.
            collective: The distributed collective instance.
            **kwargs: Additional arguments to pass to the checkpoint callback.
        """
        logger.info("Training interrupted by user")

        # Check if we have checkpoint saving configured and callback available
        if self.save_freq > 0 and self.output_path and self.checkpoint_callback:
            try:
                logger.info("Saving final checkpoint...")

                # Call the checkpoint callback with the required parameters
                self.checkpoint_callback(
                    checkpoint_path=self.output_path,
                    iteration=iteration,
                    collective=collective,
                    **kwargs,
                )
                logger.info("Final checkpoint saved")

            except Exception as e:
                logger.error(f"Failed to save final checkpoint: {e}")
                raise
        elif self.save_freq > 0:
            logger.warning(
                "Checkpoint saving requested but no callback provided or output_path missing"
            )

handle_training_interruption(iteration, collective, **kwargs)

Handle interruption by saving a final checkpoint.

Parameters:

Name Type Description Default
iteration int

The current training iteration count.

required
collective Collective | None

The distributed collective instance.

required
**kwargs Any

Additional arguments to pass to the checkpoint callback.

{}
Source code in optimus_dl/recipe/train/mixins/execution/interruption_mixin.py
def handle_training_interruption(
    self,
    iteration: int,
    collective: Collective | None,
    **kwargs: Any,
) -> None:
    """Handle interruption by saving a final checkpoint.

    Args:
        iteration: The current training iteration count.
        collective: The distributed collective instance.
        **kwargs: Additional arguments to pass to the checkpoint callback.
    """
    logger.info("Training interrupted by user")

    # Check if we have checkpoint saving configured and callback available
    if self.save_freq > 0 and self.output_path and self.checkpoint_callback:
        try:
            logger.info("Saving final checkpoint...")

            # Call the checkpoint callback with the required parameters
            self.checkpoint_callback(
                checkpoint_path=self.output_path,
                iteration=iteration,
                collective=collective,
                **kwargs,
            )
            logger.info("Final checkpoint saved")

        except Exception as e:
            logger.error(f"Failed to save final checkpoint: {e}")
            raise
    elif self.save_freq > 0:
        logger.warning(
            "Checkpoint saving requested but no callback provided or output_path missing"
        )

TrainingIterationMixin

Mixin for executing a complete training step with gradient accumulation.

Encapsulates the core training logic: 1. Forward Pass: Runs the model and criterion, measuring time. 2. Backward Pass: Scales gradients and backpropagates, handling loss parallelism if applicable. 3. Optimization: Unscales gradients, clips norms, and steps the optimizer. 4. Logging: Records detailed performance metrics (forward/backward times, grad norms, etc.).

Parameters:

Name Type Description Default
optimization_config OptimizationConfig

Configuration for optimization (accumulation steps, clipping).

required
log_freq int

Frequency of metric logging.

1
Source code in optimus_dl/recipe/train/mixins/execution/iteration_mixin.py
class TrainingIterationMixin:
    """Mixin for executing a complete training step with gradient accumulation.

    Encapsulates the core training logic:
    1.  **Forward Pass**: Runs the model and criterion, measuring time.
    2.  **Backward Pass**: Scales gradients and backpropagates, handling
        loss parallelism if applicable.
    3.  **Optimization**: Unscales gradients, clips norms, and steps the optimizer.
    4.  **Logging**: Records detailed performance metrics (forward/backward times,
        grad norms, etc.).

    Args:
        optimization_config: Configuration for optimization (accumulation steps, clipping).
        log_freq: Frequency of metric logging.
    """

    def __init__(self, optimization_config: OptimizationConfig, log_freq: int = 1):
        self.optimization_config = optimization_config
        self.log_freq = log_freq

    def log_memory_usage(self):
        """Log GPU memory usage statistics."""
        if torch.cuda.is_available():
            log_summed("gpu_gb_allocated", torch.cuda.memory_allocated() / (1024**3))
            log_summed("gpu_gb_used", torch.cuda.max_memory_allocated() / (1024**3))

    def execute_forward_pass(
        self,
        model: BaseModel,
        criterion: BaseCriterion,
        batch: Any,
        amp_ctx: Any,
        requested_protocols: set[str] | None = None,
    ) -> ForwardPassResult:
        """Run the forward pass inside an AMP context.

        Args:
            model: The model to run.
            criterion: The loss function.
            batch: The input data.
            amp_ctx: The autocast context manager.
            requested_protocols: Protocols requested by the metrics system.

        Returns:
            ForwardPassResult with the computed loss, exposed protocols, and execution time.
        """
        with amp_ctx:
            elapsed_forward, (loss, exposed) = measured_lambda(
                lambda: criterion(model, batch, requested_protocols=requested_protocols)
            )
        return ForwardPassResult(
            loss=loss, exposed_protocols=exposed, elapsed_time=elapsed_forward
        )

    def execute_backward_pass(self, loss: torch.Tensor, scaler: Any) -> float:
        """Run the backward pass with gradient scaling.

        Handles `loss_parallel` context if the loss is a DTensor.

        Args:
            loss: The computed loss tensor.
            scaler: The gradient scaler.

        Returns:
            Execution time in milliseconds.
        """

        def backward():
            with loss_parallel() if isinstance(loss, DTensor) else nullcontext():
                scaler.scale(loss).backward()

        elapsed_backward, _ = measured_lambda(backward)
        return elapsed_backward

    def execute_optimizer_step(
        self,
        optimizer: Optimizer,
        model: BaseModel,
        scaler: Any,
        clip_grad_norm: float | None = None,
    ) -> OptimizerStepResult:
        """Perform the optimization step.

        Includes gradient unscaling, optional gradient clipping, and the
        optimizer step itself. Updates the scaler state afterwards.

        Args:
            optimizer: The optimizer.
            model: The model (needed for clipping gradients).
            scaler: The gradient scaler.
            clip_grad_norm: Maximum norm for gradient clipping.

        Returns:
            OptimizerStepResult with execution time and the computed gradient norm.
        """
        scaler.unscale_(optimizer)

        grad_norm = None
        if clip_grad_norm is not None:
            from torch.distributed.tensor.experimental import implicit_replication

            with implicit_replication():
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), max_norm=clip_grad_norm
                )

        elapsed, _ = measured_lambda(lambda: scaler.step(optimizer))
        scaler.update()

        if scaler.is_enabled():
            log_averaged("grad_scale", scaler.get_scale())

        return OptimizerStepResult(elapsed_time=elapsed, grad_norm=grad_norm)

    def log_batch_metrics(
        self,
        elapsed_batch_get: float,
        elapsed_forward: float,
        elapsed_backward: float,
        acc_steps: int,
    ) -> None:
        """Log timing metrics for data loading and forward/backward passes."""
        weight = 1 / acc_steps

        log_averaged(
            "perf/batch_get",
            value=elapsed_batch_get,
            weight=weight,
            priority=999,
        )
        log_averaged(
            "perf/forward",
            value=elapsed_forward,
            weight=weight,
            priority=1000,
        )
        log_averaged(
            "perf/backward",
            value=elapsed_backward,
            weight=weight,
            priority=1001,
        )

    def log_optimizer_metrics(
        self,
        elapsed_optimizer: float,
        grad_norm: torch.Tensor | None,
        lr_scheduler: Any | None,
        optimizer: Optimizer,
    ) -> None:
        """Log optimizer performance, gradient norms, and learning rates."""
        log_averaged("perf/optimizer", value=elapsed_optimizer, priority=1002)

        # Log gradient norm if clipping was performed
        if grad_norm is not None:
            log_averaged(
                "grad_norm",
                lambda: (float(grad_norm) if grad_norm is not None else 0.0),
            )

        # Learning rate (cheap but we only need it periodically)
        if lr_scheduler is not None:
            log_averaged("learning_rate", lambda: lr_scheduler.get_last_lr()[0])
        else:
            log_averaged("learning_rate", lambda: optimizer.param_groups[0]["lr"])

    def run_training_iteration(
        self,
        model: BaseModel,
        optimizer: Optimizer,
        criterion: BaseCriterion,
        train_data_iter: Iterator,
        training_context: dict[str, Any],
        lr_scheduler: Any | None = None,
        metric_engine: Any | None = None,
    ) -> None:
        """Execute one full training iteration, including gradient accumulation.

        This is the main driver for a training step. It loops `acc_steps` times
        to accumulate gradients before performing a single optimizer update.

        Args:
            model: The model to train.
            optimizer: The optimizer.
            criterion: The loss function.
            train_data_iter: Iterator yielding training batches.
            training_context: Dict with scaler, amp_ctx, etc.
            lr_scheduler: Optional learning rate scheduler.
            metric_engine: Optional MetricEngine for training metrics.
        """
        with meters_group("train", log_freq=self.log_freq) as should_log:
            optimizer.zero_grad()
            model.train()

            requested_protocols = None
            if metric_engine and should_log:
                requested_protocols = metric_engine.required_external_protocols

            # Gradient accumulation loop
            for microbatch_idx in range(self.optimization_config.acc_steps):
                is_last_microbatch = (
                    microbatch_idx == self.optimization_config.acc_steps - 1
                )

                try:
                    elapsed_batch_get, batch = measured_next(train_data_iter)
                except StopIteration:
                    logger.error("Training data iterator exhausted unexpectedly")
                    break
                except Exception as e:
                    logger.error(f"Error getting batch: {e}")
                    continue

                with self.accumulation_context(model, is_last_microbatch):
                    forward_result = self.execute_forward_pass(
                        model,
                        criterion,
                        batch,
                        training_context["amp_ctx"],
                        requested_protocols=requested_protocols,
                    )
                    loss = forward_result.loss / self.optimization_config.acc_steps

                    if metric_engine and should_log:
                        # Pass computed data (loss, logits, etc.) to avoid redundant work in engine
                        computed_data = forward_result.exposed_protocols.copy()
                        computed_data["loss"] = forward_result.loss
                        metric_engine.update(
                            data=dict(model=model, batch=batch),
                            computed_data=computed_data,
                        )

                    elapsed_backward = self.execute_backward_pass(
                        loss, training_context["scaler"]
                    )

                # Log performance metrics using the training metrics mixin
                self.log_batch_metrics(
                    elapsed_batch_get,
                    forward_result.elapsed_time,
                    elapsed_backward,
                    self.optimization_config.acc_steps,
                )

            # Optimizer step
            optimizer_result = self.execute_optimizer_step(
                optimizer,
                model,
                training_context["scaler"],
                self.optimization_config.clip_grad_norm,
            )

            # Log optimizer metrics
            self.log_optimizer_metrics(
                optimizer_result.elapsed_time,
                optimizer_result.grad_norm,
                lr_scheduler,
                optimizer,
            )
            self.log_memory_usage()
            optimizer.zero_grad()

            if lr_scheduler is not None:
                lr_scheduler.step()

    def accumulation_context(self, model, is_last_microbatch):
        """Get the appropriate context manager for gradient accumulation.

        For FSDP/DDP models, this handles synchronization (e.g., disabling
        all-reduce during accumulation steps).
        """
        if hasattr(model, "accumulation_context"):
            ctx = model.accumulation_context(is_last_microbatch=is_last_microbatch)
            if not is_last_microbatch:
                warn_once(logger, "Using accumulation context")
            return ctx
        else:
            warn_once(logger, "Model does not support accumulation context, skipping")
            return nullcontext()

accumulation_context(model, is_last_microbatch)

Get the appropriate context manager for gradient accumulation.

For FSDP/DDP models, this handles synchronization (e.g., disabling all-reduce during accumulation steps).

Source code in optimus_dl/recipe/train/mixins/execution/iteration_mixin.py
def accumulation_context(self, model, is_last_microbatch):
    """Get the appropriate context manager for gradient accumulation.

    For FSDP/DDP models, this handles synchronization (e.g., disabling
    all-reduce during accumulation steps).
    """
    if hasattr(model, "accumulation_context"):
        ctx = model.accumulation_context(is_last_microbatch=is_last_microbatch)
        if not is_last_microbatch:
            warn_once(logger, "Using accumulation context")
        return ctx
    else:
        warn_once(logger, "Model does not support accumulation context, skipping")
        return nullcontext()

execute_backward_pass(loss, scaler)

Run the backward pass with gradient scaling.

Handles loss_parallel context if the loss is a DTensor.

Parameters:

Name Type Description Default
loss Tensor

The computed loss tensor.

required
scaler Any

The gradient scaler.

required

Returns:

Type Description
float

Execution time in milliseconds.

Source code in optimus_dl/recipe/train/mixins/execution/iteration_mixin.py
def execute_backward_pass(self, loss: torch.Tensor, scaler: Any) -> float:
    """Run the backward pass with gradient scaling.

    Handles `loss_parallel` context if the loss is a DTensor.

    Args:
        loss: The computed loss tensor.
        scaler: The gradient scaler.

    Returns:
        Execution time in milliseconds.
    """

    def backward():
        with loss_parallel() if isinstance(loss, DTensor) else nullcontext():
            scaler.scale(loss).backward()

    elapsed_backward, _ = measured_lambda(backward)
    return elapsed_backward

execute_forward_pass(model, criterion, batch, amp_ctx, requested_protocols=None)

Run the forward pass inside an AMP context.

Parameters:

Name Type Description Default
model BaseModel

The model to run.

required
criterion BaseCriterion

The loss function.

required
batch Any

The input data.

required
amp_ctx Any

The autocast context manager.

required
requested_protocols set[str] | None

Protocols requested by the metrics system.

None

Returns:

Type Description
ForwardPassResult

ForwardPassResult with the computed loss, exposed protocols, and execution time.

Source code in optimus_dl/recipe/train/mixins/execution/iteration_mixin.py
def execute_forward_pass(
    self,
    model: BaseModel,
    criterion: BaseCriterion,
    batch: Any,
    amp_ctx: Any,
    requested_protocols: set[str] | None = None,
) -> ForwardPassResult:
    """Run the forward pass inside an AMP context.

    Args:
        model: The model to run.
        criterion: The loss function.
        batch: The input data.
        amp_ctx: The autocast context manager.
        requested_protocols: Protocols requested by the metrics system.

    Returns:
        ForwardPassResult with the computed loss, exposed protocols, and execution time.
    """
    with amp_ctx:
        elapsed_forward, (loss, exposed) = measured_lambda(
            lambda: criterion(model, batch, requested_protocols=requested_protocols)
        )
    return ForwardPassResult(
        loss=loss, exposed_protocols=exposed, elapsed_time=elapsed_forward
    )

execute_optimizer_step(optimizer, model, scaler, clip_grad_norm=None)

Perform the optimization step.

Includes gradient unscaling, optional gradient clipping, and the optimizer step itself. Updates the scaler state afterwards.

Parameters:

Name Type Description Default
optimizer Optimizer

The optimizer.

required
model BaseModel

The model (needed for clipping gradients).

required
scaler Any

The gradient scaler.

required
clip_grad_norm float | None

Maximum norm for gradient clipping.

None

Returns:

Type Description
OptimizerStepResult

OptimizerStepResult with execution time and the computed gradient norm.

Source code in optimus_dl/recipe/train/mixins/execution/iteration_mixin.py
def execute_optimizer_step(
    self,
    optimizer: Optimizer,
    model: BaseModel,
    scaler: Any,
    clip_grad_norm: float | None = None,
) -> OptimizerStepResult:
    """Perform the optimization step.

    Includes gradient unscaling, optional gradient clipping, and the
    optimizer step itself. Updates the scaler state afterwards.

    Args:
        optimizer: The optimizer.
        model: The model (needed for clipping gradients).
        scaler: The gradient scaler.
        clip_grad_norm: Maximum norm for gradient clipping.

    Returns:
        OptimizerStepResult with execution time and the computed gradient norm.
    """
    scaler.unscale_(optimizer)

    grad_norm = None
    if clip_grad_norm is not None:
        from torch.distributed.tensor.experimental import implicit_replication

        with implicit_replication():
            grad_norm = torch.nn.utils.clip_grad_norm_(
                model.parameters(), max_norm=clip_grad_norm
            )

    elapsed, _ = measured_lambda(lambda: scaler.step(optimizer))
    scaler.update()

    if scaler.is_enabled():
        log_averaged("grad_scale", scaler.get_scale())

    return OptimizerStepResult(elapsed_time=elapsed, grad_norm=grad_norm)

log_batch_metrics(elapsed_batch_get, elapsed_forward, elapsed_backward, acc_steps)

Log timing metrics for data loading and forward/backward passes.

Source code in optimus_dl/recipe/train/mixins/execution/iteration_mixin.py
def log_batch_metrics(
    self,
    elapsed_batch_get: float,
    elapsed_forward: float,
    elapsed_backward: float,
    acc_steps: int,
) -> None:
    """Log timing metrics for data loading and forward/backward passes."""
    weight = 1 / acc_steps

    log_averaged(
        "perf/batch_get",
        value=elapsed_batch_get,
        weight=weight,
        priority=999,
    )
    log_averaged(
        "perf/forward",
        value=elapsed_forward,
        weight=weight,
        priority=1000,
    )
    log_averaged(
        "perf/backward",
        value=elapsed_backward,
        weight=weight,
        priority=1001,
    )

log_memory_usage()

Log GPU memory usage statistics.

Source code in optimus_dl/recipe/train/mixins/execution/iteration_mixin.py
def log_memory_usage(self):
    """Log GPU memory usage statistics."""
    if torch.cuda.is_available():
        log_summed("gpu_gb_allocated", torch.cuda.memory_allocated() / (1024**3))
        log_summed("gpu_gb_used", torch.cuda.max_memory_allocated() / (1024**3))

log_optimizer_metrics(elapsed_optimizer, grad_norm, lr_scheduler, optimizer)

Log optimizer performance, gradient norms, and learning rates.

Source code in optimus_dl/recipe/train/mixins/execution/iteration_mixin.py
def log_optimizer_metrics(
    self,
    elapsed_optimizer: float,
    grad_norm: torch.Tensor | None,
    lr_scheduler: Any | None,
    optimizer: Optimizer,
) -> None:
    """Log optimizer performance, gradient norms, and learning rates."""
    log_averaged("perf/optimizer", value=elapsed_optimizer, priority=1002)

    # Log gradient norm if clipping was performed
    if grad_norm is not None:
        log_averaged(
            "grad_norm",
            lambda: (float(grad_norm) if grad_norm is not None else 0.0),
        )

    # Learning rate (cheap but we only need it periodically)
    if lr_scheduler is not None:
        log_averaged("learning_rate", lambda: lr_scheduler.get_last_lr()[0])
    else:
        log_averaged("learning_rate", lambda: optimizer.param_groups[0]["lr"])

run_training_iteration(model, optimizer, criterion, train_data_iter, training_context, lr_scheduler=None, metric_engine=None)

Execute one full training iteration, including gradient accumulation.

This is the main driver for a training step. It loops acc_steps times to accumulate gradients before performing a single optimizer update.

Parameters:

Name Type Description Default
model BaseModel

The model to train.

required
optimizer Optimizer

The optimizer.

required
criterion BaseCriterion

The loss function.

required
train_data_iter Iterator

Iterator yielding training batches.

required
training_context dict[str, Any]

Dict with scaler, amp_ctx, etc.

required
lr_scheduler Any | None

Optional learning rate scheduler.

None
metric_engine Any | None

Optional MetricEngine for training metrics.

None
Source code in optimus_dl/recipe/train/mixins/execution/iteration_mixin.py
def run_training_iteration(
    self,
    model: BaseModel,
    optimizer: Optimizer,
    criterion: BaseCriterion,
    train_data_iter: Iterator,
    training_context: dict[str, Any],
    lr_scheduler: Any | None = None,
    metric_engine: Any | None = None,
) -> None:
    """Execute one full training iteration, including gradient accumulation.

    This is the main driver for a training step. It loops `acc_steps` times
    to accumulate gradients before performing a single optimizer update.

    Args:
        model: The model to train.
        optimizer: The optimizer.
        criterion: The loss function.
        train_data_iter: Iterator yielding training batches.
        training_context: Dict with scaler, amp_ctx, etc.
        lr_scheduler: Optional learning rate scheduler.
        metric_engine: Optional MetricEngine for training metrics.
    """
    with meters_group("train", log_freq=self.log_freq) as should_log:
        optimizer.zero_grad()
        model.train()

        requested_protocols = None
        if metric_engine and should_log:
            requested_protocols = metric_engine.required_external_protocols

        # Gradient accumulation loop
        for microbatch_idx in range(self.optimization_config.acc_steps):
            is_last_microbatch = (
                microbatch_idx == self.optimization_config.acc_steps - 1
            )

            try:
                elapsed_batch_get, batch = measured_next(train_data_iter)
            except StopIteration:
                logger.error("Training data iterator exhausted unexpectedly")
                break
            except Exception as e:
                logger.error(f"Error getting batch: {e}")
                continue

            with self.accumulation_context(model, is_last_microbatch):
                forward_result = self.execute_forward_pass(
                    model,
                    criterion,
                    batch,
                    training_context["amp_ctx"],
                    requested_protocols=requested_protocols,
                )
                loss = forward_result.loss / self.optimization_config.acc_steps

                if metric_engine and should_log:
                    # Pass computed data (loss, logits, etc.) to avoid redundant work in engine
                    computed_data = forward_result.exposed_protocols.copy()
                    computed_data["loss"] = forward_result.loss
                    metric_engine.update(
                        data=dict(model=model, batch=batch),
                        computed_data=computed_data,
                    )

                elapsed_backward = self.execute_backward_pass(
                    loss, training_context["scaler"]
                )

            # Log performance metrics using the training metrics mixin
            self.log_batch_metrics(
                elapsed_batch_get,
                forward_result.elapsed_time,
                elapsed_backward,
                self.optimization_config.acc_steps,
            )

        # Optimizer step
        optimizer_result = self.execute_optimizer_step(
            optimizer,
            model,
            training_context["scaler"],
            self.optimization_config.clip_grad_norm,
        )

        # Log optimizer metrics
        self.log_optimizer_metrics(
            optimizer_result.elapsed_time,
            optimizer_result.grad_norm,
            lr_scheduler,
            optimizer,
        )
        self.log_memory_usage()
        optimizer.zero_grad()

        if lr_scheduler is not None:
            lr_scheduler.step()

Modules and Sub-packages

  • context_mixin: Training context mixin for AMP and gradient scaler setup.
  • interruption_mixin: Training interruption mixin for handling errors and keyboard interrupts.
  • iteration_mixin: Training iteration mixin for orchestrating complete training iterations.