Skip to content

muon

optimus_dl.modules.optim.muon

Muon optimizer

MuonConfig dataclass

Bases: RegistryConfigStrict

Configuration for Muon optimizer.

Muon is a momentum-based optimizer that uses Newton-Schulz iteration for preconditioning. It's designed for efficient training of large models.

Attributes:

Name Type Description

Parameters:

Name Type Description Default
lr float
0.001
weight_decay float
0.1
momentum float
0.95
nesterov bool
True
ns_coefficients tuple[float, float, float]
(3.4445, -4.775, 2.0315)
eps float
1e-07
ns_steps int
5
adjust_lr_fn str | None
None
Source code in optimus_dl/modules/optim/muon.py
@dataclass
class MuonConfig(RegistryConfigStrict):
    """Configuration for Muon optimizer.

    Muon is a momentum-based optimizer that uses Newton-Schulz iteration for
    preconditioning. It's designed for efficient training of large models.

    Attributes:
        lr: Learning rate for parameter updates.
        weight_decay: Weight decay (L2 penalty) coefficient applied to parameters.
        momentum: Momentum factor for the moving average of gradients.
        nesterov: Whether to use Nesterov momentum.
        ns_coefficients: Coefficients (a, b, c) for Newton-Schulz iteration algorithm.
            These control the convergence and stability of the preconditioning step.
            Default values are tuned for typical use cases.
        eps: Small constant added for numerical stability in computations.
        ns_steps: Number of Newton-Schulz iteration steps to perform for preconditioning.
            Higher values can improve accuracy but increase computational cost.
        adjust_lr_fn: Optional learning rate adjustment function name. If provided,
            applies dynamic learning rate scaling during training.
    """

    lr: float = 1e-3
    weight_decay: float = 0.1
    momentum: float = 0.95
    nesterov: bool = True
    ns_coefficients: tuple[float, float, float] = (DEFAULT_A, DEFAULT_B, DEFAULT_C)
    eps: float = EPS
    ns_steps: int = DEFAULT_NS_STEPS
    adjust_lr_fn: str | None = None