Skip to content

source

optimus_dl.modules.metrics.source

MetricSource

Bases: ABC

Base class for data producers that extract information from the model.

Source code in optimus_dl/modules/metrics/source.py
class MetricSource(ABC):
    """Base class for data producers that extract information from the model."""

    def __init__(self, cfg: MetricSourceConfig):
        self.cfg = cfg
        self._hash: str | None = None

    @property
    def config_hash(self) -> str:
        """Returns a deterministic hash of the source's configuration for cross-group caching."""
        if self._hash is None:
            import dataclasses

            if dataclasses.is_dataclass(self.cfg):
                cfg_dict = dataclasses.asdict(self.cfg)
            else:
                cfg_dict = (
                    self.cfg.__dict__
                    if hasattr(self.cfg, "__dict__")
                    else str(self.cfg)
                )

            def make_hashable(obj: Any) -> Any:
                if isinstance(obj, (tuple, list)):
                    return tuple(make_hashable(e) for e in obj)
                if isinstance(obj, dict):
                    return tuple(sorted((k, make_hashable(v)) for k, v in obj.items()))
                return obj

            stable_repr = str(make_hashable(cfg_dict))
            # Include the class name so different source types with same config don't collide
            stable_repr = f"{self.__class__.__name__}:{stable_repr}"
            self._hash = hashlib.md5(stable_repr.encode()).hexdigest()
        return self._hash

    @property
    @abstractmethod
    def provides(self) -> set[str]:
        """Returns the set of protocol strings this source provides."""
        raise NotImplementedError

    @property
    def requires(self) -> set[str]:
        """Mapping from internal dependency role name to required protocol strings.

        Override this if your source depends on the output of other sources.
        """
        return set()

    @abstractmethod
    def __call__(
        self, dependencies: dict[str, dict[str, Any]], **kwargs
    ) -> dict[str, Any]:
        """Execute the source and return a dictionary mapping Protocol string to data.

        Args:
            dependencies: Data from required sources, mapped by protocol.
        """
        raise NotImplementedError

config_hash property

Returns a deterministic hash of the source's configuration for cross-group caching.

provides abstractmethod property

Returns the set of protocol strings this source provides.

requires property

Mapping from internal dependency role name to required protocol strings.

Override this if your source depends on the output of other sources.

__call__(dependencies, **kwargs) abstractmethod

Execute the source and return a dictionary mapping Protocol string to data.

Parameters:

Name Type Description Default
dependencies dict[str, dict[str, Any]]

Data from required sources, mapped by protocol.

required
Source code in optimus_dl/modules/metrics/source.py
@abstractmethod
def __call__(
    self, dependencies: dict[str, dict[str, Any]], **kwargs
) -> dict[str, Any]:
    """Execute the source and return a dictionary mapping Protocol string to data.

    Args:
        dependencies: Data from required sources, mapped by protocol.
    """
    raise NotImplementedError

MetricSourceConfig dataclass

Bases: RegistryConfigStrict

Base configuration for metric sources.

Attributes:

Name Type Description

Parameters:

Name Type Description Default
dependencies dict[str, str]

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

<class 'dict'>
Source code in optimus_dl/modules/metrics/source.py
@dataclass
class MetricSourceConfig(RegistryConfigStrict):
    """Base configuration for metric sources.

    Attributes:
        dependencies: Maps internal role requirements to source names within the group.
    """

    dependencies: dict[str, str] = field(default_factory=dict)

StandardProtocols

Standardized string constants for common metric data protocols.

Source code in optimus_dl/modules/metrics/source.py
class StandardProtocols:
    """Standardized string constants for common metric data protocols."""

    MODEL_PARAMETERS_COUNT = "model_parameters_count"
    LOGITS = "logits"
    LOGITS_TARGETS = "logits_targets"
    LOGITS_MASK = "logits_mask"
    INPUT_TOKENS = "input_tokens"
    LOSS = "loss"
    GENERATED_TOKENS = "generated_tokens"
    CLASSIFICATION = "classification"