Skip to content

composite

optimus_dl.modules.data.datasets.composite

CompositeDataset

Bases: BaseDataset

Dataset that combines multiple sub-datasets with weighted sampling.

This class orchestrates a collection of datasets, sampling from them according to specified weights. It handles:

  • Weighted Sampling: Using a rank-safe multinomial sampler.
  • Exhaustion Policies: Can stop training when one/all datasets are exhausted or cycle through them indefinitely.
  • Hierarchical Checkpointing: Correctly saves and restores the state of all sub-datasets and the global sampling state.

Parameters:

Name Type Description Default
cfg CompositeDatasetConfig

Composite dataset configuration.

required
rank int

Distributed rank.

required
world_size int

Total number of ranks.

required
Source code in optimus_dl/modules/data/datasets/composite.py
@register_dataset("composite", CompositeDatasetConfig)
class CompositeDataset(BaseDataset):
    """Dataset that combines multiple sub-datasets with weighted sampling.

    This class orchestrates a collection of datasets, sampling from them according
    to specified weights. It handles:

    - **Weighted Sampling**: Using a rank-safe multinomial sampler.
    - **Exhaustion Policies**: Can stop training when one/all datasets are
      exhausted or cycle through them indefinitely.
    - **Hierarchical Checkpointing**: Correctly saves and restores the state
      of all sub-datasets and the global sampling state.

    Args:
        cfg: Composite dataset configuration.
        rank: Distributed rank.
        world_size: Total number of ranks.
    """

    DATASET_NODE_STATES_KEY = "dataset_node_states"
    DATASETS_EXHAUSTED_KEY = "datasets_exhausted"
    EPOCH_KEY = "epoch"
    NUM_YIELDED_KEY = "num_yielded"
    WEIGHTED_SAMPLER_STATE_KEY = "weighted_sampler_state"

    def __init__(
        self,
        cfg: CompositeDatasetConfig,
        rank: int,
        world_size: int,
        seed: int,
        **kwargs,
    ):
        super().__init__(cfg)
        self.rank = rank
        self.world_size = world_size

        self.datasets = {}
        self.weights = {}
        self.cycle_flags = {}

        for name, ds_cfg in cfg.datasets.items():
            logger.info(f"Initializing sub-dataset {name} with weight {ds_cfg.weight}")
            # Sub-datasets are likely BaseNodes themselves
            self.datasets[name] = build_dataset(
                ds_cfg.dataset, rank=rank, world_size=world_size, **kwargs
            )
            self.weights[name] = ds_cfg.weight
            self.cycle_flags[name] = ds_cfg.cycle

        self.stop_criteria = cfg.stop_criteria
        self.seed = seed
        self.strict_load = cfg.strict_load

        self._validate()

        self._epoch = 0
        self._num_yielded = 0
        self._weighted_sampler = None
        self._datasets_exhausted = {}

    def _validate(self):
        for weight in self.weights.values():
            if weight < 0:
                raise ValueError("Weights must be non-negative")

    def reset(self, initial_state: dict[str, Any] | None = None):
        """Reset or restore the composite dataset state.

        Restores global epoch/yield counters, the weighted sampler state, and
        recursively calls reset() on all sub-datasets.
        """
        super().reset(initial_state)

        config_datasets = self.datasets.keys()

        if initial_state is not None:
            # Handle strict_load
            state_datasets = initial_state.get(self.DATASET_NODE_STATES_KEY, {}).keys()

            if self.strict_load:
                if set(state_datasets) != set(config_datasets):
                    raise ValueError(
                        f"Strict load enabled. Mismatch in datasets.\n"
                        f"Config: {list(config_datasets)}\nState: {list(state_datasets)}"
                    )

            self._num_yielded = initial_state.get(self.NUM_YIELDED_KEY, 0)
            self._epoch = initial_state.get(self.EPOCH_KEY, 0)
            self._datasets_exhausted = initial_state.get(
                self.DATASETS_EXHAUSTED_KEY, dict.fromkeys(config_datasets, False)
            )

            # If config matches state datasets, we can load sampler state
            if set(state_datasets) == set(config_datasets):
                sampler_state = initial_state.get(self.WEIGHTED_SAMPLER_STATE_KEY)
                self._weighted_sampler = self._get_new_weighted_sampler(sampler_state)
            else:
                # Mismatch and strict_load=False: Reset sampler
                logger.warning(
                    "Dataset configuration changed (or strict_load=False), resetting weighted sampler state."
                )
                self._weighted_sampler = self._get_new_weighted_sampler(None)

            # Load sub-datasets
            for name, dataset in self.datasets.items():
                if name in initial_state.get(self.DATASET_NODE_STATES_KEY, {}):
                    dataset.reset(initial_state[self.DATASET_NODE_STATES_KEY][name])
                else:
                    if self.strict_load:
                        # Should have been caught above, but safety check
                        raise ValueError(f"Missing state for dataset {name}")
                    logger.info(f"Resetting dataset {name} (not found in state).")
                    dataset.reset()
        else:
            # Fresh start
            self._num_yielded = 0
            self._epoch = 0

            self._weighted_sampler = self._get_new_weighted_sampler()
            self._datasets_exhausted = dict.fromkeys(self.datasets, False)
            for dataset in self.datasets.values():
                dataset.reset()

    def _get_new_weighted_sampler(self, initial_state=None):
        return _WeightedSampler(
            weights=self.weights,
            seed=self.seed,
            rank=self.rank,
            world_size=self.world_size,
            epoch=self._epoch,
            initial_state=initial_state,
        )

    def _check_for_stop_iteration(self) -> None:
        if self.stop_criteria == StopCriteria.CYCLE_FOREVER:
            return

        if all(self._datasets_exhausted.values()):
            raise StopIteration()

        if self.stop_criteria == StopCriteria.FIRST_DATASET_EXHAUSTED and any(
            self._datasets_exhausted.values()
        ):
            raise StopIteration()

    def next(self) -> Any:
        """Sample the next item from one of the sub-datasets.

        Uses the internal weighted sampler to choose a dataset, then delegates
        to that dataset's next() method. If a dataset is exhausted, it is either
        reset (cycled) or removed from sampling depending on configuration.
        """
        while True:
            self._check_for_stop_iteration()

            # Get next dataset to sample from
            try:
                ds_name = next(self._weighted_sampler)
            except StopIteration as err:
                # If sampler is empty (all weights 0), we should have caught it in check_for_stop_iteration
                # unless there's a sync issue. Treat as end of data.
                raise RuntimeError(
                    "Exhausted all datasets and cannot cycle throug"
                ) from err
            try:
                assert not self._datasets_exhausted[ds_name]
                item = next(self.datasets[ds_name])
                self._num_yielded += 1
                return item

            except StopIteration:
                self._datasets_exhausted[ds_name] = True

                if self.cycle_flags[ds_name]:
                    # Reset this dataset
                    logger.debug(f"Cycling dataset {ds_name}")
                    self.datasets[ds_name].reset()
                    self._datasets_exhausted[ds_name] = False
                    try:
                        item = next(self.datasets[ds_name])
                        self._num_yielded += 1
                        return item
                    except StopIteration as err:
                        raise RuntimeError(
                            "Cannot yield at least one item from dataset after resetting and trying to cycle"
                        ) from err
                else:
                    # Not cycling: Disable this dataset in sampler to avoid polling it again
                    self._weighted_sampler.set_active(ds_name, False)

                self._check_for_stop_iteration()

    def get_state(self) -> dict[str, Any]:
        """Collect current state for checkpointing."""
        return {
            self.DATASETS_EXHAUSTED_KEY: copy.deepcopy(self._datasets_exhausted),
            self.DATASET_NODE_STATES_KEY: {
                k: v.state_dict() for k, v in self.datasets.items()
            },
            self.EPOCH_KEY: self._epoch,
            self.NUM_YIELDED_KEY: self._num_yielded,
            self.WEIGHTED_SAMPLER_STATE_KEY: (
                self._weighted_sampler.state_dict() if self._weighted_sampler else None
            ),
        }

