Skip to content

base

optimus_dl.modules.data.datasets.strategies.base

BaseStrategy

Bases: ABC

Base class for dataset sampling strategies.

Source code in optimus_dl/modules/data/datasets/strategies/base.py
class BaseStrategy(ABC):
    """Base class for dataset sampling strategies."""

    def __init__(self, cfg: BaseStrategyConfig, rank: int, world_size: int):
        self.cfg = cfg
        self.rank = rank
        self.world_size = world_size
        self.doc_lengths: np.ndarray | None = None

    def initialize(self, doc_lengths: np.ndarray):
        """Initialize the strategy with document lengths."""
        self.doc_lengths = doc_lengths

    @abstractmethod
    def next_sample(self) -> list[tuple[int, tuple[int, int]]]:
        """Yield the next sample.

        Returns:
            A list of segments required to construct the sample.
            Each segment is a tuple: (doc_id, (start_offset, end_offset)).
            - doc_id: Global document index.
            - start_offset: Start token index within the document (inclusive).
            - end_offset: End token index within the document (exclusive).

        Raises:
            StopIteration: When the strategy is exhausted.
        """
        pass

    @abstractmethod
    def reset(self, initial_state: dict[str, Any] | None = None):
        """Reset state to initial or checkpointed state.

        Args:
            initial_state: State dictionary to restore from (optional).
        """
        pass

    @abstractmethod
    def get_state(self) -> dict[str, Any]:
        """Get state for checkpointing.

        Returns:
            Dictionary containing the current state.
        """
        pass

get_state() abstractmethod

Get state for checkpointing.

Returns:

Type Description
dict[str, Any]

Dictionary containing the current state.

Source code in optimus_dl/modules/data/datasets/strategies/base.py
@abstractmethod
def get_state(self) -> dict[str, Any]:
    """Get state for checkpointing.

    Returns:
        Dictionary containing the current state.
    """
    pass

initialize(doc_lengths)

Initialize the strategy with document lengths.

Source code in optimus_dl/modules/data/datasets/strategies/base.py
def initialize(self, doc_lengths: np.ndarray):
    """Initialize the strategy with document lengths."""
    self.doc_lengths = doc_lengths

next_sample() abstractmethod

Yield the next sample.

Returns:

Type Description
list[tuple[int, tuple[int, int]]]

A list of segments required to construct the sample.

list[tuple[int, tuple[int, int]]]

Each segment is a tuple: (doc_id, (start_offset, end_offset)).

list[tuple[int, tuple[int, int]]]
  • doc_id: Global document index.
list[tuple[int, tuple[int, int]]]
  • start_offset: Start token index within the document (inclusive).
list[tuple[int, tuple[int, int]]]
  • end_offset: End token index within the document (exclusive).

Raises:

Type Description
StopIteration

When the strategy is exhausted.

Source code in optimus_dl/modules/data/datasets/strategies/base.py
@abstractmethod
def next_sample(self) -> list[tuple[int, tuple[int, int]]]:
    """Yield the next sample.

    Returns:
        A list of segments required to construct the sample.
        Each segment is a tuple: (doc_id, (start_offset, end_offset)).
        - doc_id: Global document index.
        - start_offset: Start token index within the document (inclusive).
        - end_offset: End token index within the document (exclusive).

    Raises:
        StopIteration: When the strategy is exhausted.
    """
    pass

reset(initial_state=None) abstractmethod

Reset state to initial or checkpointed state.

Parameters:

Name Type Description Default
initial_state dict[str, Any] | None

State dictionary to restore from (optional).

None
Source code in optimus_dl/modules/data/datasets/strategies/base.py
@abstractmethod
def reset(self, initial_state: dict[str, Any] | None = None):
    """Reset state to initial or checkpointed state.

    Args:
        initial_state: State dictionary to restore from (optional).
    """
    pass

BaseStrategyConfig dataclass

Bases: RegistryConfigStrict

Source code in optimus_dl/modules/data/datasets/strategies/base.py
class BaseStrategyConfig(RegistryConfigStrict):
    pass