Skip to content

base

optimus_dl.modules.data.datasets.base

Base dataset class for data sources.

This module defines the BaseDataset class that all data sources must inherit from. It provides integration with torchdata's pipeline system and checkpointing support.

BaseDataset

Bases: BaseNode

Base class for all dataset implementations.

All data sources in Optimus-DL should inherit from this class. It provides:

  • Integration with torchdata's pipeline system
  • Checkpointing support for resuming data iteration
  • Configuration storage

Subclasses should implement:

  • The data iteration logic (inherited from torchdata.nodes.BaseNode)
  • Optionally override load_state_dict() for custom checkpointing
Example
@register_dataset("my_dataset", MyDatasetConfig)
class MyDataset(BaseDataset):
    def __init__(self, cfg: MyDatasetConfig, **kwargs):
        super().__init__(cfg, **kwargs)
        self.data = load_data(cfg.data_path)

    def __iter__(self):
        for item in self.data:
            yield item
Source code in optimus_dl/modules/data/datasets/base.py
class BaseDataset(torchdata.nodes.BaseNode):
    """Base class for all dataset implementations.

    All data sources in Optimus-DL should inherit from this class. It provides:

    - Integration with torchdata's pipeline system
    - Checkpointing support for resuming data iteration
    - Configuration storage

    Subclasses should implement:

    - The data iteration logic (inherited from torchdata.nodes.BaseNode)
    - Optionally override `load_state_dict()` for custom checkpointing

    Example:
        ```python
        @register_dataset("my_dataset", MyDatasetConfig)
        class MyDataset(BaseDataset):
            def __init__(self, cfg: MyDatasetConfig, **kwargs):
                super().__init__(cfg, **kwargs)
                self.data = load_data(cfg.data_path)

            def __iter__(self):
                for item in self.data:
                    yield item

        ```"""

    def __init__(self, cfg, **kwargs):
        """Initialize the base dataset.

        Args:
            cfg: Configuration object for this dataset.
            **kwargs: Additional keyword arguments passed from the data builder.
        """
        super().__init__()
        self.cfg = cfg

    def __repr__(self):
        return f"{self.__class__.__name__}()"

    def load_state_dict(self, state_dict: dict) -> None:
        """Load dataset state from checkpoint.

        This method restores the dataset's iteration state, allowing training
        to resume from the same position in the dataset. The default implementation
        uses torchdata's `reset()` method.

        Args:
            state_dict: Dictionary containing the dataset's saved state.
                Typically includes iteration position, random state, etc.

        Note:
            Subclasses can override this to handle custom state restoration.
            The state_dict is typically saved by the checkpoint manager.
        """
        self.reset(state_dict)

__init__(cfg, **kwargs)

Initialize the base dataset.

Parameters:

Name Type Description Default
cfg

Configuration object for this dataset.

required
**kwargs

Additional keyword arguments passed from the data builder.

{}
Source code in optimus_dl/modules/data/datasets/base.py
def __init__(self, cfg, **kwargs):
    """Initialize the base dataset.

    Args:
        cfg: Configuration object for this dataset.
        **kwargs: Additional keyword arguments passed from the data builder.
    """
    super().__init__()
    self.cfg = cfg

load_state_dict(state_dict)

Load dataset state from checkpoint.

This method restores the dataset's iteration state, allowing training to resume from the same position in the dataset. The default implementation uses torchdata's reset() method.

Parameters:

Name Type Description Default
state_dict dict

Dictionary containing the dataset's saved state. Typically includes iteration position, random state, etc.

required
Note

Subclasses can override this to handle custom state restoration. The state_dict is typically saved by the checkpoint manager.

Source code in optimus_dl/modules/data/datasets/base.py
def load_state_dict(self, state_dict: dict) -> None:
    """Load dataset state from checkpoint.

    This method restores the dataset's iteration state, allowing training
    to resume from the same position in the dataset. The default implementation
    uses torchdata's `reset()` method.

    Args:
        state_dict: Dictionary containing the dataset's saved state.
            Typically includes iteration position, random state, etc.

    Note:
        Subclasses can override this to handle custom state restoration.
        The state_dict is typically saved by the checkpoint manager.
    """
    self.reset(state_dict)