Skip to content

tokenized_flat_dataset

optimus_dl.modules.data.datasets.tokenized_flat_dataset

TokenizedFlatDataset

Bases: BaseDataset

Dataset that treats multiple token files as a single contiguous stream.

This dataset memory-maps all provided files and calculates a global token offset. It then partitions this global stream into equal segments for each distributed rank. This is ideal for pre-training on very large corpora where data is stored as raw token IDs.

Parameters:

Name Type Description Default
cfg TokenizedFlatDatasetConfig

Flat tokenized dataset configuration.

required
rank int

Distributed rank.

required
world_size int

Total number of ranks.

required
Source code in optimus_dl/modules/data/datasets/tokenized_flat_dataset.py
@register_dataset("tokenized_flat", TokenizedFlatDatasetConfig)
class TokenizedFlatDataset(BaseDataset):
    """Dataset that treats multiple token files as a single contiguous stream.

    This dataset memory-maps all provided files and calculates a global token
    offset. It then partitions this global stream into equal segments for each
    distributed rank. This is ideal for pre-training on very large corpora
    where data is stored as raw token IDs.

    Args:
        cfg: Flat tokenized dataset configuration.
        rank: Distributed rank.
        world_size: Total number of ranks.
    """

    def __init__(
        self, cfg: TokenizedFlatDatasetConfig, rank: int, world_size: int, **kwargs
    ):
        super().__init__(cfg)
        self.files = cfg.files
        self.seq_len = cfg.seq_len
        self.batch_size = cfg.batch_size
        self.dtype = cfg.dtype
        self.rank = rank
        self.world_size = world_size

    def _remap_files(self):
        """Memory-map all files and calculate rank-specific token boundaries."""
        # Safe dtype conversion without eval()
        dtype_map = {
            "np.uint8": np.uint8,
            "np.uint16": np.uint16,
            "np.uint32": np.uint32,
            "np.int8": np.int8,
            "np.int16": np.int16,
            "np.int32": np.int32,
            "np.float32": np.float32,
            "np.float64": np.float64,
            "uint8": np.uint8,
            "uint16": np.uint16,
            "uint32": np.uint32,
            "int8": np.int8,
            "int16": np.int16,
            "int32": np.int32,
            "float32": np.float32,
            "float64": np.float64,
        }

        dtype = dtype_map.get(self.dtype, np.uint16)
        if self.dtype not in dtype_map:
            import logging

            logger = logging.getLogger(__name__)
            logger.warning(f"Unknown dtype '{self.dtype}', defaulting to np.uint16")

        self.files_mapped = [np.memmap(i, dtype=dtype, mode="r") for i in self.files]
        self.cumlens = np.cumsum([len(i) for i in self.files_mapped])

        total_tokens = self.cumlens[-1]
        tokens_per_rank = math.floor(total_tokens / self.world_size)

        self.index = tokens_per_rank * (self.rank)
        self.limit = tokens_per_rank * (self.rank + 1)

    @property
    def file_index(self):
        """Find the index of the file containing the current token offset."""
        if self.index >= self.cumlens[-1]:
            return None
        return np.min(np.arange(len(self.files_mapped))[self.index < self.cumlens])

    def reset(self, initial_state: dict | None = None):
        """Restore dataset state or recalculate boundaries for a fresh start."""
        super().reset(initial_state)

        initial_state = initial_state or {}

        old_files = self.files
        self.files = initial_state.get("files", self.files)
        self.seq_len = initial_state.get("seq_len", self.seq_len)
        self.batch_size = initial_state.get("batch_size", self.batch_size)

        assert initial_state.get("world_size", self.world_size) == self.world_size
        assert initial_state.get("rank", self.rank) == self.rank

        if self.files != old_files or not hasattr(self, "index"):
            self._remap_files()

        self.index = initial_state.get("index", self.index)
        self.limit = initial_state.get("limit", self.limit)

    def get_state(self):
        """Return current token offset for checkpointing."""
        return {
            "files": self.files,
            "index": self.index,
            "limit": self.limit,
            "seq_len": self.seq_len,
            "batch_size": self.batch_size,
            "rank": self.rank,
            "world_size": self.world_size,
        }

    def _take_at_most(self, size):
        """Read at most `size` tokens from the current file, advancing the pointer."""
        file_index = self.file_index
        if file_index is None:
            raise StopIteration

        infile_index = self.index
        if file_index > 0:
            infile_index = self.index - self.cumlens[file_index - 1]

        to_take = min(size, len(self.files_mapped[file_index]) - infile_index)
        chunk = self.files_mapped[file_index][infile_index : infile_index + to_take]
        self.index += to_take
        return chunk

    def next(self):
        """Yield the next batch of sequences."""
        target_size = self.batch_size * self.seq_len
        result = self._take_at_most(target_size)

        while len(result) != target_size:
            left = target_size - len(result)
            result = np.concatenate((result, self._take_at_most(left)))

        if self.index > self.limit:
            raise StopIteration

        return {"input_ids": result.reshape(self.batch_size, self.seq_len)}

