Skip to content

model_info

optimus_dl.modules.metrics.sources.model_info

ModelInfoSource

Bases: MetricSource

Source for extracting static info about model.

Source code in optimus_dl/modules/metrics/sources/model_info.py
@register_metric_source("model_info", ModelInfoSourceConfig)
class ModelInfoSource(MetricSource):
    """Source for extracting static info about model."""

    cfg: ModelInfoSourceConfig

    @property
    def provides(self) -> set[str]:
        return {StandardProtocols.MODEL_PARAMETERS_COUNT}

    @torch.no_grad()
    def __call__(
        self,
        dependencies: dict[str, dict[str, Any]],
        model: Any,
        batch: Any,
        **kwargs: Any,
    ) -> dict[str, Any]:
        model_parameters_count = sum(p.numel() for p in model.parameters())
        model_parameters_bytes = sum(
            p.numel() * p.element_size() for p in model.parameters()
        )
        model_trainable_parameters_count = sum(
            p.numel() for p in model.parameters() if p.requires_grad
        )

        return {
            StandardProtocols.MODEL_PARAMETERS_COUNT: dict(
                parameters_count=model_parameters_count,
                parameters_bytes=model_parameters_bytes,
                trainable_parameters_count=model_trainable_parameters_count,
            )
        }

ModelInfoSourceConfig dataclass

Bases: MetricSourceConfig

Configuration for ModelInfoSource.

Source code in optimus_dl/modules/metrics/sources/model_info.py
@dataclass
class ModelInfoSourceConfig(MetricSourceConfig):
    """Configuration for ModelInfoSource."""