Skip to content

causal_lm

optimus_dl.modules.metrics.sources.causal_lm

CausalLMSource

Bases: MetricSource

Source for Causal LM that extracts logits and labels from the model and batch.

Source code in optimus_dl/modules/metrics/sources/causal_lm.py
@register_metric_source("causal_lm", CausalLMSourceConfig)
class CausalLMSource(MetricSource):
    """Source for Causal LM that extracts logits and labels from the model and batch."""

    cfg: CausalLMSourceConfig

    @property
    def provides(self) -> set[str]:
        return {StandardProtocols.LOGITS, StandardProtocols.CLASSIFICATION}

    @torch.no_grad()
    def __call__(
        self,
        dependencies: dict[str, dict[str, Any]],
        model: Any,
        batch: Any,
        **kwargs: Any,
    ) -> dict[str, Any]:
        """Execute the source.

        Args:
            dependencies: Data from required sources (none for this source).
            model: The model to run forward pass on.
            batch: The input batch, expected to contain 'input_ids'.
            **kwargs: Additional arguments (like criterion if needed).
        """
        batch = copy.copy(batch)
        input_ids = batch.pop("input_ids")
        batch["input_ids"] = input_ids[:, :-1]

        output = model(**batch)

        targets = input_ids[:, 1:]

        # Handle different output types (dict, Namespace, or raw Tensor)
        if isinstance(output, dict):
            logits = output.get("logits")
        elif hasattr(output, "logits"):
            logits = output.logits
        else:
            logits = output  # Assume it's the logits tensor

        mask = targets != self.cfg.padding_token_id
        if "seq_lens" in batch:
            mask = mask & (
                torch.arange(mask.shape[1], device=mask.device)
                < batch["seq_lens"][:, None]
            )

        classification = dict(
            predictions=logits.argmax(dim=-1),
            targets=targets,
            mask=mask,
        )

        return {
            StandardProtocols.LOGITS: logits,
            StandardProtocols.CLASSIFICATION: classification,
        }

__call__(dependencies, model, batch, **kwargs)

Execute the source.

Parameters:

Name Type Description Default
dependencies dict[str, dict[str, Any]]

Data from required sources (none for this source).

required
model Any

The model to run forward pass on.

required
batch Any

The input batch, expected to contain 'input_ids'.

required
**kwargs Any

Additional arguments (like criterion if needed).

{}
Source code in optimus_dl/modules/metrics/sources/causal_lm.py
@torch.no_grad()
def __call__(
    self,
    dependencies: dict[str, dict[str, Any]],
    model: Any,
    batch: Any,
    **kwargs: Any,
) -> dict[str, Any]:
    """Execute the source.

    Args:
        dependencies: Data from required sources (none for this source).
        model: The model to run forward pass on.
        batch: The input batch, expected to contain 'input_ids'.
        **kwargs: Additional arguments (like criterion if needed).
    """
    batch = copy.copy(batch)
    input_ids = batch.pop("input_ids")
    batch["input_ids"] = input_ids[:, :-1]

    output = model(**batch)

    targets = input_ids[:, 1:]

    # Handle different output types (dict, Namespace, or raw Tensor)
    if isinstance(output, dict):
        logits = output.get("logits")
    elif hasattr(output, "logits"):
        logits = output.logits
    else:
        logits = output  # Assume it's the logits tensor

    mask = targets != self.cfg.padding_token_id
    if "seq_lens" in batch:
        mask = mask & (
            torch.arange(mask.shape[1], device=mask.device)
            < batch["seq_lens"][:, None]
        )

    classification = dict(
        predictions=logits.argmax(dim=-1),
        targets=targets,
        mask=mask,
    )

    return {
        StandardProtocols.LOGITS: logits,
        StandardProtocols.CLASSIFICATION: classification,
    }

CausalLMSourceConfig dataclass

Bases: MetricSourceConfig

Configuration for CausalLMSource.

Parameters:

Name Type Description Default
padding_token_id int
-100
Source code in optimus_dl/modules/metrics/sources/causal_lm.py
@dataclass
class CausalLMSourceConfig(MetricSourceConfig):
    """Configuration for CausalLMSource."""

    _name: str = "causal_lm"
    padding_token_id: int = -100