Skip to content

cross_entropy

optimus_dl.modules.criterion.cross_entropy

CrossEntropyCriterion

Bases: BaseCriterion

Standard Cross Entropy loss with distributed and kernel optimizations.

This criterion implements standard Cross Entropy but adds support for:

  • Loss Parallelism: Computes loss directly on sharded logits (DTensors) to save memory and communication.
  • Liger Kernel: Optional high-performance kernel for faster computation and lower memory usage on GPUs.
  • Metrics: Automatically logs accuracy, perplexity, and token counts.

Parameters:

Name Type Description Default
cfg CrossEntropyCriterionConfig

Configuration for cross entropy.

required
collective Collective

Collective object for distributed operations.

required
Source code in optimus_dl/modules/criterion/cross_entropy.py
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
@register_criterion("cross_entropy", CrossEntropyCriterionConfig)
class CrossEntropyCriterion(BaseCriterion):
    """Standard Cross Entropy loss with distributed and kernel optimizations.

    This criterion implements standard Cross Entropy but adds support for:

    - **Loss Parallelism**: Computes loss directly on sharded logits (DTensors) to
      save memory and communication.
    - **Liger Kernel**: Optional high-performance kernel for faster computation
      and lower memory usage on GPUs.
    - **Metrics**: Automatically logs accuracy, perplexity, and token counts.

    Args:
        cfg: Configuration for cross entropy.
        collective: Collective object for distributed operations.
    """

    def __init__(
        self, cfg: CrossEntropyCriterionConfig, collective: Collective, **kwargs
    ):
        self.cfg = cfg
        self.collective = collective
        self.tp_size = collective.tp_world_size
        self.padding_token_id = cfg.padding_token_id
        self._liger_available = False
        if self.cfg.use_liger_kernel or self.cfg.use_liger_kernel is None:
            try:
                from liger_kernel.transformers.functional import liger_cross_entropy

                self._liger_cross_entropy = liger_cross_entropy
                self._liger_available = True
                if self.cfg.use_liger_kernel is None:
                    logger.info("Using liger-kernel for cross-entropy.")
            except ImportError:
                if self.cfg.use_liger_kernel is not None:
                    logger.warning(
                        "use_liger_kernel=True but liger-kernel is not installed. Falling back to PyTorch."
                    )

    def __call__(self, model, batch, requested_protocols: set[str] | None = None):
        """Compute the cross entropy loss.

        Automatically handles target shifting (labels = inputs[1:]) and manages
        distributed loss computation if the model output is a DTensor.

        Args:
            model: The language model.
            batch: Dictionary containing 'input_ids' and optional 'labels'.
            requested_protocols: Optional set of requested protocols.

        Returns:
            Tuple of (loss tensor, exposed_protocols dictionary).
        """
        requested_protocols = requested_protocols or set()
        batch = copy.copy(batch)
        input_ids = batch.pop("input_ids")
        labels = batch.pop("labels", None)

        B, T = input_ids.shape

        if labels is not None:
            # Batcher already performed causal shifting (input_ids and labels are aligned)
            targets = labels
            batch["input_ids"] = input_ids
        else:
            assert (
                "cu_seqlens" not in batch
            ), "If input is flat, we cannot generate labels and inputs efficiently"
            assert input_ids.ndim == 2, "Input must be 2D for automatic causal shifting"

            # Perform standard causal shifting: targets = inputs[1:], inputs = inputs[:-1]
            targets = input_ids[:, 1:]
            batch["input_ids"] = input_ids[:, :-1]

            # Metadata tensors that match the sequence length must also be sliced
            for k in list(batch.keys()):
                v = batch[k]
                if isinstance(v, torch.Tensor) and v.ndim >= 2 and v.shape[1] == T:
                    batch[k] = v[:, :-1]

            if "seq_lens" in batch:
                batch["seq_lens"] = torch.clamp(batch["seq_lens"] - 1, min=0)
                if torch.any(batch["seq_lens"] == 0).item():
                    warn_once(
                        logger,
                        "Some sequences are too short to be used for training (len = 1 -> len = 0 after shifting)",
                    )

        # Log sequence statistics accurately for all schemes
        if "cu_seqlens" in batch:
            # Packed/Flat batch: metadata already adjusted for shifting
            cu = batch["cu_seqlens"]
            doc_lens = (cu[1:] - cu[:-1]).float()
            log_max("input_max_seq_len", doc_lens.max().item(), round=2)
            log_averaged(
                "input_mean_seq_len",
                lambda: doc_lens.mean().item(),
                weight=len(doc_lens) / self.tp_size,
                round=2,
            )
            log_summed(
                "total_samples",
                len(doc_lens) / self.tp_size,
                reset=False,
            )
            log_summed(
                "batch_samples",
                len(doc_lens) / self.tp_size,
            )
        elif "seq_lens" in batch:
            sl = batch["seq_lens"].float()
            log_max("input_max_seq_len", sl.max().item(), round=2)
            log_averaged(
                "input_mean_seq_len",
                lambda: sl.mean().item(),
                weight=sl.shape[0] / self.tp_size,
                round=2,
            )
            log_summed(
                "total_samples",
                len(sl) / self.tp_size,
                reset=False,
            )
            log_summed(
                "batch_samples",
                len(sl) / self.tp_size,
            )
        else:
            # Fixed-size batch
            current_T = batch["input_ids"].shape[1]
            log_max("input_max_seq_len", current_T, round=2)
            log_averaged("input_mean_seq_len", current_T, weight=B, round=2)
            log_summed(
                "total_samples",
                len(batch["input_ids"]) / self.tp_size,
                reset=False,
            )
            log_summed(
                "batch_samples",
                len(batch["input_ids"]) / self.tp_size,
            )

        model_out = model(**batch)
        input_ids = batch["input_ids"]
        logits = model_out["logits"]
        is_dtensor = isinstance(logits, DTensor)

        valid_tokens = cached_lambda(
            lambda: ((targets >= 0) & (targets != self.padding_token_id)).sum().item()
            / self.tp_size
        )
        predictions = cached_lambda(lambda: self.gather_predictions(logits))

        log_averaged(
            "accuracy",
            lambda: self.accuracy_metric(predictions(), targets),
            weight=valid_tokens,
            round=2,
        )
        log_summed(
            "batch_tokens",
            valid_tokens,
        )
        log_summed(
            "total_tokens",
            valid_tokens,
            reset=False,
        )

        targets_flat = targets.reshape(-1)
        enable_loss_parallel = False
        if is_dtensor:
            from torch.distributed.tensor.placement_types import Replicate

            if not isinstance(targets_flat, DTensor):
                targets_parallel = DTensor.from_local(
                    targets_flat, logits.device_mesh, (Replicate(),)
                )
            else:
                targets_parallel = targets_flat

            # Only enable loss_parallel if logits are actually sharded
            for placement in logits.placements:
                if isinstance(placement, Shard):
                    enable_loss_parallel = True
                    break
        else:
            targets_parallel = targets_flat

        if (
            self._liger_available
            and targets_parallel.device.type != "cpu"
            and not is_dtensor
        ):
            # Liger kernel handles mixed precision internally, no need to cast to float
            loss = self._liger_cross_entropy(
                input=logits.view(-1, logits.size(-1)),
                target=targets_parallel,
                label_smoothing=self.cfg.label_smoothing,
                ignore_index=self.padding_token_id,
            )
        else:
            with (
                torch.autocast(targets_parallel.device.type, enabled=False),
                loss_parallel() if enable_loss_parallel else nullcontext(),
            ):
                loss = torch.nn.functional.cross_entropy(
                    input=logits.view(-1, logits.size(-1)).float(),
                    target=targets_parallel,
                    label_smoothing=self.cfg.label_smoothing,
                    ignore_index=self.padding_token_id,
                )

        log_averaged(
            "loss",
            value=lambda: loss.item(),
            weight=valid_tokens,
        )
        log_averaged_exponent(
            "perplexity",
            value=lambda: loss.item(),
            weight=valid_tokens,
        )

        exposed = {}
        if (
            StandardProtocols.LOGITS in requested_protocols
            or StandardProtocols.LOGITS_TARGETS in requested_protocols
            or StandardProtocols.LOGITS_MASK in requested_protocols
            or StandardProtocols.CLASSIFICATION in requested_protocols
            or StandardProtocols.LOSS in requested_protocols
            or StandardProtocols.INPUT_TOKENS in requested_protocols
        ):
            with torch.no_grad():
                is_flat = B == 1 and "cu_seqlens" in batch
                current_seq_lens = batch.get("seq_lens")

                if StandardProtocols.LOSS in requested_protocols:
                    loss_for_exposed = loss
                    if isinstance(loss_for_exposed, DTensor):
                        loss_for_exposed = loss_for_exposed.full_tensor()
                    exposed[StandardProtocols.LOSS] = loss_for_exposed.detach()

                if StandardProtocols.LOGITS in requested_protocols:
                    res_logits = logits
                    if isinstance(res_logits, DTensor):
                        res_logits = res_logits.full_tensor()

                    if is_flat:
                        res_logits = self._unflatten_flat(
                            res_logits, batch["cu_seqlens"], batch["max_seqlen"]
                        )
                    exposed[StandardProtocols.LOGITS] = res_logits

                if StandardProtocols.LOGITS_TARGETS in requested_protocols:
                    targets_exposed = targets
                    if is_flat:
                        targets_exposed = self._unflatten_flat(
                            targets,
                            batch["cu_seqlens"],
                            batch["max_seqlen"],
                            pad_val=self.padding_token_id,
                        )
                    exposed[StandardProtocols.LOGITS_TARGETS] = targets_exposed

                if StandardProtocols.INPUT_TOKENS in requested_protocols:
                    input_tokens = input_ids
                    if is_flat:
                        input_tokens = self._unflatten_flat(
                            input_ids,
                            batch["cu_seqlens"],
                            batch["max_seqlen"],
                            pad_val=self.padding_token_id,
                        )
                    exposed[StandardProtocols.INPUT_TOKENS] = input_tokens

                if StandardProtocols.LOGITS_MASK in requested_protocols:
                    mask = targets != self.padding_token_id
                    if is_flat:
                        mask = self._unflatten_flat(
                            mask,
                            batch["cu_seqlens"],
                            batch["max_seqlen"],
                            pad_val=False,
                        )
                    elif current_seq_lens is not None:
                        mask = mask & (
                            torch.arange(mask.shape[1], device=mask.device)
                            < current_seq_lens[:, None]
                        )

                    exposed[StandardProtocols.LOGITS_MASK] = mask

                if StandardProtocols.CLASSIFICATION in requested_protocols:
                    res_preds = predictions()  # Already gathered by cached_lambda
                    res_targets = targets

                    # Base mask for valid tokens
                    res_mask = res_targets != self.padding_token_id
                    # Refine mask for padded batches if current_seq_lens is available
                    if is_flat:
                        cu = batch["cu_seqlens"]
                        ms = batch["max_seqlen"]
                        res_preds = self._unflatten_flat(res_preds, cu, ms)
                        res_targets = self._unflatten_flat(
                            res_targets, cu, ms, pad_val=self.padding_token_id
                        )
                        res_mask = self._unflatten_flat(res_mask, cu, ms, pad_val=False)
                    elif current_seq_lens is not None:
                        res_mask = res_mask & (
                            torch.arange(res_mask.shape[1], device=res_mask.device)
                            < current_seq_lens[:, None]
                        )

                    exposed[StandardProtocols.CLASSIFICATION] = dict(
                        predictions=res_preds,
                        targets=res_targets,
                        mask=res_mask,
                    )

        return loss, exposed

    @staticmethod
    def _unflatten_flat(t, cu_seqlens, max_seqlen, pad_val=0):
        """Helper to reconstruct (batch, time) layout from a flat (1, sum_T) batch."""
        # t is (1, T_sum, ...)
        device = t.device
        dtype = t.dtype
        num_docs = len(cu_seqlens) - 1
        total_tokens = int(cu_seqlens[-1].item())

        # seqlens of each document
        seqlens = cu_seqlens[1:] - cu_seqlens[:-1]

        # Batch index for each token: [0,0,0, 1,1, 2,2,2,2, ...]
        batch_idx = torch.repeat_interleave(
            torch.arange(num_docs, device=device), seqlens.to(torch.long)
        )

        # Local index for each token: [0,1,2, 0,1, 0,1,2,3, ...]
        # Global index minus sequence start index
        local_idx = torch.arange(total_tokens, device=device) - torch.repeat_interleave(
            cu_seqlens[:-1].to(torch.long), seqlens.to(torch.long)
        )

        # Prepare output buffer (batch, max_time, ...)
        out_shape = (num_docs, max_seqlen, *t.shape[2:])
        out = torch.full(out_shape, pad_val, device=device, dtype=dtype)

        # Vectorized assignment: out[batch_idx, local_idx] = t[0, :total_tokens]
        out[batch_idx, local_idx] = t[0]

        return out

    @torch.no_grad()
    def gather_predictions(self, logits):
        """
        Get predictions from logits.
        """
        is_dtensor = isinstance(logits, DTensor)
        if is_dtensor:
            assert isinstance(self.collective, MeshCollective)
            local_logits = logits.to_local()
            maxes = torch.max(local_logits, -1)

            maxes_values_distr = DTensor.from_local(
                maxes.values,
                device_mesh=self.collective.tp_mesh,
                placements=(Shard(1),),
            ).full_tensor()
            tok_shift = self.collective.tp_rank * local_logits.size(-1)
            maxes_index_distr = DTensor.from_local(
                maxes.indices + tok_shift,
                device_mesh=self.collective.tp_mesh,
                placements=(Shard(1),),
            ).full_tensor()

            max_total = torch.max(maxes_values_distr, -1, keepdim=True)
            predictions = torch.gather(
                maxes_index_distr,
                dim=1,
                index=max_total.indices,
            )
        else:
            predictions = torch.argmax(logits, dim=-1)
        return predictions

    @torch.no_grad()
    def accuracy_metric(self, predictions, targets):
        """Compute top-1 accuracy.

        Handles both standard Tensors and distributed DTensors. For DTensors, it
        performs a distributed max across tensor-parallel ranks.
        """

        correct = predictions == targets
        valid = (targets >= 0) & (targets != self.padding_token_id)
        correct = (correct & valid).float()
        return (correct.sum() / valid.sum()).item()

