Skip to content

tokenize

optimus_dl.modules.data.transforms.tokenize

TokenizeTransform

Bases: BaseTransform

Transform that converts raw text strings into sequences of token IDs.

Uses the registry system to instantiate a tokenizer and applies it to the input data. Supports parallel mapping via torchdata.nodes.ParallelMapper.

Parameters:

Name Type Description Default
cfg TokenizeTransformConfig

Tokenization configuration.

required
Source code in optimus_dl/modules/data/transforms/tokenize.py
@register_transform("tokenize", TokenizeTransformConfig)
class TokenizeTransform(BaseTransform):
    """Transform that converts raw text strings into sequences of token IDs.

    Uses the registry system to instantiate a tokenizer and applies it to the
    input data. Supports parallel mapping via `torchdata.nodes.ParallelMapper`.

    Args:
        cfg: Tokenization configuration.
    """

    def __init__(self, cfg: TokenizeTransformConfig, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.tokenizer = build_tokenizer(cfg.tokenizer_config)

        # Convert MapperConfig dataclass to dict for torchdata
        from dataclasses import asdict

        if hasattr(cfg.worker_cfg, "__dataclass_fields__"):
            # It's a dataclass
            self.mapper_cfg = asdict(cfg.worker_cfg)
        else:
            # Try OmegaConf conversion for backwards compatibility
            self.mapper_cfg = OmegaConf.to_container(cfg.worker_cfg)

        self.debug_counter = 0
        self.debug_samples = cfg.debug_samples

    def _map(self, sample):
        """Map raw text sample to input_ids using the internal tokenizer."""
        text = sample["text"]
        ids = self.tokenizer.encode(text)
        if self.debug_counter < self.debug_samples:
            self.debug_counter += 1
            tokens_debug = []
            for (
                token_id
            ) in (
                ids
            ):  # Renamed 'token' to 'token_id' to avoid confusion with token strings
                token_decoded = self.tokenizer.decode([token_id])
                tokens_debug.append(f"{token_id}({token_decoded!r})")

            tokens_debug = ", ".join(tokens_debug)
            logger.info(f"Debugging tokenizer sample: \n{tokens_debug}\n=======")

        return {
            "input_ids": ids,
        }

    def build(self, source: BaseNode) -> BaseNode:
        """Wrap the source node with a parallel mapper using the tokenizer function."""
        return torchdata.nodes.ParallelMapper(
            source=source,
            map_fn=self._map,
            **self.mapper_cfg,
        )

build(source)

Wrap the source node with a parallel mapper using the tokenizer function.

Source code in optimus_dl/modules/data/transforms/tokenize.py
def build(self, source: BaseNode) -> BaseNode:
    """Wrap the source node with a parallel mapper using the tokenizer function."""
    return torchdata.nodes.ParallelMapper(
        source=source,
        map_fn=self._map,
        **self.mapper_cfg,
    )

TokenizeTransformConfig dataclass

Bases: RegistryConfigStrict

Configuration for text tokenization.

Attributes:

Name Type Description

Parameters:

Name Type Description Default
tokenizer_config Any
'???'
debug_samples int
0
worker_cfg MapperConfig

Config with process-based parallelism by default.

<dynamic>
Source code in optimus_dl/modules/data/transforms/tokenize.py
@dataclass
class TokenizeTransformConfig(RegistryConfigStrict):
    """Configuration for text tokenization.

    Attributes:
        tokenizer_config: Registry configuration for the tokenizer to use.
        debug_samples: Number of samples to print to the log for debugging.
        worker_cfg: Configuration for parallel processing workers.
    """

    tokenizer_config: Any = MISSING
    debug_samples: int = 0
    worker_cfg: MapperConfig = field(
        default_factory=ProcessMapperConfig,
    )