file_index property

Find the index of the file containing the current token offset.

get_state()

Return current token offset for checkpointing.

Source code in optimus_dl/modules/data/datasets/tokenized_flat_dataset.py
def get_state(self):
    """Return current token offset for checkpointing."""
    return {
        "files": self.files,
        "index": self.index,
        "limit": self.limit,
        "seq_len": self.seq_len,
        "batch_size": self.batch_size,
        "rank": self.rank,
        "world_size": self.world_size,
    }

next()

Yield the next batch of sequences.

Source code in optimus_dl/modules/data/datasets/tokenized_flat_dataset.py
def next(self):
    """Yield the next batch of sequences."""
    target_size = self.batch_size * self.seq_len
    result = self._take_at_most(target_size)

    while len(result) != target_size:
        left = target_size - len(result)
        result = np.concatenate((result, self._take_at_most(left)))

    if self.index > self.limit:
        raise StopIteration

    return {"input_ids": result.reshape(self.batch_size, self.seq_len)}

reset(initial_state=None)

Restore dataset state or recalculate boundaries for a fresh start.

Source code in optimus_dl/modules/data/datasets/tokenized_flat_dataset.py
def reset(self, initial_state: dict | None = None):
    """Restore dataset state or recalculate boundaries for a fresh start."""
    super().reset(initial_state)

    initial_state = initial_state or {}

    old_files = self.files
    self.files = initial_state.get("files", self.files)
    self.seq_len = initial_state.get("seq_len", self.seq_len)
    self.batch_size = initial_state.get("batch_size", self.batch_size)

    assert initial_state.get("world_size", self.world_size) == self.world_size
    assert initial_state.get("rank", self.rank) == self.rank

    if self.files != old_files or not hasattr(self, "index"):
        self._remap_files()

    self.index = initial_state.get("index", self.index)
    self.limit = initial_state.get("limit", self.limit)

TokenizedFlatDatasetConfig dataclass

Bases: RegistryConfigStrict

Configuration for flat tokenized datasets.

Attributes:

Name Type Description

Parameters:

Name Type Description Default
dtype str
'np.uint16'
files list[str]

Built-in mutable sequence.

If no argument is given, the constructor creates a new empty list. The argument must be an iterable if specified.

<dynamic>
seq_len int
'???'
batch_size int
'???'
Source code in optimus_dl/modules/data/datasets/tokenized_flat_dataset.py
@dataclass
class TokenizedFlatDatasetConfig(RegistryConfigStrict):
    """Configuration for flat tokenized datasets.

    Attributes:
        dtype: Numpy dtype of the token files.
        files: List of paths to tokenized `.npy` or raw binary files.
        seq_len: Sequence length for each batch.
        batch_size: Number of sequences per batch.
    """

    dtype: str = "np.uint16"
    files: list[str] = field(
        default_factory=list,
    )
    seq_len: int = field(default=MISSING)
    batch_size: int = field(default=MISSING)