Mixin for building models and applying transformations.
Encapsulates the logic for:
1. Instantiating a BaseModel from a configuration object.
2. Sequentially applying a list of ModelTransforms (e.g., FSDP, DDP, compile).
3. Logging model statistics (parameter count).
Parameters:
Source code in optimus_dl/recipe/mixins/model_builder.py
| class ModelBuilder:
"""Mixin for building models and applying transformations.
Encapsulates the logic for:
1. Instantiating a `BaseModel` from a configuration object.
2. Sequentially applying a list of `ModelTransforms` (e.g., FSDP, DDP, compile).
3. Logging model statistics (parameter count).
Args:
cfg: Builder configuration.
model_transforms: List of configurations for transforms to apply.
"""
def __init__(
self,
cfg: ModelBuilderConfig,
model_transforms: list[ModelTransformConfig] | None = None,
**kwargs: Any,
):
self.model_transforms = model_transforms or []
def build_model(
self, model_config: ModelConfig | None, collective: Collective, **kwargs
) -> BaseModel:
"""Build the model and apply all configured transforms.
Args:
model_config: Configuration for the model architecture.
collective: Distributed collective for transforms that need it.
**kwargs: Additional arguments passed to model constructor and transforms.
Returns:
The fully constructed and transformed model.
"""
if model_config is None:
raise ValueError(
"model_config is None. Use build_model_from_checkpoint for evaluation."
)
model = build_model(model_config, **kwargs)
num_param_before = get_num_parameters(model)
logger.info(f"Params num (before model transforms): {num_param_before:,}")
log_averaged("model/num_params_before_transforms", num_param_before)
assert isinstance(model, BaseModel)
# Apply model transforms (including distributed setup)
model = self._apply_model_transforms(
model, collective=collective, device=collective.default_device, **kwargs
)
num_param_after = get_num_parameters(model)
logger.info(f"Model \n{model}")
logger.info(f"Params num (after model transforms): {num_param_after:,}")
log_averaged("model/num_params_after_transforms", num_param_after)
return model
def _apply_model_transforms(self, model: BaseModel, **kwargs) -> BaseModel:
"""Iteratively apply all configured model transforms.
Args:
model: The base model.
**kwargs: Context arguments (device, collective, etc.).
Returns:
The transformed model.
"""
for transform_cfg in self.model_transforms:
try:
transform = build_model_transform(transform_cfg, **kwargs)
if transform is not None:
logger.info(f"Applying model transform: {transform}")
model = transform.apply(model, **kwargs)
else:
logger.warning(
f"Failed to build model transform from config: {transform_cfg}"
)
except Exception as e:
logger.error(f"Failed to apply model transform {transform_cfg}: {e}")
raise
return model
|
build_model(model_config, collective, **kwargs)
Build the model and apply all configured transforms.
Parameters:
| Name |
Type |
Description |
Default |
model_config
|
ModelConfig | None
|
Configuration for the model architecture.
|
required
|
collective
|
Collective
|
Distributed collective for transforms that need it.
|
required
|
**kwargs
|
|
Additional arguments passed to model constructor and transforms.
|
{}
|
Returns:
| Type |
Description |
BaseModel
|
The fully constructed and transformed model.
|
Source code in optimus_dl/recipe/mixins/model_builder.py
| def build_model(
self, model_config: ModelConfig | None, collective: Collective, **kwargs
) -> BaseModel:
"""Build the model and apply all configured transforms.
Args:
model_config: Configuration for the model architecture.
collective: Distributed collective for transforms that need it.
**kwargs: Additional arguments passed to model constructor and transforms.
Returns:
The fully constructed and transformed model.
"""
if model_config is None:
raise ValueError(
"model_config is None. Use build_model_from_checkpoint for evaluation."
)
model = build_model(model_config, **kwargs)
num_param_before = get_num_parameters(model)
logger.info(f"Params num (before model transforms): {num_param_before:,}")
log_averaged("model/num_params_before_transforms", num_param_before)
assert isinstance(model, BaseModel)
# Apply model transforms (including distributed setup)
model = self._apply_model_transforms(
model, collective=collective, device=collective.default_device, **kwargs
)
num_param_after = get_num_parameters(model)
logger.info(f"Model \n{model}")
logger.info(f"Params num (after model transforms): {num_param_after:,}")
log_averaged("model/num_params_after_transforms", num_param_after)
return model
|