base
optimus_dl.modules.criterion.base
¶
Base criterion (loss function) class.
This module defines the BaseCriterion class that all loss functions must inherit from. Criteria compute the loss given a model and a batch of data.
BaseCriterion
¶
Base class for all loss criteria (loss functions).
All loss functions in Optimus-DL should inherit from this class. The criterion is responsible for computing the loss given a model's output and the target data.
Subclasses should implement:
__call__(): Compute the loss given model and batch
Example
@register_criterion("cross_entropy", CrossEntropyConfig)
class CrossEntropyCriterion(BaseCriterion):
def __init__(self, cfg: CrossEntropyConfig):
self.cfg = cfg
def __call__(self, model: BaseModel, batch: dict) -> torch.Tensor:
logits = model(batch["input_ids"])
return F.cross_entropy(logits.view(-1, logits.size(-1)),
batch["labels"].view(-1))
Source code in optimus_dl/modules/criterion/base.py
__call__(model, batch, requested_protocols=None)
¶
Compute the loss for a given model and batch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
BaseModel
|
The model to compute loss for. Should be called with the batch to get model outputs. |
required |
batch
|
dict
|
Dictionary containing input data and targets. Typically includes: - "input_ids": Token IDs for the input sequence - "labels": Target token IDs for computing loss - Other model-specific fields |
required |
requested_protocols
|
set[str] | None
|
Optional set of protocol strings (e.g., {'logits', 'classification'}) that are requested by the metrics system. Subclasses can use this to avoid computing data that won't be used. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Tensor, dict[str, Any]]
|
A tuple of (loss, exposed_protocols), where: - loss: Scalar tensor containing the loss value. - exposed_protocols: Dictionary mapping protocol names (e.g., 'logits') to their computed values for reuse in metrics. |
Raises:
| Type | Description |
|---|---|
NotImplementedError
|
Must be implemented by subclasses. |