Skip to content

tensor_parallel

optimus_dl.modules.model_transforms.tensor_parallel

Tensor Parallelism Transform.

TensorParallelConfig dataclass

Bases: ModelTransformConfig

Configuration for Tensor Parallelism.

Attributes:

Name Type Description

Parameters:

Name Type Description Default
custom_model_kwargs dict

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

<class 'dict'>
Source code in optimus_dl/modules/model_transforms/tensor_parallel.py
@dataclass
class TensorParallelConfig(ModelTransformConfig):
    """Configuration for Tensor Parallelism.

    Attributes:
        custom_model_kwargs: Additional keyword arguments passed to the model's
            `apply_tp` method (e.g., sequence_parallel=True).
    """

    custom_model_kwargs: dict = field(default_factory=dict)

TensorParallelTransform

Bases: BaseModelTransform

Transform that applies Tensor Parallelism to a model.

This transform delegates the actual sharding logic to the model's apply_tp method, providing it with the appropriate Tensor Parallel device mesh from the global collective.

Parameters:

Name Type Description Default
cfg TensorParallelConfig

Tensor parallel configuration.

required
collective Collective

Distributed collective (MeshCollective required).

required
Source code in optimus_dl/modules/model_transforms/tensor_parallel.py
@register_model_transform("tensor_parallel", TensorParallelConfig)
class TensorParallelTransform(BaseModelTransform):
    """Transform that applies Tensor Parallelism to a model.

    This transform delegates the actual sharding logic to the model's `apply_tp`
    method, providing it with the appropriate Tensor Parallel device mesh from
     the global collective.

    Args:
        cfg: Tensor parallel configuration.
        collective: Distributed collective (MeshCollective required).
    """

    def __init__(
        self,
        cfg: TensorParallelConfig,
        collective: Collective,
        **kwargs: Any,
    ):
        super().__init__(cfg, **kwargs)
        self.collective = collective

    def apply(self, model: BaseModel, **kwargs) -> BaseModel:
        """Apply the tensor parallel plan to the model."""
        if not isinstance(self.collective, MeshCollective):
            logger.warning("TensorParallel requires MeshCollective. Skipping.")
            return model

        tp_mesh = self.collective.tp_mesh
        if tp_mesh is None:
            logger.info("No TP mesh found (tp_size=1). Skipping Tensor Parallelism.")
            return model

        logger.info(f"Applying Tensor Parallelism with mesh: {tp_mesh}")

        # Get the parallelization plan from the model
        model.apply_tp(tp_mesh, **self.cfg.custom_model_kwargs)

        logger.info("Tensor Parallelism applied successfully.")
        return model

apply(model, **kwargs)

Apply the tensor parallel plan to the model.

Source code in optimus_dl/modules/model_transforms/tensor_parallel.py
def apply(self, model: BaseModel, **kwargs) -> BaseModel:
    """Apply the tensor parallel plan to the model."""
    if not isinstance(self.collective, MeshCollective):
        logger.warning("TensorParallel requires MeshCollective. Skipping.")
        return model

    tp_mesh = self.collective.tp_mesh
    if tp_mesh is None:
        logger.info("No TP mesh found (tp_size=1). Skipping Tensor Parallelism.")
        return model

    logger.info(f"Applying Tensor Parallelism with mesh: {tp_mesh}")

    # Get the parallelization plan from the model
    model.apply_tp(tp_mesh, **self.cfg.custom_model_kwargs)

    logger.info("Tensor Parallelism applied successfully.")
    return model