__call__(model, batch, requested_protocols=None)

Compute the cross entropy loss.

Automatically handles target shifting (labels = inputs[1:]) and manages distributed loss computation if the model output is a DTensor.

Parameters:

Name Type Description Default
model

The language model.

required
batch

Dictionary containing 'input_ids' and optional 'labels'.

required
requested_protocols set[str] | None

Optional set of requested protocols.

None

Returns:

Type Description

Tuple of (loss tensor, exposed_protocols dictionary).

Source code in optimus_dl/modules/criterion/cross_entropy.py
def __call__(self, model, batch, requested_protocols: set[str] | None = None):
    """Compute the cross entropy loss.

    Automatically handles target shifting (labels = inputs[1:]) and manages
    distributed loss computation if the model output is a DTensor.

    Args:
        model: The language model.
        batch: Dictionary containing 'input_ids' and optional 'labels'.
        requested_protocols: Optional set of requested protocols.

    Returns:
        Tuple of (loss tensor, exposed_protocols dictionary).
    """
    requested_protocols = requested_protocols or set()
    batch = copy.copy(batch)
    input_ids = batch.pop("input_ids")
    labels = batch.pop("labels", None)

    B, T = input_ids.shape

    if labels is not None:
        # Batcher already performed causal shifting (input_ids and labels are aligned)
        targets = labels
        batch["input_ids"] = input_ids
    else:
        assert (
            "cu_seqlens" not in batch
        ), "If input is flat, we cannot generate labels and inputs efficiently"
        assert input_ids.ndim == 2, "Input must be 2D for automatic causal shifting"

        # Perform standard causal shifting: targets = inputs[1:], inputs = inputs[:-1]
        targets = input_ids[:, 1:]
        batch["input_ids"] = input_ids[:, :-1]

        # Metadata tensors that match the sequence length must also be sliced
        for k in list(batch.keys()):
            v = batch[k]
            if isinstance(v, torch.Tensor) and v.ndim >= 2 and v.shape[1] == T:
                batch[k] = v[:, :-1]

        if "seq_lens" in batch:
            batch["seq_lens"] = torch.clamp(batch["seq_lens"] - 1, min=0)
            if torch.any(batch["seq_lens"] == 0).item():
                warn_once(
                    logger,
                    "Some sequences are too short to be used for training (len = 1 -> len = 0 after shifting)",
                )

    # Log sequence statistics accurately for all schemes
    if "cu_seqlens" in batch:
        # Packed/Flat batch: metadata already adjusted for shifting
        cu = batch["cu_seqlens"]
        doc_lens = (cu[1:] - cu[:-1]).float()
        log_max("input_max_seq_len", doc_lens.max().item(), round=2)
        log_averaged(
            "input_mean_seq_len",
            lambda: doc_lens.mean().item(),
            weight=len(doc_lens) / self.tp_size,
            round=2,
        )
        log_summed(
            "total_samples",
            len(doc_lens) / self.tp_size,
            reset=False,
        )
        log_summed(
            "batch_samples",
            len(doc_lens) / self.tp_size,
        )
    elif "seq_lens" in batch:
        sl = batch["seq_lens"].float()
        log_max("input_max_seq_len", sl.max().item(), round=2)
        log_averaged(
            "input_mean_seq_len",
            lambda: sl.mean().item(),
            weight=sl.shape[0] / self.tp_size,
            round=2,
        )
        log_summed(
            "total_samples",
            len(sl) / self.tp_size,
            reset=False,
        )
        log_summed(
            "batch_samples",
            len(sl) / self.tp_size,
        )
    else:
        # Fixed-size batch
        current_T = batch["input_ids"].shape[1]
        log_max("input_max_seq_len", current_T, round=2)
        log_averaged("input_mean_seq_len", current_T, weight=B, round=2)
        log_summed(
            "total_samples",
            len(batch["input_ids"]) / self.tp_size,
            reset=False,
        )
        log_summed(
            "batch_samples",
            len(batch["input_ids"]) / self.tp_size,
        )

    model_out = model(**batch)
    input_ids = batch["input_ids"]
    logits = model_out["logits"]
    is_dtensor = isinstance(logits, DTensor)

    valid_tokens = cached_lambda(
        lambda: ((targets >= 0) & (targets != self.padding_token_id)).sum().item()
        / self.tp_size
    )
    predictions = cached_lambda(lambda: self.gather_predictions(logits))

    log_averaged(
        "accuracy",
        lambda: self.accuracy_metric(predictions(), targets),
        weight=valid_tokens,
        round=2,
    )
    log_summed(
        "batch_tokens",
        valid_tokens,
    )
    log_summed(
        "total_tokens",
        valid_tokens,
        reset=False,
    )

    targets_flat = targets.reshape(-1)
    enable_loss_parallel = False
    if is_dtensor:
        from torch.distributed.tensor.placement_types import Replicate

        if not isinstance(targets_flat, DTensor):
            targets_parallel = DTensor.from_local(
                targets_flat, logits.device_mesh, (Replicate(),)
            )
        else:
            targets_parallel = targets_flat

        # Only enable loss_parallel if logits are actually sharded
        for placement in logits.placements:
            if isinstance(placement, Shard):
                enable_loss_parallel = True
                break
    else:
        targets_parallel = targets_flat

    if (
        self._liger_available
        and targets_parallel.device.type != "cpu"
        and not is_dtensor
    ):
        # Liger kernel handles mixed precision internally, no need to cast to float
        loss = self._liger_cross_entropy(
            input=logits.view(-1, logits.size(-1)),
            target=targets_parallel,
            label_smoothing=self.cfg.label_smoothing,
            ignore_index=self.padding_token_id,
        )
    else:
        with (
            torch.autocast(targets_parallel.device.type, enabled=False),
            loss_parallel() if enable_loss_parallel else nullcontext(),
        ):
            loss = torch.nn.functional.cross_entropy(
                input=logits.view(-1, logits.size(-1)).float(),
                target=targets_parallel,
                label_smoothing=self.cfg.label_smoothing,
                ignore_index=self.padding_token_id,
            )

    log_averaged(
        "loss",
        value=lambda: loss.item(),
        weight=valid_tokens,
    )
    log_averaged_exponent(
        "perplexity",
        value=lambda: loss.item(),
        weight=valid_tokens,
    )

    exposed = {}
    if (
        StandardProtocols.LOGITS in requested_protocols
        or StandardProtocols.LOGITS_TARGETS in requested_protocols
        or StandardProtocols.LOGITS_MASK in requested_protocols
        or StandardProtocols.CLASSIFICATION in requested_protocols
        or StandardProtocols.LOSS in requested_protocols
        or StandardProtocols.INPUT_TOKENS in requested_protocols
    ):
        with torch.no_grad():
            is_flat = B == 1 and "cu_seqlens" in batch
            current_seq_lens = batch.get("seq_lens")

            if StandardProtocols.LOSS in requested_protocols:
                loss_for_exposed = loss
                if isinstance(loss_for_exposed, DTensor):
                    loss_for_exposed = loss_for_exposed.full_tensor()
                exposed[StandardProtocols.LOSS] = loss_for_exposed.detach()

            if StandardProtocols.LOGITS in requested_protocols:
                res_logits = logits
                if isinstance(res_logits, DTensor):
                    res_logits = res_logits.full_tensor()

                if is_flat:
                    res_logits = self._unflatten_flat(
                        res_logits, batch["cu_seqlens"], batch["max_seqlen"]
                    )
                exposed[StandardProtocols.LOGITS] = res_logits

            if StandardProtocols.LOGITS_TARGETS in requested_protocols:
                targets_exposed = targets
                if is_flat:
                    targets_exposed = self._unflatten_flat(
                        targets,
                        batch["cu_seqlens"],
                        batch["max_seqlen"],
                        pad_val=self.padding_token_id,
                    )
                exposed[StandardProtocols.LOGITS_TARGETS] = targets_exposed

            if StandardProtocols.INPUT_TOKENS in requested_protocols:
                input_tokens = input_ids
                if is_flat:
                    input_tokens = self._unflatten_flat(
                        input_ids,
                        batch["cu_seqlens"],
                        batch["max_seqlen"],
                        pad_val=self.padding_token_id,
                    )
                exposed[StandardProtocols.INPUT_TOKENS] = input_tokens

            if StandardProtocols.LOGITS_MASK in requested_protocols:
                mask = targets != self.padding_token_id
                if is_flat:
                    mask = self._unflatten_flat(
                        mask,
                        batch["cu_seqlens"],
                        batch["max_seqlen"],
                        pad_val=False,
                    )
                elif current_seq_lens is not None:
                    mask = mask & (
                        torch.arange(mask.shape[1], device=mask.device)
                        < current_seq_lens[:, None]
                    )

                exposed[StandardProtocols.LOGITS_MASK] = mask

            if StandardProtocols.CLASSIFICATION in requested_protocols:
                res_preds = predictions()  # Already gathered by cached_lambda
                res_targets = targets

                # Base mask for valid tokens
                res_mask = res_targets != self.padding_token_id
                # Refine mask for padded batches if current_seq_lens is available
                if is_flat:
                    cu = batch["cu_seqlens"]
                    ms = batch["max_seqlen"]
                    res_preds = self._unflatten_flat(res_preds, cu, ms)
                    res_targets = self._unflatten_flat(
                        res_targets, cu, ms, pad_val=self.padding_token_id
                    )
                    res_mask = self._unflatten_flat(res_mask, cu, ms, pad_val=False)
                elif current_seq_lens is not None:
                    res_mask = res_mask & (
                        torch.arange(res_mask.shape[1], device=res_mask.device)
                        < current_seq_lens[:, None]
                    )

                exposed[StandardProtocols.CLASSIFICATION] = dict(
                    predictions=res_preds,
                    targets=res_targets,
                    mask=res_mask,
                )

    return loss, exposed

