Add pipeline parallel support to TransformersModel
(#12832)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
parent
911c8eb000
commit
97cfa65df7
@ -73,7 +73,7 @@ The Transformers fallback explicitly supports the following features:
|
|||||||
|
|
||||||
- <project:#quantization-index> (except GGUF)
|
- <project:#quantization-index> (except GGUF)
|
||||||
- <project:#lora-adapter>
|
- <project:#lora-adapter>
|
||||||
- <project:#distributed-serving> (pipeline parallel coming soon <gh-pr:12832>!)
|
- <project:#distributed-serving> (requires `transformers>=4.49.0`)
|
||||||
|
|
||||||
#### Remote code
|
#### Remote code
|
||||||
|
|
||||||
|
@ -175,6 +175,8 @@ TEXT_GENERATION_MODELS = {
|
|||||||
"inceptionai/jais-13b-chat": PPTestSettings.fast(),
|
"inceptionai/jais-13b-chat": PPTestSettings.fast(),
|
||||||
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
|
"ai21labs/Jamba-tiny-dev": PPTestSettings.fast(),
|
||||||
"meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
|
"meta-llama/Llama-3.2-1B-Instruct": PPTestSettings.detailed(),
|
||||||
|
# Tests TransformersModel
|
||||||
|
"ArthurZ/Ilama-3.2-1B": PPTestSettings.fast(),
|
||||||
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(),
|
"openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(),
|
||||||
"openbmb/MiniCPM3-4B": PPTestSettings.fast(),
|
"openbmb/MiniCPM3-4B": PPTestSettings.fast(),
|
||||||
# Uses Llama
|
# Uses Llama
|
||||||
@ -243,6 +245,7 @@ TEST_MODELS = [
|
|||||||
# [LANGUAGE GENERATION]
|
# [LANGUAGE GENERATION]
|
||||||
"microsoft/Phi-3.5-MoE-instruct",
|
"microsoft/Phi-3.5-MoE-instruct",
|
||||||
"meta-llama/Llama-3.2-1B-Instruct",
|
"meta-llama/Llama-3.2-1B-Instruct",
|
||||||
|
# "ArthurZ/Ilama-3.2-1B", NOTE: Uncomment after #13905
|
||||||
"ibm/PowerLM-3b",
|
"ibm/PowerLM-3b",
|
||||||
# [LANGUAGE EMBEDDING]
|
# [LANGUAGE EMBEDDING]
|
||||||
"intfloat/e5-mistral-7b-instruct",
|
"intfloat/e5-mistral-7b-instruct",
|
||||||
|
@ -15,21 +15,25 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Wrapper around `transformers` models"""
|
"""Wrapper around `transformers` models"""
|
||||||
import re
|
import re
|
||||||
|
from itertools import chain
|
||||||
from typing import Iterable, Literal, Optional, Union
|
from typing import Iterable, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import AutoModel, PreTrainedModel
|
from transformers import AutoModel, PretrainedConfig, PreTrainedModel
|
||||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||||
|
|
||||||
from vllm.attention import Attention
|
from vllm.attention import Attention
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
ParallelConfig, VllmConfig)
|
||||||
|
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
|
from vllm.distributed.utils import get_pp_indices
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead, VocabParallelEmbedding)
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
@ -37,8 +41,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsQuant
|
from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant
|
||||||
from .utils import maybe_prefix
|
from .utils import (PPMissingLayer, is_pp_missing_parameter,
|
||||||
|
make_empty_intermediate_tensors_factory, maybe_prefix)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -53,7 +58,7 @@ def vllm_flash_attention_forward(
|
|||||||
# Transformers kwargs
|
# Transformers kwargs
|
||||||
scaling: Optional[float] = None,
|
scaling: Optional[float] = None,
|
||||||
# vLLM kwargs
|
# vLLM kwargs
|
||||||
attention_instances: Optional[list[Attention]] = None,
|
attention_instances: Optional[dict[Attention]] = None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
self_attn = attention_instances[module.layer_idx]
|
self_attn = attention_instances[module.layer_idx]
|
||||||
if scaling is not None:
|
if scaling is not None:
|
||||||
@ -72,13 +77,12 @@ def log_replacement(name: str, old_module: nn.Module, new_module: nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def replace_linear_class(
|
def replace_linear_class(
|
||||||
linear: nn.Linear,
|
linear: nn.Linear, style: Literal["colwise", "rowwise"],
|
||||||
style: Literal["colwise", "rowwise"],
|
quant_config: QuantizationConfig
|
||||||
quant_config=None) -> Union[ColumnParallelLinear, RowParallelLinear]:
|
) -> Union[ColumnParallelLinear, RowParallelLinear]:
|
||||||
"""
|
"""
|
||||||
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
|
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
|
||||||
|
|
||||||
`quant_config` is not yet supported.
|
|
||||||
Args:
|
Args:
|
||||||
linear (nn.Linear): `nn.Linear` to be replaced.
|
linear (nn.Linear): `nn.Linear` to be replaced.
|
||||||
style (str): Tensor parallel style of the new linear, e.g. "colwise".
|
style (str): Tensor parallel style of the new linear, e.g. "colwise".
|
||||||
@ -105,7 +109,7 @@ def replace_linear_class(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
|
class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA, SupportsPP):
|
||||||
embedding_padding_modules = ["lm_head"]
|
embedding_padding_modules = ["lm_head"]
|
||||||
embedding_modules = ["embed_tokens"
|
embedding_modules = ["embed_tokens"
|
||||||
] # TODO transformers will have a util to get it
|
] # TODO transformers will have a util to get it
|
||||||
@ -114,31 +118,175 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
logger.info("Using Transformers backend.")
|
logger.info("Using Transformers backend.")
|
||||||
|
|
||||||
config = vllm_config.model_config.hf_config
|
config: PretrainedConfig = vllm_config.model_config.hf_config
|
||||||
cache_config = vllm_config.cache_config
|
cache_config: CacheConfig = vllm_config.cache_config
|
||||||
model_config = vllm_config.model_config
|
device_config: DeviceConfig = vllm_config.device_config
|
||||||
parallel_config = vllm_config.parallel_config
|
model_config: ModelConfig = vllm_config.model_config
|
||||||
|
parallel_config: ParallelConfig = vllm_config.parallel_config
|
||||||
|
quant_config: QuantizationConfig = vllm_config.quant_config
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.cache_config = cache_config
|
||||||
|
self.device_config = device_config
|
||||||
|
self.model_config = model_config
|
||||||
|
self.parallel_config = parallel_config
|
||||||
|
self.quant_config = quant_config
|
||||||
|
|
||||||
self.vocab_size = model_config.get_vocab_size()
|
self.vocab_size = model_config.get_vocab_size()
|
||||||
self.unpadded_vocab_size = model_config.get_vocab_size()
|
self.unpadded_vocab_size = model_config.get_vocab_size()
|
||||||
|
|
||||||
|
self.pp_group = get_pp_group()
|
||||||
|
self.pp_size = self.pp_group.world_size
|
||||||
|
self.pp_rank = self.pp_group.rank_in_group
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
|
||||||
|
# Use meta device to delay allocating GPU tensors
|
||||||
|
with torch.device("meta"):
|
||||||
self.model: PreTrainedModel = AutoModel.from_config(
|
self.model: PreTrainedModel = AutoModel.from_config(
|
||||||
self.config,
|
config,
|
||||||
attn_implementation="vllm",
|
attn_implementation="vllm",
|
||||||
torch_dtype=vllm_config.model_config.dtype,
|
torch_dtype=model_config.dtype,
|
||||||
trust_remote_code=vllm_config.model_config.trust_remote_code,
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
)
|
)
|
||||||
prefix = self.model.base_model_prefix
|
prefix = self.model.base_model_prefix
|
||||||
|
|
||||||
# MLP modifications
|
self.pipeline_parallel()
|
||||||
self.apply_base_model_tp_plan(self.model)
|
self.tensor_parallel()
|
||||||
|
|
||||||
# Attention modifications (assumes 1 attention op per hidden layer)
|
# Input embeddings
|
||||||
num_heads = model_config.get_num_attention_heads(parallel_config)
|
if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
|
||||||
head_size = model_config.get_head_size()
|
self.model.set_input_embeddings(
|
||||||
num_kv_heads = model_config.get_num_kv_heads(parallel_config)
|
VocabParallelEmbedding(
|
||||||
self.attention_instances = [
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=config.vocab_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Attention layers
|
||||||
|
self.attention_instances = self.create_attention_instances()
|
||||||
|
|
||||||
|
# Output embeddings
|
||||||
|
if not isinstance(getattr(self, "lm_head", None), PPMissingLayer):
|
||||||
|
self.unpadded_vocab_size = config.vocab_size
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=maybe_prefix(prefix, "lm_head"),
|
||||||
|
)
|
||||||
|
if config.tie_word_embeddings:
|
||||||
|
self.lm_head = self.lm_head.tie_weights(
|
||||||
|
self.model.get_input_embeddings())
|
||||||
|
|
||||||
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||||
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
|
config.vocab_size,
|
||||||
|
logit_scale)
|
||||||
|
|
||||||
|
# Initialize buffers (e.g. rotary embedding inverse frequency)
|
||||||
|
self.init_buffers(self.model)
|
||||||
|
|
||||||
|
# Move remaining meta tensors to device (should happen last)
|
||||||
|
self.meta_to_empty(self.model)
|
||||||
|
|
||||||
|
self.sampler = get_sampler()
|
||||||
|
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
make_empty_intermediate_tensors_factory(["hidden_states"],
|
||||||
|
config.hidden_size))
|
||||||
|
|
||||||
|
def pipeline_parallel(self):
|
||||||
|
"""
|
||||||
|
Apply the model's pipeline parallelization plan.
|
||||||
|
"""
|
||||||
|
if self.pp_size <= 1:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not self.model.supports_pp_plan:
|
||||||
|
raise ValueError(
|
||||||
|
f"{type(self.model)} does not support pipeline parallel yet!")
|
||||||
|
|
||||||
|
module_lists = []
|
||||||
|
module_list_idx = None
|
||||||
|
pp_plan = list(self.model._pp_plan.keys())
|
||||||
|
for i, name in enumerate(pp_plan):
|
||||||
|
if isinstance(getattr(self.model, name), nn.ModuleList):
|
||||||
|
module_lists.append(name)
|
||||||
|
module_list_idx = i
|
||||||
|
|
||||||
|
if len(module_lists) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Pipeline parallel of models with multiple `ModuleList`s "
|
||||||
|
"in the base model are not supported yet!")
|
||||||
|
if module_list_idx is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not find `ModuleList` in {type(self.model)}")
|
||||||
|
|
||||||
|
# Layers before module list
|
||||||
|
for name in pp_plan[:module_list_idx]:
|
||||||
|
if self.pp_group.is_first_rank or (self.config.tie_word_embeddings
|
||||||
|
and self.pp_group.is_last_rank):
|
||||||
|
continue
|
||||||
|
setattr(self.model, name, PPMissingLayer())
|
||||||
|
|
||||||
|
# Module list
|
||||||
|
start_layer, end_layer = get_pp_indices(self.config.num_hidden_layers,
|
||||||
|
self.pp_rank, self.pp_size)
|
||||||
|
layers_name = pp_plan[module_list_idx]
|
||||||
|
layers = getattr(self.model, layers_name)
|
||||||
|
for i in range(len(layers)):
|
||||||
|
if start_layer <= i and i < end_layer:
|
||||||
|
continue
|
||||||
|
layers[i] = PPMissingLayer(return_tuple=True)
|
||||||
|
|
||||||
|
# Layers after module list
|
||||||
|
for name in pp_plan[module_list_idx + 1:]:
|
||||||
|
# Modules that should be on last rank
|
||||||
|
if not self.pp_group.is_last_rank:
|
||||||
|
setattr(self.model, name, PPMissingLayer())
|
||||||
|
|
||||||
|
if not self.pp_group.is_last_rank:
|
||||||
|
self.lm_head = PPMissingLayer()
|
||||||
|
|
||||||
|
def tensor_parallel(self):
|
||||||
|
"""
|
||||||
|
Apply the model's tensor parallelization plan.
|
||||||
|
Currently only supports linear layers.
|
||||||
|
"""
|
||||||
|
if self.tp_size > 1 and self.config.base_model_tp_plan is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"{type(self.model)} does not support tensor parallel yet!")
|
||||||
|
|
||||||
|
tp_plan = self.model._tp_plan
|
||||||
|
|
||||||
|
def _tensor_parallel(module: nn.Module, prefix: str = ""):
|
||||||
|
for child_name, child_module in module.named_children():
|
||||||
|
qual_name = maybe_prefix(prefix, child_name)
|
||||||
|
for pattern, style in tp_plan.items():
|
||||||
|
if re.match(pattern, qual_name) and isinstance(
|
||||||
|
child_module, nn.Linear):
|
||||||
|
new_module = replace_linear_class(
|
||||||
|
child_module, style, self.quant_config)
|
||||||
|
setattr(module, child_name, new_module)
|
||||||
|
log_replacement(qual_name, child_module, new_module)
|
||||||
|
else:
|
||||||
|
_tensor_parallel(child_module, prefix=qual_name)
|
||||||
|
|
||||||
|
_tensor_parallel(self.model)
|
||||||
|
|
||||||
|
def create_attention_instances(self) -> dict[int, Attention]:
|
||||||
|
"""
|
||||||
|
Create `Attention` instances to inform KV cache allocation.
|
||||||
|
"""
|
||||||
|
num_heads = self.model_config.get_num_attention_heads(
|
||||||
|
self.parallel_config)
|
||||||
|
head_size = self.model_config.get_head_size()
|
||||||
|
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||||
|
start, end = get_pp_indices(self.config.num_hidden_layers,
|
||||||
|
self.pp_rank, self.pp_size)
|
||||||
|
return {
|
||||||
|
i:
|
||||||
Attention(
|
Attention(
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
head_size=head_size,
|
head_size=head_size,
|
||||||
@ -146,77 +294,70 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
|
|||||||
# Transformers, it's updated in vllm_flash_attention_forward
|
# Transformers, it's updated in vllm_flash_attention_forward
|
||||||
scale=head_size**-0.5,
|
scale=head_size**-0.5,
|
||||||
num_kv_heads=num_kv_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=self.cache_config,
|
||||||
quant_config=self.quant_config,
|
quant_config=self.quant_config,
|
||||||
prefix=f"{i}.attn") for i in range(config.num_hidden_layers)
|
prefix=f"{i}.attn")
|
||||||
]
|
for i in range(start, end)
|
||||||
|
}
|
||||||
|
|
||||||
# Model modifications
|
def init_buffers(self, module: nn.Module):
|
||||||
self.replace_vocab_embed_class(self.model)
|
|
||||||
|
|
||||||
# ForCausalLM modifications
|
|
||||||
self.lm_head = ParallelLMHead(self.vocab_size,
|
|
||||||
config.hidden_size,
|
|
||||||
quant_config=self.quant_config,
|
|
||||||
prefix=maybe_prefix(prefix, "lm_head"))
|
|
||||||
if config.tie_word_embeddings:
|
|
||||||
self.lm_head.weight = self.model.get_input_embeddings().weight
|
|
||||||
|
|
||||||
logit_scale = getattr(config, "logit_scale", 1.0)
|
|
||||||
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
|
||||||
self.vocab_size, logit_scale)
|
|
||||||
self.sampler = get_sampler()
|
|
||||||
|
|
||||||
def apply_base_model_tp_plan(self, module: nn.Module, prefix: str = ""):
|
|
||||||
"""
|
"""
|
||||||
Apply the base model tensor parallelization plan to a module.
|
If a `buffer` is on the `meta` device, then its parent
|
||||||
Currently only supports linear layers.
|
`module` is the original module created by:
|
||||||
|
|
||||||
|
```python
|
||||||
|
with torch.device("meta"):
|
||||||
|
self.model: PreTrainedModel = AutoModel.from_config(...)
|
||||||
|
```
|
||||||
|
|
||||||
|
This means that:
|
||||||
|
- `type(module)` is a class from `transformers`
|
||||||
|
- This class is constructed using a `PretrainedConfig`
|
||||||
"""
|
"""
|
||||||
if (self.config.base_model_tp_plan is None
|
for name, buffer in module.named_buffers(recurse=False):
|
||||||
and get_tensor_model_parallel_world_size() > 1):
|
if buffer.device == torch.device("meta"):
|
||||||
raise ValueError(
|
new_buffer = getattr(type(module)(self.config), name)
|
||||||
"Trying to run tensor parallelization but the model does not "
|
setattr(module, name, new_buffer)
|
||||||
"support it yet!")
|
for child in module.children():
|
||||||
|
self.init_buffers(child)
|
||||||
|
|
||||||
for child_name, child_module in module.named_children():
|
def meta_to_empty(self, module: nn.Module):
|
||||||
qual_name = maybe_prefix(prefix, child_name)
|
tensors = list(chain(module.buffers(), module.parameters()))
|
||||||
for pattern, style in self.config.base_model_tp_plan.items():
|
if tensors and all(t.device == torch.device("meta") for t in tensors):
|
||||||
if re.match(pattern, qual_name) and isinstance(
|
module.to_empty(device=self.device_config.device)
|
||||||
child_module, nn.Linear):
|
return # We can stop recursing because to_empty is recursive
|
||||||
new_module = replace_linear_class(child_module, style,
|
for child in module.children():
|
||||||
self.quant_config)
|
self.meta_to_empty(child)
|
||||||
setattr(module, child_name, new_module)
|
|
||||||
log_replacement(qual_name, child_module, new_module)
|
|
||||||
else:
|
|
||||||
self.apply_base_model_tp_plan(child_module, prefix=qual_name)
|
|
||||||
|
|
||||||
def replace_vocab_embed_class(self, module: nn.Module):
|
|
||||||
# Use native set input embeddings
|
|
||||||
new_module = VocabParallelEmbedding(
|
|
||||||
self.vocab_size,
|
|
||||||
self.config.hidden_size,
|
|
||||||
org_num_embeddings=self.vocab_size,
|
|
||||||
quant_config=None,
|
|
||||||
)
|
|
||||||
log_replacement("input embedding", self.model.get_input_embeddings(),
|
|
||||||
new_module)
|
|
||||||
module.set_input_embeddings(new_module)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: torch.Tensor,
|
input_ids: Optional[torch.Tensor],
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
inputs_embeds: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||||
model_output = self.model(
|
if not get_pp_group().is_first_rank:
|
||||||
input_ids[None, ...],
|
assert intermediate_tensors is not None
|
||||||
|
input_ids = None
|
||||||
|
inputs_embeds = intermediate_tensors["hidden_states"]
|
||||||
|
|
||||||
|
if input_ids is not None:
|
||||||
|
input_ids = input_ids[None, ...]
|
||||||
|
if inputs_embeds is not None:
|
||||||
|
inputs_embeds = inputs_embeds[None, ...]
|
||||||
|
|
||||||
|
hidden_states = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
use_cache=False,
|
use_cache=False,
|
||||||
position_ids=positions[None, ...],
|
position_ids=positions[None, ...],
|
||||||
intermediate_tensors=intermediate_tensors,
|
|
||||||
attention_instances=self.attention_instances,
|
attention_instances=self.attention_instances,
|
||||||
return_dict=False)[0][0, ...] # we remove batch dimension for now
|
return_dict=False)[0][0, ...] # we remove batch dimension for now
|
||||||
return model_output
|
|
||||||
|
if not get_pp_group().is_last_rank:
|
||||||
|
return IntermediateTensors({"hidden_states": hidden_states})
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
@ -238,8 +379,11 @@ class TransformersModel(nn.Module, SupportsQuant, SupportsLoRA):
|
|||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params = set[str]()
|
loaded_params = set[str]()
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
if name not in params_dict:
|
# Necessary for some models which use remote code
|
||||||
name = f"{self.model.base_model_prefix}.{name}"
|
if not name.startswith(prefix := self.model.base_model_prefix):
|
||||||
|
name = maybe_prefix(prefix, name)
|
||||||
|
if is_pp_missing_parameter(name, self):
|
||||||
|
continue
|
||||||
if name in params_dict:
|
if name in params_dict:
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(param, "weight_loader",
|
weight_loader = getattr(param, "weight_loader",
|
||||||
|
@ -472,6 +472,16 @@ class PPMissingLayer(torch.nn.Identity):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.return_tuple = kwargs.get("return_tuple", False)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Return the first arg from args or the first value from kwargs.
|
||||||
|
|
||||||
|
Wraps the input in a tuple if `self.return_tuple` is True.
|
||||||
|
"""
|
||||||
|
input = args[0] if args else next(iter(kwargs.values()))
|
||||||
|
return (input, ) if self.return_tuple else input
|
||||||
|
|
||||||
|
|
||||||
_CPU_OFFLOAD_BYTES = 0
|
_CPU_OFFLOAD_BYTES = 0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user