get_state()

Collect current state for checkpointing.

Source code in optimus_dl/modules/data/datasets/composite.py
def get_state(self) -> dict[str, Any]:
    """Collect current state for checkpointing."""
    return {
        self.DATASETS_EXHAUSTED_KEY: copy.deepcopy(self._datasets_exhausted),
        self.DATASET_NODE_STATES_KEY: {
            k: v.state_dict() for k, v in self.datasets.items()
        },
        self.EPOCH_KEY: self._epoch,
        self.NUM_YIELDED_KEY: self._num_yielded,
        self.WEIGHTED_SAMPLER_STATE_KEY: (
            self._weighted_sampler.state_dict() if self._weighted_sampler else None
        ),
    }

next()

Sample the next item from one of the sub-datasets.

Uses the internal weighted sampler to choose a dataset, then delegates to that dataset's next() method. If a dataset is exhausted, it is either reset (cycled) or removed from sampling depending on configuration.

Source code in optimus_dl/modules/data/datasets/composite.py
def next(self) -> Any:
    """Sample the next item from one of the sub-datasets.

    Uses the internal weighted sampler to choose a dataset, then delegates
    to that dataset's next() method. If a dataset is exhausted, it is either
    reset (cycled) or removed from sampling depending on configuration.
    """
    while True:
        self._check_for_stop_iteration()

        # Get next dataset to sample from
        try:
            ds_name = next(self._weighted_sampler)
        except StopIteration as err:
            # If sampler is empty (all weights 0), we should have caught it in check_for_stop_iteration
            # unless there's a sync issue. Treat as end of data.
            raise RuntimeError(
                "Exhausted all datasets and cannot cycle throug"
            ) from err
        try:
            assert not self._datasets_exhausted[ds_name]
            item = next(self.datasets[ds_name])
            self._num_yielded += 1
            return item

        except StopIteration:
            self._datasets_exhausted[ds_name] = True

            if self.cycle_flags[ds_name]:
                # Reset this dataset
                logger.debug(f"Cycling dataset {ds_name}")
                self.datasets[ds_name].reset()
                self._datasets_exhausted[ds_name] = False
                try:
                    item = next(self.datasets[ds_name])
                    self._num_yielded += 1
                    return item
                except StopIteration as err:
                    raise RuntimeError(
                        "Cannot yield at least one item from dataset after resetting and trying to cycle"
                    ) from err
            else:
                # Not cycling: Disable this dataset in sampler to avoid polling it again
                self._weighted_sampler.set_active(ds_name, False)

            self._check_for_stop_iteration()

