Skip to content

config

optimus_dl.modules.optim.config

General optimizer config

AmpConfig dataclass

AmpConfig(enabled: bool = False, dtype: str = 'torch.bfloat16', enable_scaler: bool = '\({eval: \'"\){.dtype}" == "torch.float16"\'}', init_scale: float = 65536, growth_factor: float = 2.0, backoff_factor: float = 0.5, growth_interval: int = 2000)

Parameters:

Name Type Description Default
enabled bool
False
dtype str
'torch.bfloat16'
enable_scaler bool
'${eval: \'"${.dtype}" == "torch.float16"\'}'
init_scale float
65536
growth_factor float
2.0
backoff_factor float
0.5
growth_interval int
2000
Source code in optimus_dl/modules/optim/config.py
@dataclass
class AmpConfig:
    enabled: bool = False
    dtype: str = "torch.bfloat16"

    enable_scaler: bool = '${eval: \'"${.dtype}" == "torch.float16"\'}'
    init_scale: float = 2**16
    growth_factor: float = 2.0
    backoff_factor: float = 0.5
    growth_interval: int = 2000

OptimizationConfig dataclass

OptimizationConfig(optimizer: optimus_dl.core.registry.RegistryConfig, iterations: int = 1000, acc_steps: int = 1, clip_grad_norm: float | None = None, amp: optimus_dl.modules.optim.config.AmpConfig = )

Parameters:

Name Type Description Default
optimizer RegistryConfig
required
iterations int

Total train steps

1000
acc_steps int

Steps to accumulate gradient

1
clip_grad_norm float | None

Clip gradient norm

None
amp AmpConfig

AmpConfig(enabled: bool = False, dtype: str = 'torch.bfloat16', enable_scaler: bool = '\({eval: \'"\){.dtype}" == "torch.float16"\'}', init_scale: float = 65536, growth_factor: float = 2.0, backoff_factor: float = 0.5, growth_interval: int = 2000)

<dynamic>
Source code in optimus_dl/modules/optim/config.py
@dataclass
class OptimizationConfig:
    optimizer: RegistryConfig

    iterations: int = field(default=1000, metadata={"description": "Total train steps"})
    acc_steps: int = field(
        default=1, metadata={"description": "Steps to accumulate gradient"}
    )
    clip_grad_norm: float | None = field(
        default=None, metadata={"description": "Clip gradient norm"}
    )
    amp: AmpConfig = field(default_factory=AmpConfig)