Add pipeline parallel support to TransformersModel (#12832)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Harry Mellor 2025-03-25 02:41:45 +00:00 committed by GitHub
parent 911c8eb000
commit 97cfa65df7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 244 additions and 87 deletions

View File

@ -73,7 +73,7 @@ The Transformers fallback explicitly supports the following features:
- <project:#quantization-index> (except GGUF)
- <project:#lora-adapter>
- <project:#distributed-serving> (pipeline parallel coming soon <gh-pr:12832>!)
- <project:#distributed-serving> (requires `transformers>=4.49.0`)
#### Remote code

View File

@ -175,6 +175,8 @@ TEXT_GENERATION_MODELS = {
"inceptionai/jais-13b-chat": PPTestSettings.fast(),
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
"meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
# Tests TransformersModel
"ArthurZ/Ilama-3.2-1B": PPTestSettings.fast(),
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(),
"openbmb/MiniCPM3-4B": PPTestSettings.fast(),
# Uses Llama
@ -243,6 +245,7 @@ TEST_MODELS = [
# [LANGUAGE GENERATION]
"microsoft/Phi-3.5-MoE-instruct",
"meta-llama/Llama-3.2-1B-Instruct",
# "ArthurZ/Ilama-3.2-1B", NOTE: Uncomment after #13905
"ibm/PowerLM-3b",
# [LANGUAGE EMBEDDING]
"intfloat/e5-mistral-7b-instruct",

View File

@ -15,21 +15,25 @@
# limitations under the License.
"""Wrapper around `transformers` models"""
import re
from itertools import chain
from typing import Iterable, Literal, Optional, Union
import torch
from torch import nn
from transformers import AutoModel, PreTrainedModel
from transformers import AutoModel, PretrainedConfig, PreTrainedModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from vllm.attention import Attention
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, VllmConfig)
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed.utils import get_pp_indices
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
@ -37,8 +41,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsQuant
from .utils import maybe_prefix
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, maybe_prefix)
logger = init_logger(__name__)
@ -53,7 +58,7 @@ def vllm_flash_attention_forward(
# Transformers kwargs
scaling: Optional[float] = None,
# vLLM kwargs
attention_instances: Optional[list[Attention]] = None,
attention_instances: Optional[dict[Attention]] = None,
**kwargs):
self_attn = attention_instances[module.layer_idx]
if scaling is not None:
@ -72,13 +77,12 @@ def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
def replace_linear_class(
linear: nn.Linear,
style: Literal["colwise", "rowwise"],
quant_config=None) -> Union[ColumnParallelLinear, RowParallelLinear]:
linear: nn.Linear, style: Literal["colwise", "rowwise"],
quant_config: QuantizationConfig
) -> Union[ColumnParallelLinear, RowParallelLinear]:
"""
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
`quant_config` is not yet supported.
Args:
linear (nn.Linear): `nn.Linear` to be replaced.
style (str): Tensor parallel style of the new linear, e.g. "colwise".
@ -105,7 +109,7 @@ def replace_linear_class(
)
class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
embedding_padding_modules = ["lm_head"]
embedding_modules = ["embed_tokens"
] # TODO transformers will have a util to get it
@ -114,31 +118,175 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
super().__init__()
logger.info("Using Transformers backend.")
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
config: PretrainedConfig = vllm_config.model_config.hf_config
cache_config: CacheConfig = vllm_config.cache_config
device_config: DeviceConfig = vllm_config.device_config
model_config: ModelConfig = vllm_config.model_config
parallel_config: ParallelConfig = vllm_config.parallel_config
quant_config: QuantizationConfig = vllm_config.quant_config
self.config = config
self.cache_config = cache_config
self.device_config = device_config
self.model_config = model_config
self.parallel_config = parallel_config
self.quant_config = quant_config
self.vocab_size = model_config.get_vocab_size()
self.unpadded_vocab_size = model_config.get_vocab_size()
self.model: PreTrainedModel = AutoModel.from_config(
self.config,
attn_implementation="vllm",
torch_dtype=vllm_config.model_config.dtype,
trust_remote_code=vllm_config.model_config.trust_remote_code,
)
self.pp_group = get_pp_group()
self.pp_size = self.pp_group.world_size
self.pp_rank = self.pp_group.rank_in_group
self.tp_size = get_tensor_model_parallel_world_size()
# Use meta device to delay allocating GPU tensors
with torch.device("meta"):
self.model: PreTrainedModel = AutoModel.from_config(
config,
attn_implementation="vllm",
torch_dtype=model_config.dtype,
trust_remote_code=model_config.trust_remote_code,
)
prefix = self.model.base_model_prefix
# MLP modifications
self.apply_base_model_tp_plan(self.model)
self.pipeline_parallel()
self.tensor_parallel()
# Attention modifications (assumes 1 attention op per hidden layer)
num_heads = model_config.get_num_attention_heads(parallel_config)
head_size = model_config.get_head_size()
num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.attention_instances = [
# Input embeddings
if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
self.model.set_input_embeddings(
VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
quant_config=quant_config,
))
# Attention layers
self.attention_instances = self.create_attention_instances()
# Output embeddings
if not isinstance(getattr(self, "lm_head", None), PPMissingLayer):
self.unpadded_vocab_size = config.vocab_size
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "lm_head"),
)
if config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(
self.model.get_input_embeddings())
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size,
logit_scale)
# Initialize buffers (e.g. rotary embedding inverse frequency)
self.init_buffers(self.model)
# Move remaining meta tensors to device (should happen last)
self.meta_to_empty(self.model)
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(["hidden_states"],
config.hidden_size))
def pipeline_parallel(self):
"""
Apply the model's pipeline parallelization plan.
"""
if self.pp_size <= 1:
return
if not self.model.supports_pp_plan:
raise ValueError(
f"{type(self.model)} does not support pipeline parallel yet!")
module_lists = []
module_list_idx = None
pp_plan = list(self.model._pp_plan.keys())
for i, name in enumerate(pp_plan):
if isinstance(getattr(self.model, name), nn.ModuleList):
module_lists.append(name)
module_list_idx = i
if len(module_lists) > 1:
raise ValueError(
"Pipeline parallel of models with multiple `ModuleList`s "
"in the base model are not supported yet!")
if module_list_idx is None:
raise ValueError(
f"Could not find `ModuleList` in {type(self.model)}")
# Layers before module list
for name in pp_plan[:module_list_idx]:
if self.pp_group.is_first_rank or (self.config.tie_word_embeddings
and self.pp_group.is_last_rank):
continue
setattr(self.model, name, PPMissingLayer())
# Module list
start_layer, end_layer = get_pp_indices(self.config.num_hidden_layers,
self.pp_rank, self.pp_size)
layers_name = pp_plan[module_list_idx]
layers = getattr(self.model, layers_name)
for i in range(len(layers)):
if start_layer <= i and i < end_layer:
continue
layers[i] = PPMissingLayer(return_tuple=True)
# Layers after module list
for name in pp_plan[module_list_idx + 1:]:
# Modules that should be on last rank
if not self.pp_group.is_last_rank:
setattr(self.model, name, PPMissingLayer())
if not self.pp_group.is_last_rank:
self.lm_head = PPMissingLayer()
def tensor_parallel(self):
"""
Apply the model's tensor parallelization plan.
Currently only supports linear layers.
"""
if self.tp_size > 1 and self.config.base_model_tp_plan is None:
raise ValueError(
f"{type(self.model)} does not support tensor parallel yet!")
tp_plan = self.model._tp_plan
def _tensor_parallel(module: nn.Module, prefix: str = ""):
for child_name, child_module in module.named_children():
qual_name = maybe_prefix(prefix, child_name)
for pattern, style in tp_plan.items():
if re.match(pattern, qual_name) and isinstance(
child_module, nn.Linear):
new_module = replace_linear_class(
child_module, style, self.quant_config)
setattr(module, child_name, new_module)
log_replacement(qual_name, child_module, new_module)
else:
_tensor_parallel(child_module, prefix=qual_name)
_tensor_parallel(self.model)
def create_attention_instances(self) -> dict[int, Attention]:
"""
Create `Attention` instances to inform KV cache allocation.
"""
num_heads = self.model_config.get_num_attention_heads(
self.parallel_config)
head_size = self.model_config.get_head_size()
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
start, end = get_pp_indices(self.config.num_hidden_layers,
self.pp_rank, self.pp_size)
return {
i:
Attention(
num_heads=num_heads,
head_size=head_size,
@ -146,77 +294,70 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
# Transformers, it's updated in vllm_flash_attention_forward
scale=head_size**-0.5,
num_kv_heads=num_kv_heads,
cache_config=cache_config,
cache_config=self.cache_config,
quant_config=self.quant_config,
prefix=f"{i}.attn") for i in range(config.num_hidden_layers)
]
prefix=f"{i}.attn")
for i in range(start, end)
}
# Model modifications
self.replace_vocab_embed_class(self.model)
# ForCausalLM modifications
self.lm_head = ParallelLMHead(self.vocab_size,
config.hidden_size,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "lm_head"))
if config.tie_word_embeddings:
self.lm_head.weight = self.model.get_input_embeddings().weight
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
self.vocab_size, logit_scale)
self.sampler = get_sampler()
def apply_base_model_tp_plan(self, module: nn.Module, prefix: str = ""):
def init_buffers(self, module: nn.Module):
"""
Apply the base model tensor parallelization plan to a module.
Currently only supports linear layers.
If a `buffer` is on the `meta` device, then its parent
`module` is the original module created by:
```python
with torch.device("meta"):
self.model: PreTrainedModel = AutoModel.from_config(...)
```
This means that:
- `type(module)` is a class from `transformers`
- This class is constructed using a `PretrainedConfig`
"""
if (self.config.base_model_tp_plan is None
and get_tensor_model_parallel_world_size() > 1):
raise ValueError(
"Trying to run tensor parallelization but the model does not "
"support it yet!")
for name, buffer in module.named_buffers(recurse=False):
if buffer.device == torch.device("meta"):
new_buffer = getattr(type(module)(self.config), name)
setattr(module, name, new_buffer)
for child in module.children():
self.init_buffers(child)
for child_name, child_module in module.named_children():
qual_name = maybe_prefix(prefix, child_name)
for pattern, style in self.config.base_model_tp_plan.items():
if re.match(pattern, qual_name) and isinstance(
child_module, nn.Linear):
new_module = replace_linear_class(child_module, style,
self.quant_config)
setattr(module, child_name, new_module)
log_replacement(qual_name, child_module, new_module)
else:
self.apply_base_model_tp_plan(child_module, prefix=qual_name)
def replace_vocab_embed_class(self, module: nn.Module):
# Use native set input embeddings
new_module = VocabParallelEmbedding(
self.vocab_size,
self.config.hidden_size,
org_num_embeddings=self.vocab_size,
quant_config=None,
)
log_replacement("input embedding", self.model.get_input_embeddings(),
new_module)
module.set_input_embeddings(new_module)
def meta_to_empty(self, module: nn.Module):
tensors = list(chain(module.buffers(), module.parameters()))
if tensors and all(t.device == torch.device("meta") for t in tensors):
module.to_empty(device=self.device_config.device)
return # We can stop recursing because to_empty is recursive
for child in module.children():
self.meta_to_empty(child)
def forward(
self,
input_ids: torch.Tensor,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
model_output = self.model(
input_ids[None, ...],
if not get_pp_group().is_first_rank:
assert intermediate_tensors is not None
input_ids = None
inputs_embeds = intermediate_tensors["hidden_states"]
if input_ids is not None:
input_ids = input_ids[None, ...]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[None, ...]
hidden_states = self.model(
input_ids=input_ids,
inputs_embeds=inputs_embeds,
use_cache=False,
position_ids=positions[None, ...],
intermediate_tensors=intermediate_tensors,
attention_instances=self.attention_instances,
return_dict=False)[0][0, ...] # we remove batch dimension for now
return model_output
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
return hidden_states
def compute_logits(
self,
@ -238,8 +379,11 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
params_dict = dict(self.named_parameters())
loaded_params = set[str]()
for name, loaded_weight in weights:
if name not in params_dict:
name = f"{self.model.base_model_prefix}.{name}"
# Necessary for some models which use remote code
if not name.startswith(prefix := self.model.base_model_prefix):
name = maybe_prefix(prefix, name)
if is_pp_missing_parameter(name, self):
continue
if name in params_dict:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",

View File

@ -472,6 +472,16 @@ class PPMissingLayer(torch.nn.Identity):
def __init__(self, *args, **kwargs):
super().__init__()
self.return_tuple = kwargs.get("return_tuple", False)
def forward(self, *args, **kwargs):
"""
Return the first arg from args or the first value from kwargs.
Wraps the input in a tuple if `self.return_tuple` is True.
"""
input = args[0] if args else next(iter(kwargs.values()))
return (input, ) if self.return_tuple else input
_CPU_OFFLOAD_BYTES = 0
@ -650,4 +660,4 @@ def cast_overflow_tensors(
if tensors.isinf().any() or tensors.isnan().any():
clamp_value = torch.finfo(tensors.dtype).max - offset
tensors = torch.clamp(tensors, min=-clamp_value, max=clamp_value)
return tensors
return tensors