reset(initial_state=None)

Reset or restore the composite dataset state.

Restores global epoch/yield counters, the weighted sampler state, and recursively calls reset() on all sub-datasets.

Source code in optimus_dl/modules/data/datasets/composite.py
def reset(self, initial_state: dict[str, Any] | None = None):
    """Reset or restore the composite dataset state.

    Restores global epoch/yield counters, the weighted sampler state, and
    recursively calls reset() on all sub-datasets.
    """
    super().reset(initial_state)

    config_datasets = self.datasets.keys()

    if initial_state is not None:
        # Handle strict_load
        state_datasets = initial_state.get(self.DATASET_NODE_STATES_KEY, {}).keys()

        if self.strict_load:
            if set(state_datasets) != set(config_datasets):
                raise ValueError(
                    f"Strict load enabled. Mismatch in datasets.\n"
                    f"Config: {list(config_datasets)}\nState: {list(state_datasets)}"
                )

        self._num_yielded = initial_state.get(self.NUM_YIELDED_KEY, 0)
        self._epoch = initial_state.get(self.EPOCH_KEY, 0)
        self._datasets_exhausted = initial_state.get(
            self.DATASETS_EXHAUSTED_KEY, dict.fromkeys(config_datasets, False)
        )

        # If config matches state datasets, we can load sampler state
        if set(state_datasets) == set(config_datasets):
            sampler_state = initial_state.get(self.WEIGHTED_SAMPLER_STATE_KEY)
            self._weighted_sampler = self._get_new_weighted_sampler(sampler_state)
        else:
            # Mismatch and strict_load=False: Reset sampler
            logger.warning(
                "Dataset configuration changed (or strict_load=False), resetting weighted sampler state."
            )
            self._weighted_sampler = self._get_new_weighted_sampler(None)

        # Load sub-datasets
        for name, dataset in self.datasets.items():
            if name in initial_state.get(self.DATASET_NODE_STATES_KEY, {}):
                dataset.reset(initial_state[self.DATASET_NODE_STATES_KEY][name])
            else:
                if self.strict_load:
                    # Should have been caught above, but safety check
                    raise ValueError(f"Missing state for dataset {name}")
                logger.info(f"Resetting dataset {name} (not found in state).")
                dataset.reset()
    else:
        # Fresh start
        self._num_yielded = 0
        self._epoch = 0

        self._weighted_sampler = self._get_new_weighted_sampler()
        self._datasets_exhausted = dict.fromkeys(self.datasets, False)
        for dataset in self.datasets.values():
            dataset.reset()

CompositeDatasetConfig dataclass

Bases: RegistryConfigStrict

CompositeDatasetConfig(_name: str | None = None, datasets: dict[str, optimus_dl.modules.data.datasets.composite.DatasetConfig] = , strict_load: bool = True, stop_criteria: optimus_dl.modules.data.datasets.composite.StopCriteria = )

Parameters:

Name Type Description Default
datasets dict[str, DatasetConfig]

Datasets to load: name -> config

<class 'dict'>
strict_load bool

Whether to raise an error if state dict does not contain all required keys

True
stop_criteria StopCriteria

Stop criteria for the composite dataset

<StopCriteria.CYCLE_FOREVER: 'CYCLE_FOREVER'>
Source code in optimus_dl/modules/data/datasets/composite.py
@dataclass
class CompositeDatasetConfig(RegistryConfigStrict):
    datasets: dict[str, DatasetConfig] = field(
        default_factory=dict,
        metadata={"description": "Datasets to load: name -> config"},
    )
    strict_load: bool = field(
        default=True,
        metadata={
            "description": "Whether to raise an error if state dict does not contain all required keys"
        },
    )
    stop_criteria: StopCriteria = field(
        default=StopCriteria.CYCLE_FOREVER,
        metadata={"description": "Stop criteria for the composite dataset"},
    )

DatasetConfig dataclass

DatasetConfig(dataset: optimus_dl.core.registry.RegistryConfig = '???', weight: float = 1.0, cycle: bool = True)

Parameters:

Name Type Description Default
dataset RegistryConfig

Dataset config to load

'???'
weight float

Weight of the dataset for sampling

1.0
cycle bool

Whether to cycle through the dataset after it is exhausted

True
Source code in optimus_dl/modules/data/datasets/composite.py
@dataclass
class DatasetConfig:
    dataset: RegistryConfig = field(
        default=MISSING, metadata={"description": "Dataset config to load"}
    )
    weight: float = field(
        default=1.0, metadata={"description": "Weight of the dataset for sampling"}
    )
    cycle: bool = field(
        default=True,
        metadata={
            "description": "Whether to cycle through the dataset after it is exhausted"
        },
    )