tensor_parallel
optimus_dl.modules.model_transforms.tensor_parallel
¶
Tensor Parallelism Transform.
TensorParallelConfig
dataclass
¶
Bases: ModelTransformConfig
Configuration for Tensor Parallelism.
Attributes:
| Name | Type | Description |
|---|
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
custom_model_kwargs
|
dict
|
dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2) |
<class 'dict'>
|
Source code in optimus_dl/modules/model_transforms/tensor_parallel.py
TensorParallelTransform
¶
Bases: BaseModelTransform
Transform that applies Tensor Parallelism to a model.
This transform delegates the actual sharding logic to the model's apply_tp
method, providing it with the appropriate Tensor Parallel device mesh from
the global collective.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cfg
|
TensorParallelConfig
|
Tensor parallel configuration. |
required |
collective
|
Collective
|
Distributed collective (MeshCollective required). |
required |
Source code in optimus_dl/modules/model_transforms/tensor_parallel.py
apply(model, **kwargs)
¶
Apply the tensor parallel plan to the model.