Skip to content

base

optimus_dl.modules.data.transforms.base

Base transform classes for data pipeline.

This module defines the base classes for data transforms, which are components that process data as it flows through the pipeline. Transforms can be chained together to create complex data processing pipelines.

BaseTransform

Base class for all data transforms.

All data transforms in Optimus-DL should inherit from this class. Transforms take a data source (BaseNode) and return a new BaseNode that applies the transformation. Transforms can be chained together using CompositeTransform.

Subclasses should implement:

  • build(): Apply the transform to a data source and return a new node
Example
@register_transform("tokenize", TokenizeConfig)
class TokenizeTransform(BaseTransform):
    def __init__(self, cfg: TokenizeConfig, **kwargs):
        super().__init__(**kwargs)
        self.tokenizer = build_tokenizer(cfg.tokenizer_config)

    def build(self, source: BaseNode) -> BaseNode:
        def tokenize_fn(item):
            return {"input_ids": self.tokenizer.encode(item["text"])}
        return source.map(tokenize_fn)
Source code in optimus_dl/modules/data/transforms/base.py
class BaseTransform:
    """Base class for all data transforms.

    All data transforms in Optimus-DL should inherit from this class. Transforms
    take a data source (BaseNode) and return a new BaseNode that applies the
    transformation. Transforms can be chained together using CompositeTransform.

    Subclasses should implement:

    - `build()`: Apply the transform to a data source and return a new node

    Example:
        ```python
        @register_transform("tokenize", TokenizeConfig)
        class TokenizeTransform(BaseTransform):
            def __init__(self, cfg: TokenizeConfig, **kwargs):
                super().__init__(**kwargs)
                self.tokenizer = build_tokenizer(cfg.tokenizer_config)

            def build(self, source: BaseNode) -> BaseNode:
                def tokenize_fn(item):
                    return {"input_ids": self.tokenizer.encode(item["text"])}
                return source.map(tokenize_fn)

        ```"""

    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        if "build" in cls.__dict__:
            original_build = cls.build

            def wrapped_build(self, *args, **kwargs) -> torchdata.nodes.BaseNode:
                from optimus_dl.modules.data.profiling import (
                    ProfilingProxyNode,
                    get_active_profiler,
                )

                profiler = get_active_profiler()
                if not profiler:
                    return original_build(self, *args, **kwargs)

                if not hasattr(profiler, "_build_stack"):
                    profiler._build_stack = []
                stack = profiler._build_stack
                is_outermost = len(stack) == 0

                stack.append(self)
                try:
                    node = original_build(self, *args, **kwargs)
                finally:
                    stack.pop()

                proxy = ProfilingProxyNode(node, name=repr(self), profiler=profiler)
                if is_outermost:
                    profiler.root_nodes.append(proxy)
                return proxy

            cls.build = wrapped_build

    def __init__(self, *args, **kwargs) -> None:
        """Initialize the transform.

        Args:
            *args: Positional arguments (typically unused, for compatibility).
            **kwargs: Keyword arguments passed from the data builder.
        """
        pass

    def build(self, source: torchdata.nodes.BaseNode) -> torchdata.nodes.BaseNode:
        """Apply the transform to a data source.

        This method takes a data source node and returns a new node that applies
        the transformation. The transformation is applied lazily as data flows
        through the pipeline.

        Args:
            source: The data source node to transform.

        Returns:
            A new BaseNode that applies the transformation.

        Raises:
            NotImplementedError: Must be implemented by subclasses.

        Example:
            ```python
            transform = TokenizeTransform(cfg)
            transformed_source = transform.build(raw_source)
            # transformed_source now yields tokenized data

            ```"""
        raise NotImplementedError

    def __repr__(self):
        return f"{self.__class__.__name__}()"

__init__(*args, **kwargs)

Initialize the transform.

Parameters:

Name Type Description Default
*args

Positional arguments (typically unused, for compatibility).

()
**kwargs

Keyword arguments passed from the data builder.

{}
Source code in optimus_dl/modules/data/transforms/base.py
def __init__(self, *args, **kwargs) -> None:
    """Initialize the transform.

    Args:
        *args: Positional arguments (typically unused, for compatibility).
        **kwargs: Keyword arguments passed from the data builder.
    """
    pass

build(source)

Apply the transform to a data source.

This method takes a data source node and returns a new node that applies the transformation. The transformation is applied lazily as data flows through the pipeline.

Parameters:

Name Type Description Default
source BaseNode

The data source node to transform.

required

Returns:

Type Description
BaseNode

A new BaseNode that applies the transformation.

Raises:

Type Description
NotImplementedError

Must be implemented by subclasses.

Example
transform = TokenizeTransform(cfg)
transformed_source = transform.build(raw_source)
# transformed_source now yields tokenized data
Source code in optimus_dl/modules/data/transforms/base.py
def build(self, source: torchdata.nodes.BaseNode) -> torchdata.nodes.BaseNode:
    """Apply the transform to a data source.

    This method takes a data source node and returns a new node that applies
    the transformation. The transformation is applied lazily as data flows
    through the pipeline.

    Args:
        source: The data source node to transform.

    Returns:
        A new BaseNode that applies the transformation.

    Raises:
        NotImplementedError: Must be implemented by subclasses.

    Example:
        ```python
        transform = TokenizeTransform(cfg)
        transformed_source = transform.build(raw_source)
        # transformed_source now yields tokenized data

        ```"""
    raise NotImplementedError

MapperConfig dataclass

Configuration for map operations in data transforms.

This configuration is used by transforms that apply map operations to data. It controls parallelism, ordering, and batching behavior.

Attributes:

Name Type Description

Parameters:

Name Type Description Default
num_workers int
4
in_order bool
True
method str
'thread'
snapshot_frequency int
128
prebatch int
32
Source code in optimus_dl/modules/data/transforms/base.py
@dataclass
class MapperConfig:
    """Configuration for map operations in data transforms.

    This configuration is used by transforms that apply map operations to data.
    It controls parallelism, ordering, and batching behavior.

    Attributes:
        num_workers: Number of worker processes/threads for parallel processing.
        in_order: If True, preserve the order of items. If False, allow out-of-order
            processing for better performance.
        method: Parallelization method: "thread" (threading) or "process" (multiprocessing).
        snapshot_frequency: How often to snapshot the iterator state for checkpointing.
        prebatch: Number of items to batch together before processing (for efficiency).
    """

    num_workers: int = 4
    in_order: bool = True
    method: str = "thread"
    snapshot_frequency: int = 128
    prebatch: int = 32

ProcessMapperConfig dataclass

Bases: MapperConfig

Config with process-based parallelism by default.

Parameters:

Name Type Description Default
method str
'process'
Source code in optimus_dl/modules/data/transforms/base.py
@dataclass
class ProcessMapperConfig(MapperConfig):
    """Config with process-based parallelism by default."""

    method: str = "process"

ThreadedMapperConfig dataclass

Bases: MapperConfig

Config with thread-based parallelism by default.

Parameters:

Name Type Description Default
method str
'thread'
Source code in optimus_dl/modules/data/transforms/base.py
@dataclass
class ThreadedMapperConfig(MapperConfig):
    """Config with thread-based parallelism by default."""

    method: str = "thread"