Skip to content

registry

optimus_dl.core.registry

Registry system for dependency injection and component management.

This module provides a flexible registry pattern that allows components (models, optimizers, data loaders, etc.) to be registered and instantiated from configuration. The registry system is the foundation of Optimus-DL's modular architecture, enabling easy component swapping and configuration-driven instantiation.

Example

Basic usage:

# Create a registry
registry, register, build = make_registry("my_component")

# Register a component
@register("my_impl", MyConfig)
class MyImplementation:
    def __init__(self, cfg: MyConfig):
        self.cfg = cfg

# Build from config
config = MyConfig(_name="my_impl", param=1)
obj = build(config)

RegistryConfig dataclass

Bases: dict[str, Any]

Flexible configuration base class for registry components.

This extends RegistryConfigStrict to allow arbitrary additional fields via dictionary inheritance. Use this when you need custom configuration parameters beyond static fields or for dynamic configurations.

Attributes:

Name Type Description
_name str | None

The registered name of the component to instantiate.

extra_fields str | None

Additional fields can be added as dictionary keys.

Source code in optimus_dl/core/registry.py
@dataclass
class RegistryConfig(dict[str, Any]):
    """Flexible configuration base class for registry components.

    This extends RegistryConfigStrict to allow arbitrary additional fields
    via dictionary inheritance. Use this when you need custom configuration
    parameters beyond static fields or for dynamic configurations.

    Attributes:
        _name: The registered name of the component to instantiate.
        extra_fields: Additional fields can be added as dictionary keys.
    """

    _name: str | None = None

RegistryConfigStrict dataclass

Strict configuration base class for registry components.

This is a minimal configuration class that only requires a _name field. Use this when you don't need additional configuration fields.

Attributes:

Name Type Description
_name str | None

The registered name of the component to instantiate.

Source code in optimus_dl/core/registry.py
@dataclass
class RegistryConfigStrict:
    """Strict configuration base class for registry components.

    This is a minimal configuration class that only requires a `_name` field.
    Use this when you don't need additional configuration fields.

    Attributes:
        _name: The registered name of the component to instantiate.
    """

    _name: str | None = None

build(registry_name, cfg, cast_to=None, **kwargs)

build(registry_name: str, cfg: CorrectCfg, cast_to: type[T], **kwargs: Any) -> T
build(registry_name: str, cfg: CorrectCfg, cast_to: None = None, **kwargs: Any) -> Any

Build a component from a named registry.

This is a convenience function that builds a component from a registry by name. It's useful when you know the registry name but don't have the registry functions directly available.

Parameters:

Name Type Description Default
registry_name str

Name of the registry to build from (e.g., "model", "optimizer").

required
cfg CorrectCfg | None

Configuration object with _name field specifying the component.

required
cast_to type[T] | None

Optional type to cast the result to.

None
**kwargs Any

Additional arguments passed to the component constructor.

{}

Returns:

Type Description
T | Any | None

An instance of the registered component.

Raises:

Type Description
AssertionError

If the registry doesn't exist, the component name is not found, or if cast_to is provided and the object is not of that type.

Example
config = RegistryConfig(_name="llama", n_embd=512)
model = build("model", config, cast_to=BaseModel)
Source code in optimus_dl/core/registry.py
def build(
    registry_name: str,
    cfg: CorrectCfg | None,
    cast_to: type[T] | None = None,
    **kwargs: Any,
) -> T | Any | None:
    """Build a component from a named registry.

    This is a convenience function that builds a component from a registry by name.
    It's useful when you know the registry name but don't have the registry
    functions directly available.

    Args:
        registry_name: Name of the registry to build from (e.g., "model", "optimizer").
        cfg: Configuration object with `_name` field specifying the component.
        cast_to: Optional type to cast the result to.
        **kwargs: Additional arguments passed to the component constructor.

    Returns:
        An instance of the registered component.

    Raises:
        AssertionError: If the registry doesn't exist, the component name is not
            found, or if cast_to is provided and the object is not of that type.

    Example:
        ```python
        config = RegistryConfig(_name="llama", n_embd=512)
        model = build("model", config, cast_to=BaseModel)

        ```"""
    assert registry_name in registries, f"Unknown registry {registry_name}"
    _, _, build_fn = make_registry(registry_name)
    obj = build_fn(cfg, **kwargs)
    if cast_to is not None:
        assert isinstance(obj, cast_to), f"Expected {cast_to}, got {type(obj)}"
    return obj

make_registry(registry_name, base_class=None)

Create or retrieve a component registry.

This function creates a new registry or retrieves an existing one. Each registry maintains a mapping of component names to their classes and configuration classes. The registry pattern enables dependency injection and configuration-driven component instantiation.

Parameters:

Name Type Description Default
registry_name str

Unique name for the registry (e.g., "model", "optimizer").

required
base_class type | type[Any] | None

