Skip to content

base

optimus_dl.modules.lr_scheduler.base

BaseLRScheduler

Bases: ABC

Abstract base class for learning rate schedulers.

This class provides a uniform interface for learning rate scheduling that is decoupled from specific optimizer implementations. It manages the stepping of learning rates across multiple parameter groups and handles state serialization for checkpointing.

Attributes:

Name Type Description
optimizer

The PyTorch optimizer whose learning rates are managed.

base_lrs

Initial learning rates for each parameter group.

Source code in optimus_dl/modules/lr_scheduler/base.py
class BaseLRScheduler(ABC):
    """Abstract base class for learning rate schedulers.

    This class provides a uniform interface for learning rate scheduling that
    is decoupled from specific optimizer implementations. It manages the
    stepping of learning rates across multiple parameter groups and handles
    state serialization for checkpointing.

    Attributes:
        optimizer: The PyTorch optimizer whose learning rates are managed.
        base_lrs: Initial learning rates for each parameter group.
    """

    def __init__(self, optimizer: Optimizer, **kwargs):
        """Initialize the scheduler.

        Args:
            optimizer: The optimizer to manage.
            **kwargs: Additional keyword arguments.
        """
        self.optimizer = optimizer
        self._step_count = 0
        self.base_lrs = [group["lr"] for group in optimizer.param_groups]

    @abstractmethod
    def get_lr(self) -> list[float]:
        """Calculate the target learning rates for the current step.

        Returns:
            List of floats representing the new learning rates for each
            parameter group in the optimizer.
        """
        pass

    def step(self) -> None:
        """Update the optimizer's learning rates based on the current step count.

        This should be called at the end of each training iteration.
        """
        self._step_count += 1
        self.set()

    def set(self) -> None:
        """Set the learning rates of the optimizer to the current values."""
        values = self.get_lr()
        for param_group, lr in zip(self.optimizer.param_groups, values, strict=True):
            param_group["lr"] = lr

    def get_last_lr(self) -> list[float]:
        """Return the most recently computed learning rates."""
        return [group["lr"] for group in self.optimizer.param_groups]

    def state_dict(self) -> dict[str, Any]:
        """Return the scheduler's state for checkpointing."""
        return {
            "step_count": self._step_count,
            "base_lrs": self.base_lrs,
        }

    def load_state_dict(self, state_dict: dict[str, Any]) -> None:
        """Restore the scheduler's state from a checkpoint."""
        self._step_count = state_dict["step_count"]
        self.base_lrs = state_dict["base_lrs"]

    @property
    def last_epoch(self) -> int:
        """The current step count (for compatibility with PyTorch schedulers)."""
        return self._step_count

last_epoch property

The current step count (for compatibility with PyTorch schedulers).

__init__(optimizer, **kwargs)

Initialize the scheduler.

Parameters:

Name Type Description Default
optimizer Optimizer

The optimizer to manage.

required
**kwargs

Additional keyword arguments.

{}
Source code in optimus_dl/modules/lr_scheduler/base.py
def __init__(self, optimizer: Optimizer, **kwargs):
    """Initialize the scheduler.

    Args:
        optimizer: The optimizer to manage.
        **kwargs: Additional keyword arguments.
    """
    self.optimizer = optimizer
    self._step_count = 0
    self.base_lrs = [group["lr"] for group in optimizer.param_groups]

get_last_lr()

Return the most recently computed learning rates.

Source code in optimus_dl/modules/lr_scheduler/base.py
def get_last_lr(self) -> list[float]:
    """Return the most recently computed learning rates."""
    return [group["lr"] for group in self.optimizer.param_groups]

get_lr() abstractmethod

Calculate the target learning rates for the current step.

Returns:

Type Description
list[float]

List of floats representing the new learning rates for each

list[float]

parameter group in the optimizer.

Source code in optimus_dl/modules/lr_scheduler/base.py
@abstractmethod
def get_lr(self) -> list[float]:
    """Calculate the target learning rates for the current step.

    Returns:
        List of floats representing the new learning rates for each
        parameter group in the optimizer.
    """
    pass

load_state_dict(state_dict)

Restore the scheduler's state from a checkpoint.

Source code in optimus_dl/modules/lr_scheduler/base.py
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
    """Restore the scheduler's state from a checkpoint."""
    self._step_count = state_dict["step_count"]
    self.base_lrs = state_dict["base_lrs"]

set()

Set the learning rates of the optimizer to the current values.

Source code in optimus_dl/modules/lr_scheduler/base.py
def set(self) -> None:
    """Set the learning rates of the optimizer to the current values."""
    values = self.get_lr()
    for param_group, lr in zip(self.optimizer.param_groups, values, strict=True):
        param_group["lr"] = lr

state_dict()

Return the scheduler's state for checkpointing.

Source code in optimus_dl/modules/lr_scheduler/base.py
def state_dict(self) -> dict[str, Any]:
    """Return the scheduler's state for checkpointing."""
    return {
        "step_count": self._step_count,
        "base_lrs": self.base_lrs,
    }

step()

Update the optimizer's learning rates based on the current step count.

This should be called at the end of each training iteration.

Source code in optimus_dl/modules/lr_scheduler/base.py
def step(self) -> None:
    """Update the optimizer's learning rates based on the current step count.

    This should be called at the end of each training iteration.
    """
    self._step_count += 1
    self.set()

BaseLRSchedulerConfig dataclass

Bases: RegistryConfig

Base configuration for learning rate schedulers.

Source code in optimus_dl/modules/lr_scheduler/base.py
@dataclass
class BaseLRSchedulerConfig(RegistryConfig):
    """Base configuration for learning rate schedulers."""

    pass