Skip to content

txt_lines

optimus_dl.modules.data.datasets.txt_lines

TxtLinesDataset

Bases: BaseDataset

Dataset that reads and shards a text file line-by-line.

This dataset handles:

  • Remote Loading: Automatically downloads files from URLs and caches them.
  • Line Filtering: Optional removal of empty lines.
  • Distributed Sharding: Partitions the total number of lines equally across ranks.
Note

This dataset loads the entire file into memory on each rank. It is intended for small to medium-sized text files (e.g., TinyShakespeare).

Parameters:

Name Type Description Default
cfg TxtLinesDatasetConfig

Text lines dataset configuration.

required
rank int

Distributed rank.

required
world_size int

Total number of ranks.

required
Source code in optimus_dl/modules/data/datasets/txt_lines.py
@register_dataset("txt_lines", TxtLinesDatasetConfig)
class TxtLinesDataset(BaseDataset):
    """Dataset that reads and shards a text file line-by-line.

    This dataset handles:

    - **Remote Loading**: Automatically downloads files from URLs and caches them.
    - **Line Filtering**: Optional removal of empty lines.
    - **Distributed Sharding**: Partitions the total number of lines equally
      across ranks.

    Note:
      This dataset loads the entire file into memory on each rank. It is intended
      for small to medium-sized text files (e.g., TinyShakespeare).

    Args:
        cfg: Text lines dataset configuration.
        rank: Distributed rank.
        world_size: Total number of ranks.
    """

    def __init__(
        self, cfg: TxtLinesDatasetConfig, rank: int, world_size: int, **kwargs
    ):
        super().__init__(cfg)
        self.file_link = cfg.file_link
        self.cache_dir = cfg.cache_dir
        self.skip_empty_lines = cfg.skip_empty_lines
        self.rank = rank
        self.world_size = world_size

        self.lines = []
        self.index = 0
        self.limit = 0

    def _prepare_data(self):
        """Download (if needed) and shard the text data into lines."""
        # 1. Resolve path / download
        local_path = self.file_link
        if self.file_link.startswith("http://") or self.file_link.startswith(
            "https://"
        ):
            # Encode URL to filename to avoid collisions
            url_hash = hashlib.sha256(self.file_link.encode("utf-8")).hexdigest()
            # Try to keep the extension if present
            _, ext = os.path.splitext(self.file_link)
            if not ext:
                ext = ".txt"
            # Sanitize extension (just in case)
            if len(ext) > 10:
                ext = ".txt"

            filename = f"{url_hash}{ext}"
            local_path = os.path.join(self.cache_dir, filename)

            if not os.path.exists(local_path):
                os.makedirs(self.cache_dir, exist_ok=True)
                logger.info(f"Downloading {self.file_link} to {local_path}")
                try:
                    response = requests.get(self.file_link, stream=True)
                    response.raise_for_status()
                    with open(local_path, "wb") as f:
                        for chunk in response.iter_content(chunk_size=8192):
                            f.write(chunk)
                except Exception as e:
                    logger.error(f"Failed to download {self.file_link}: {e}")
                    raise

        if not os.path.exists(local_path):
            raise FileNotFoundError(f"File not found: {local_path}")

        # 2. Read and filter
        logger.info(f"Loading data from {local_path}")
        with open(local_path, encoding="utf-8") as f:
            raw_lines = f.readlines()

        # Strip trailing newlines
        self.lines = [i.rstrip("\n") for i in raw_lines]

        if self.skip_empty_lines:
            self.lines = [i for i in self.lines if i.strip()]

        # 3. Shard
        total_lines = len(self.lines)
        if total_lines == 0:
            logger.warning(f"Dataset at {local_path} is empty after filtering.")
            self.limit = 0
            self.index = 0
            return

        lines_per_rank = total_lines // self.world_size
        self.index = lines_per_rank * self.rank
        self.limit = lines_per_rank * (self.rank + 1)

        # Ensure index is within bounds (just in case)
        self.index = max(0, min(self.index, total_lines))
        self.limit = max(0, min(self.limit, total_lines))

    def next(self):
        """Yield the next line of text."""
        if self.index >= self.limit:
            raise StopIteration

        line = self.lines[self.index]
        self.index += 1
        return {"text": line}

    def reset(self, initial_state: dict | None = None):
        """Restore dataset state or prepare the file for a fresh start."""
        super().reset(initial_state)
        initial_state = initial_state or {}

        if "file_link" in initial_state:
            self.file_link = initial_state["file_link"]

        if "skip_empty_lines" in initial_state:
            self.skip_empty_lines = initial_state["skip_empty_lines"]

        self._prepare_data()

        self.index = initial_state.get("index", self.index)
        assert initial_state.get("rank", self.rank) == self.rank
        assert initial_state.get("world_size", self.world_size) == self.world_size

    def get_state(self):
        """Return current line index for checkpointing."""
        return {
            "index": self.index,
            "file_link": self.file_link,
            "skip_empty_lines": self.skip_empty_lines,
            "rank": self.rank,
            "world_size": self.world_size,
        }

get_state()

Return current line index for checkpointing.

Source code in optimus_dl/modules/data/datasets/txt_lines.py
def get_state(self):
    """Return current line index for checkpointing."""
    return {
        "index": self.index,
        "file_link": self.file_link,
        "skip_empty_lines": self.skip_empty_lines,
        "rank": self.rank,
        "world_size": self.world_size,
    }

next()

Yield the next line of text.

Source code in optimus_dl/modules/data/datasets/txt_lines.py
def next(self):
    """Yield the next line of text."""
    if self.index >= self.limit:
        raise StopIteration

    line = self.lines[self.index]
    self.index += 1
    return {"text": line}

reset(initial_state=None)

Restore dataset state or prepare the file for a fresh start.

Source code in optimus_dl/modules/data/datasets/txt_lines.py
def reset(self, initial_state: dict | None = None):
    """Restore dataset state or prepare the file for a fresh start."""
    super().reset(initial_state)
    initial_state = initial_state or {}

    if "file_link" in initial_state:
        self.file_link = initial_state["file_link"]

    if "skip_empty_lines" in initial_state:
        self.skip_empty_lines = initial_state["skip_empty_lines"]

    self._prepare_data()

    self.index = initial_state.get("index", self.index)
    assert initial_state.get("rank", self.rank) == self.rank
    assert initial_state.get("world_size", self.world_size) == self.world_size

TxtLinesDatasetConfig dataclass

Bases: RegistryConfigStrict

Configuration for line-based text datasets.

Attributes:

Name Type Description

Parameters:

Name Type Description Default
file_link str
'???'
cache_dir str

Returns tempfile.tempdir as str.

'/tmp'
skip_empty_lines bool
True
Source code in optimus_dl/modules/data/datasets/txt_lines.py
@dataclass
class TxtLinesDatasetConfig(RegistryConfigStrict):
    """Configuration for line-based text datasets.

    Attributes:
        file_link: Local path or HTTP(S) URL to the text file.
        cache_dir: Directory to cache downloaded files.
        skip_empty_lines: If True, lines that are empty or only whitespace are ignored.
    """

    file_link: str = MISSING
    cache_dir: str = field(default_factory=tempfile.gettempdir)
    skip_empty_lines: bool = True