accuracy_metric(predictions, targets)

Compute top-1 accuracy.

Handles both standard Tensors and distributed DTensors. For DTensors, it performs a distributed max across tensor-parallel ranks.

Source code in optimus_dl/modules/criterion/cross_entropy.py
@torch.no_grad()
def accuracy_metric(self, predictions, targets):
    """Compute top-1 accuracy.

    Handles both standard Tensors and distributed DTensors. For DTensors, it
    performs a distributed max across tensor-parallel ranks.
    """

    correct = predictions == targets
    valid = (targets >= 0) & (targets != self.padding_token_id)
    correct = (correct & valid).float()
    return (correct.sum() / valid.sum()).item()

gather_predictions(logits)

Get predictions from logits.

Source code in optimus_dl/modules/criterion/cross_entropy.py
@torch.no_grad()
def gather_predictions(self, logits):
    """
    Get predictions from logits.
    """
    is_dtensor = isinstance(logits, DTensor)
    if is_dtensor:
        assert isinstance(self.collective, MeshCollective)
        local_logits = logits.to_local()
        maxes = torch.max(local_logits, -1)

        maxes_values_distr = DTensor.from_local(
            maxes.values,
            device_mesh=self.collective.tp_mesh,
            placements=(Shard(1),),
        ).full_tensor()
        tok_shift = self.collective.tp_rank * local_logits.size(-1)
        maxes_index_distr = DTensor.from_local(
            maxes.indices + tok_shift,
            device_mesh=self.collective.tp_mesh,
            placements=(Shard(1),),
        ).full_tensor()

        max_total = torch.max(maxes_values_distr, -1, keepdim=True)
        predictions = torch.gather(
            maxes_index_distr,
            dim=1,
            index=max_total.indices,
        )
    else:
        predictions = torch.argmax(logits, dim=-1)
    return predictions

CrossEntropyCriterionConfig dataclass

Bases: RegistryConfigStrict

CrossEntropyCriterionConfig(_name: str | None = None, label_smoothing: float = 0.0, use_liger_kernel: bool | None = None, padding_token_id: int = -100)

Parameters:

Name Type Description Default
label_smoothing float
0.0
use_liger_kernel bool | None
None
padding_token_id int
-100
Source code in optimus_dl/modules/criterion/cross_entropy.py
@dataclass
class CrossEntropyCriterionConfig(RegistryConfigStrict):
    label_smoothing: float = 0.0
    use_liger_kernel: bool | None = None
    padding_token_id: int = -100