base
optimus_dl.modules.model.base
¶
Base model class for all language models in Optimus-DL.
This module defines the BaseModel class that all model architectures must inherit from. It provides common functionality for parameter grouping, distributed sharding, and tensor parallelism.
BaseModel
¶
Bases: Module
Base class for all language model architectures in the framework.
All model implementations should inherit from this class. It provides a standardized interface for:
- Forward Pass: Standard PyTorch forward method.
- Optimizer Integration: Custom parameter grouping (e.g., weight decay exclusion for norms/biases).
- FSDP2 Sharding: Support for fully sharded data parallelism via a custom
fully_shardmethod. - Tensor Parallelism: Support for sharding parameters across multiple
devices via
apply_tp.
Subclasses must implement:
forward(): The main computation loop.
Example
Source code in optimus_dl/modules/model/base.py
__init__()
¶
apply_tp(mesh, **kwargs)
¶
Apply Tensor Parallelism (sharding) to the model's parameters.
This method should use parallelize_module or similar utilities to
shard specific linear or embedding layers across the provided mesh.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mesh
|
The DeviceMesh for tensor parallelism. |
required | |
**kwargs
|
Additional model-specific TP flags (e.g., sequence_parallel). |
{}
|
Source code in optimus_dl/modules/model/base.py
fully_shard(**fsdp_kwargs)
¶
Define the FSDP2 sharding strategy for this model.
This method should wrap sub-modules (e.g., transformer blocks) with
fully_shard to enable efficient distributed training.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
**fsdp_kwargs
|
Arguments for the FSDP sharding process (e.g., mesh). |
{}
|
Source code in optimus_dl/modules/model/base.py
make_parameter_groups()
¶
Create parameter groups for optimizer configuration.
This method allows models to specify which parameters should have weight decay applied, or to use different learning rates for different sub-modules.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Dictionary with a 'params' key, or a list of such dictionaries, |
dict[str, Any]
|
compatible with PyTorch optimizers. |
Source code in optimus_dl/modules/model/base.py
register_arch(arch_name)
classmethod
¶
Decorator for registering an architecture variant of this model.
This method is dynamically populated on the class during registration in the model registry. It allows registering variants like '7b', '13b', etc.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
arch_name
|
str
|
Name of the architecture variant. |
required |
Returns:
| Type | Description |
|---|---|
Callable[[Callable[[], Any]], Any]
|
A decorator function. |