Optional base class that all registered components must inherit from. Used for type checking. If None, any class can be registered.

None

Returns:

Name Type Description
registry_dict dict

The actual registry dictionary (for inspection/debugging)

register callable

Decorator function to register components

build callable

Function to build instances from configuration

Example
registry, register, build = make_registry("model", BaseModel)

@register("llama", LlamaConfig)
class Llama(BaseModel):
    def __init__(self, cfg: LlamaConfig):
        self.cfg = cfg

# Later, build from config
config = LlamaConfig(_name="llama", n_embd=512)
model = build(config)

Note: The registry is stored globally. Multiple calls with the same registry_name will return the same registry instance.

Source code in optimus_dl/core/registry.py
def make_registry(registry_name: str, base_class: type | type[Any] | None = None):
    """Create or retrieve a component registry.

    This function creates a new registry or retrieves an existing one. Each registry
    maintains a mapping of component names to their classes and configuration classes.
    The registry pattern enables dependency injection and configuration-driven
    component instantiation.

    Args:
        registry_name: Unique name for the registry (e.g., "model", "optimizer").
        base_class: Optional base class that all registered components must inherit
            from. Used for type checking. If None, any class can be registered.

    Returns:
        registry_dict (dict): The actual registry dictionary (for inspection/debugging)
        register (callable): Decorator function to register components
        build (callable): Function to build instances from configuration

    Example:
        ```python
        registry, register, build = make_registry("model", BaseModel)

        @register("llama", LlamaConfig)
        class Llama(BaseModel):
            def __init__(self, cfg: LlamaConfig):
                self.cfg = cfg

        # Later, build from config
        config = LlamaConfig(_name="llama", n_embd=512)
        model = build(config)

        ```
    Note:
        The registry is stored globally. Multiple calls with the same `registry_name`
        will return the same registry instance.
    """
    if registry_name in registries:
        registry = registries[registry_name]
    else:
        registries[registry_name] = {}
        registry = registries[registry_name]

    def register_arch(name: str, class_name: str, registered_class: type):
        """Register an architecture variant of a base class.

        This allows registering multiple variants (architectures) of a base class.
        For example, you might have a base "llama" model with variants like
        "llama-7b", "llama-13b", "llama-70b".

        Args:
            name: Name of the architecture variant.
            class_name: Name of the base class this variant belongs to.
            registered_class: The class implementing this architecture.

        Returns:
            A decorator function that takes a config factory method.

        Raises:
            AssertionError: If the architecture is already registered or if the
                base class doesn't exist or doesn't have a structured config.

        Example:
            ```python
            @register("llama", LlamaConfig)
            class Llama(BaseModel):
                pass

            @Llama.register_arch("7b")
            def llama_7b_config():
                return LlamaConfig(n_layers=32, n_embd=4096)

            ```"""
        full_name = f"{class_name}-{name}"

        assert (
            full_name not in registry
        ), f"Double registering of {full_name} in {registry_name} registry"

        base_registry = registry.get(class_name, None)
        assert (
            base_registry is not None
        ), f"Base class {class_name} not found in {registry_name} registry"
        assert (
            base_registry[1] is not None
        ), f"Base class {class_name} must have a structured config in {registry_name} registry to register architectures"

        def wrapper(method):
            cfg = method()
            assert (
                cfg is None
                or isinstance(cfg, RegistryConfig)
                or isinstance(cfg, RegistryConfigStrict)
            ), "Configs must be subclasses of RegistryConfig"
            registry[full_name] = (registered_class, cfg)
            return method

        return wrapper

    def register(name: str, cfg: type | None = None):
        """Register a component in the registry.

        This is the main decorator for registering components. It associates a
        component class with a name and optional configuration class.

        Args:
            name: Unique name to register the component under (e.g., "llama", "adamw").
            cfg: Optional configuration class that must be a subclass of
                RegistryConfig or RegistryConfigStrict. If None, the component
                will be instantiated without configuration.

        Returns:
            A decorator function that registers the decorated class.

        Raises:
            AssertionError: If the name is already registered or if cfg is not
                a valid config class.

        Example:
            ```python
            @register("my_model", MyModelConfig)
            class MyModel:
                def __init__(self, cfg: MyModelConfig):
                    self.cfg = cfg

            ```"""
        assert (
            name not in registry
        ), f"Double registering of {name} in {registry_name} registry"
        assert (
            cfg is None
            or issubclass(cfg, RegistryConfig)
            or issubclass(cfg, RegistryConfigStrict)
        ), "Configs must be subclasses of RegistryConfig or RegistryConfigStrict"

        def wrapper(registered_class: C) -> C:
            """Decorator that registers a class in the registry.

            Args:
                registered_class: The class to register.

            Returns:
                The same class (for use as a decorator).
            """
            registry[name] = (registered_class, cfg)
            registered_class.register_arch = functools.partial(
                register_arch,
                class_name=name,
                registered_class=registered_class,
            )
            return registered_class

        return wrapper

    @overload
    def build(
        cfg: CorrectCfg,
        cast_to: type[T],
        **kwargs: Any,
    ) -> T: ...

    @overload
    def build(
        cfg: CorrectCfg,
        cast_to: None = None,
        **kwargs: Any,
    ) -> Any: ...

    def build(
        cfg: CorrectCfg | None,
        cast_to: type[T] | None = None,
        **kwargs: Any,
    ) -> T | Any | None:
        """Build a component instance from configuration.

        This function instantiates a registered component based on its configuration.
        It handles merging default configuration values with provided overrides,
        and supports both structured configs and plain dictionaries.

        Args:
            cfg: Configuration object containing `_name` field specifying which
                component to build. Can be a RegistryConfig, dict, or None.
                If None, returns None. If a string, treats it as the component name.
            cast_to: Optional type to cast the result to. If provided, raises
                AssertionError if the built object is not an instance of this type.
            **kwargs: Additional keyword arguments passed to the component constructor.

        Returns:
            An instance of the registered component, or None if cfg is None.

        Raises:
            AssertionError: If the component name is not found in the registry, or
                if cast_to is provided and the built object is not of that type.

        Example:
            ```python
            config = RegistryConfig(_name="llama", n_embd=512)
            model = build(config, cast_to=BaseModel)
            assert isinstance(model, Llama)

            ```"""
        if cfg is None:
            return None
        cfg_orig = cfg
        if isinstance(cfg, str):
            name: str = cfg
            cfg = {}
        else:
            if not omegaconf.OmegaConf.is_config(cfg):
                cfg = omegaconf.OmegaConf.structured(cfg)
            name: str = cfg["_name"]
        assert name in registry, f"Unknown {name} in {registry_name} registry"
        registered_class, structured_cfg = registry[name]
        structured_cfg_original = structured_cfg
        is_strict = False
        if type(structured_cfg) is type:
            is_strict = issubclass(structured_cfg, RegistryConfigStrict)
            structured_cfg = omegaconf.OmegaConf.merge(
                omegaconf.OmegaConf.structured(structured_cfg), structured_cfg()
            )
        if structured_cfg is not None:
            if is_strict:
                expected_keys = set(structured_cfg.keys())
                try:
                    actual_keys = set(cfg.keys())
                except AttributeError as e:
                    raise ValueError(
                        f"Cannot get true keys for config {type(cfg)}: {cfg}"
                    ) from e

                maybe_path = ".".join(_get_cfg_path(cfg_orig) or ["<root>"])
                assert actual_keys.issubset(expected_keys), (
                    f"For {maybe_path} {structured_cfg_original} expected keys {expected_keys}, "
                    f"got {actual_keys},\n"
                    f"diff: {actual_keys - expected_keys}"
                )
            try:
                cfg = omegaconf.OmegaConf.merge(
                    structured_cfg,
                    omegaconf.OmegaConf.to_container(cfg=cfg, resolve=True),
                )
                if is_strict:
                    cfg = validate_and_cast(structured_cfg_original, cfg)
            except omegaconf.errors.ConfigTypeError:
                logger.error(f"{structured_cfg = }\n====\n{cfg = }")
                raise
            obj = registered_class(cfg, **kwargs)
        else:
            obj = registered_class(**kwargs)

        if cast_to is not None:
            assert isinstance(obj, cast_to), f"Expected {cast_to}, got {type(obj)}"
        if base_class is not None:
            assert isinstance(
                obj, base_class
            ), f"Expected {base_class}, got {type(obj)}"
        return obj

    return registry, register, build

validate_and_cast(cls, cfg, path='')

Recursively validates and casts config based on cls typings.

Parameters:

Name Type Description Default
cls Any

The class or type to validate against.

required
cfg Any

The configuration object to validate and cast.

required
path str

Path prefix to use in error messages, indicating the logical location of this config within a larger configuration structure.

''

Returns:

Type Description
Any

The casted and validated configuration, as an OmegaConf object if applicable.

Source code in optimus_dl/core/registry.py
def validate_and_cast(cls: Any, cfg: Any, path: str = "") -> Any:
    """Recursively validates and casts config based on cls typings.

    Args:
        cls: The class or type to validate against.
        cfg: The configuration object to validate and cast.
        path: Path prefix to use in error messages, indicating the logical
            location of this config within a larger configuration structure.

    Returns:
        The casted and validated configuration, as an OmegaConf object if applicable.
    """
    res = _validate_and_cast_recursive(cls, cfg, path)
    # Only wrap in OmegaConf if it's a collection or dataclass at the top level.
    # Recursive calls return raw objects to ensure proper dataclass construction.
    if is_dataclass(cls) or get_origin(cls) in (list, dict, list, dict):
        return omegaconf.OmegaConf.create(res)
    return res