Skip to content

optimizer_builder

optimus_dl.recipe.train.builders.optimizer_builder

Optimizer builder mixin for building optimizers.

OptimizerBuilder

Builder class responsible for creating the optimizer.

Takes parameter groups from the model and instantiates the configured optimizer (e.g., AdamW). It also logs the total number of optimized parameters.

Parameters:

Name Type Description Default
cfg OptimizerBuilderConfig

Builder configuration.

required
optimization_config OptimizationConfig

Optimization settings including the optimizer config.

required
Source code in optimus_dl/recipe/train/builders/optimizer_builder.py
class OptimizerBuilder:
    """Builder class responsible for creating the optimizer.

    Takes parameter groups from the model and instantiates the configured
    optimizer (e.g., AdamW). It also logs the total number of optimized
    parameters.

    Args:
        cfg: Builder configuration.
        optimization_config: Optimization settings including the optimizer config.
    """

    def __init__(
        self,
        cfg: OptimizerBuilderConfig,
        optimization_config: OptimizationConfig,
        **kwargs: Any,
    ):
        self.optimization_config = optimization_config

    def build_optimizer(self, params, **kwargs) -> Optimizer:
        """Build and validate the optimizer.

        Args:
            params: Iterable of parameters or dicts defining parameter groups
                (typically from `model.make_parameter_groups()`).
            **kwargs: Additional arguments passed to the optimizer constructor.

        Returns:
            Instantiated Optimizer.
        """
        optimizer = build(
            "optimizer", self.optimization_config.optimizer, params=params, **kwargs
        )
        assert isinstance(optimizer, Optimizer)
        logger.info(f"Optimizer \n{optimizer}")
        optimized_params = []
        for param_group in optimizer.param_groups:
            optimized_params.append(
                sum([p.numel() for p in param_group["params"] if p.requires_grad])
            )
        optimized_params_num = sum(optimized_params)
        logger.info(
            f"Optimized {optimized_params_num:,} parameters. Per group: {[f'{i:,}' for i in optimized_params]}"
        )
        log_averaged("optimized_params", optimized_params_num)

        return optimizer

build_optimizer(params, **kwargs)

Build and validate the optimizer.

Parameters:

Name Type Description Default
params

Iterable of parameters or dicts defining parameter groups (typically from model.make_parameter_groups()).

required
**kwargs

Additional arguments passed to the optimizer constructor.

{}

Returns:

Type Description
Optimizer

Instantiated Optimizer.

Source code in optimus_dl/recipe/train/builders/optimizer_builder.py
def build_optimizer(self, params, **kwargs) -> Optimizer:
    """Build and validate the optimizer.

    Args:
        params: Iterable of parameters or dicts defining parameter groups
            (typically from `model.make_parameter_groups()`).
        **kwargs: Additional arguments passed to the optimizer constructor.

    Returns:
        Instantiated Optimizer.
    """
    optimizer = build(
        "optimizer", self.optimization_config.optimizer, params=params, **kwargs
    )
    assert isinstance(optimizer, Optimizer)
    logger.info(f"Optimizer \n{optimizer}")
    optimized_params = []
    for param_group in optimizer.param_groups:
        optimized_params.append(
            sum([p.numel() for p in param_group["params"] if p.requires_grad])
        )
    optimized_params_num = sum(optimized_params)
    logger.info(
        f"Optimized {optimized_params_num:,} parameters. Per group: {[f'{i:,}' for i in optimized_params]}"
    )
    log_averaged("optimized_params", optimized_params_num)

    return optimizer

OptimizerBuilderConfig dataclass

Bases: RegistryConfig

Configuration for OptimizerBuilder.

Source code in optimus_dl/recipe/train/builders/optimizer_builder.py
@dataclass
class OptimizerBuilderConfig(RegistryConfig):
    """Configuration for OptimizerBuilder."""

    pass