Skip to content

chunk_tokens

optimus_dl.modules.data.transforms.chunk_tokens

ChunkTransform

Bases: BaseTransform

Transform that splits variable-length documents into fixed-size chunks.

Useful when datasets yield full documents that are longer than the desired training sequence length.

Parameters:

Name Type Description Default
cfg ChunkTransformConfig

Chunking configuration.

required
Source code in optimus_dl/modules/data/transforms/chunk_tokens.py
@register_transform("chunk_tokens", ChunkTransformConfig)
class ChunkTransform(BaseTransform):
    """Transform that splits variable-length documents into fixed-size chunks.

    Useful when datasets yield full documents that are longer than the desired
    training sequence length.

    Args:
        cfg: Chunking configuration.
    """

    def __init__(self, cfg: ChunkTransformConfig, **kwargs):
        super().__init__(**kwargs)
        self.cfg = cfg

    def build(self, source: BaseNode) -> BaseNode:
        """Apply the chunking transformation to a source node."""
        return ChunkTransformNode(source, self.cfg)

build(source)

Apply the chunking transformation to a source node.

Source code in optimus_dl/modules/data/transforms/chunk_tokens.py
def build(self, source: BaseNode) -> BaseNode:
    """Apply the chunking transformation to a source node."""
    return ChunkTransformNode(source, self.cfg)

ChunkTransformConfig dataclass

Bases: RegistryConfigStrict

Configuration for chunking token sequences.

Attributes:

Name Type Description

Parameters:

Name Type Description Default
max_seq_len int
'???'
add_one_for_shift bool
True
Source code in optimus_dl/modules/data/transforms/chunk_tokens.py
@dataclass
class ChunkTransformConfig(RegistryConfigStrict):
    """Configuration for chunking token sequences.

    Attributes:
        max_seq_len: Maximum length of each produced chunk.
        add_one_for_shift: If True, adds 1 to max_seq_len (useful for causal LM training).
    """

    max_seq_len: int = MISSING
    add_one_for_shift: bool = True

ChunkTransformNode

Bases: BaseNode

Internal node for performing sequence chunking.

Maintains a buffer of tokens from the source node and yields segments of length max_seq_len.

Source code in optimus_dl/modules/data/transforms/chunk_tokens.py
class ChunkTransformNode(BaseNode):
    """Internal node for performing sequence chunking.

    Maintains a buffer of tokens from the source node and yields segments of
    length `max_seq_len`.
    """

    def __init__(self, node: BaseNode, cfg: ChunkTransformConfig, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.cfg = cfg
        self.node = node
        self.buffer = []

    def reset(self, initial_state: dict | None = None):
        """Restore the buffer and source node state."""
        super().reset(initial_state)
        self.buffer = []
        if initial_state:
            self.buffer = initial_state["buffer"]
            self.cfg = initial_state["cfg"]

            self.node.reset(initial_state["source_state"])
        else:
            self.node.reset()

    def get_state(self):
        """Collect current buffer and source state for checkpointing."""
        return {
            "buffer": self.buffer,
            "cfg": self.cfg,
            "source_state": self.node.state_dict(),
        }

    def next(self):
        """Yield the next chunk of tokens, refilling the buffer if empty."""
        if len(self.buffer) == 0:
            self.buffer = next(self.node)["input_ids"]

        taken = min(
            self.cfg.max_seq_len + (1 if self.cfg.add_one_for_shift else 0),
            len(self.buffer),
        )
        return_buff = self.buffer[:taken]
        self.buffer = self.buffer[taken:]
        return {"input_ids": return_buff}

get_state()

Collect current buffer and source state for checkpointing.

Source code in optimus_dl/modules/data/transforms/chunk_tokens.py
def get_state(self):
    """Collect current buffer and source state for checkpointing."""
    return {
        "buffer": self.buffer,
        "cfg": self.cfg,
        "source_state": self.node.state_dict(),
    }

next()

Yield the next chunk of tokens, refilling the buffer if empty.

Source code in optimus_dl/modules/data/transforms/chunk_tokens.py
def next(self):
    """Yield the next chunk of tokens, refilling the buffer if empty."""
    if len(self.buffer) == 0:
        self.buffer = next(self.node)["input_ids"]

    taken = min(
        self.cfg.max_seq_len + (1 if self.cfg.add_one_for_shift else 0),
        len(self.buffer),
    )
    return_buff = self.buffer[:taken]
    self.buffer = self.buffer[taken:]
    return {"input_ids": return_buff}

reset(initial_state=None)

Restore the buffer and source node state.

Source code in optimus_dl/modules/data/transforms/chunk_tokens.py
def reset(self, initial_state: dict | None = None):
    """Restore the buffer and source node state."""
    super().reset(initial_state)
    self.buffer = []
    if initial_state:
        self.buffer = initial_state["buffer"]
        self.cfg = initial_state["cfg"]

        self.node.reset(initial_state["source_state"])
    else:
        self.node.reset()