[Hardware][Intel-Gaudi] Enable LoRA support for Intel Gaudi (HPU) (#10565)
Signed-off-by: Sanju C Sudhakaran <scsudhakaran@habana.ai>
This commit is contained in:
parent
f092153fbe
commit
8195824206
@ -8,4 +8,4 @@ pandas
|
|||||||
tabulate
|
tabulate
|
||||||
setuptools>=61
|
setuptools>=61
|
||||||
setuptools-scm>=8
|
setuptools-scm>=8
|
||||||
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@e096d6f
|
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@4312768
|
||||||
|
@ -30,6 +30,7 @@ from vllm.model_executor.layers.rotary_embedding import (
|
|||||||
LinearScalingRotaryEmbedding, RotaryEmbedding)
|
LinearScalingRotaryEmbedding, RotaryEmbedding)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.lora.punica_wrapper import PunicaWrapperBase
|
from vllm.lora.punica_wrapper import PunicaWrapperBase
|
||||||
@ -1068,6 +1069,11 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
|||||||
).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
|
).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
|
||||||
posinf=float("inf"),
|
posinf=float("inf"),
|
||||||
neginf=float("-inf")))
|
neginf=float("-inf")))
|
||||||
|
|
||||||
|
# HPU needs special handling to prune out dummy samples.
|
||||||
|
if current_platform.is_hpu():
|
||||||
|
lora_logits = lora_logits[:logits.shape[0], :]
|
||||||
|
|
||||||
logits[:,
|
logits[:,
|
||||||
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
|
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
|
||||||
lora_logits.shape[1]] = lora_logits
|
lora_logits.shape[1]] = lora_logits
|
||||||
|
87
vllm/lora/punica_wrapper/punica_hpu.py
Normal file
87
vllm/lora/punica_wrapper/punica_hpu.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
from typing import Optional, Tuple, Union, final
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm_hpu_extension.ops import (dispatch_bgmv_embedding,
|
||||||
|
dispatch_bgmv_linear)
|
||||||
|
|
||||||
|
from .punica_base import PunicaWrapperBase
|
||||||
|
|
||||||
|
|
||||||
|
@final
|
||||||
|
class PunicaWrapperHPU(PunicaWrapperBase):
|
||||||
|
|
||||||
|
def __init__(self, max_num_batched_tokens: int, max_batches: int,
|
||||||
|
device: Union[torch.device, str], **kwargs):
|
||||||
|
# Increasing max_num_batched_tokens by 3x to handle increase in
|
||||||
|
# tensor size due to padding.
|
||||||
|
PunicaWrapperBase.__init__(self, 3 * max_num_batched_tokens,
|
||||||
|
max_batches, device)
|
||||||
|
|
||||||
|
def add_lora_embedding(self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
lora_b_stacked: torch.Tensor,
|
||||||
|
add_input: bool = True,
|
||||||
|
**kwargs) -> None:
|
||||||
|
dispatch_bgmv_embedding(y, x, lora_b_stacked, 0)
|
||||||
|
|
||||||
|
def add_lora_linear(self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||||
|
lora_b_stacked: Tuple[torch.Tensor, ...],
|
||||||
|
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
||||||
|
scale: float,
|
||||||
|
output_slices: Tuple[int, ...],
|
||||||
|
*,
|
||||||
|
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
|
||||||
|
**kwargs) -> None:
|
||||||
|
y_org = y
|
||||||
|
x = x.view(-1, x.shape[-1])
|
||||||
|
y = y.view(-1, y.shape[-1])
|
||||||
|
offset_left = 0
|
||||||
|
|
||||||
|
for slice_idx in range(len(output_slices)):
|
||||||
|
dispatch_bgmv_linear(
|
||||||
|
y[:, offset_left:offset_left + output_slices[slice_idx]], x,
|
||||||
|
lora_a_stacked[slice_idx], lora_b_stacked[slice_idx], 0, scale)
|
||||||
|
offset_left += output_slices[slice_idx]
|
||||||
|
y = y.view_as(y_org)
|
||||||
|
|
||||||
|
def add_lora_logits(self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: torch.Tensor,
|
||||||
|
lora_a_stacked: torch.Tensor,
|
||||||
|
lora_b_stacked: torch.Tensor,
|
||||||
|
scale,
|
||||||
|
*,
|
||||||
|
buffer: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs) -> None:
|
||||||
|
y_org = y
|
||||||
|
y = y.view(-1, y.shape[-1])
|
||||||
|
x = x.view(-1, x.shape[-1])
|
||||||
|
dispatch_bgmv_linear(y, x, lora_a_stacked, lora_b_stacked, 0, scale)
|
||||||
|
y = y.view_as(y_org)
|
||||||
|
|
||||||
|
def add_shrink(
|
||||||
|
self,
|
||||||
|
y: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
||||||
|
x: torch.Tensor,
|
||||||
|
lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||||
|
scale: float,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def add_expand(
|
||||||
|
self,
|
||||||
|
y: torch.Tensor,
|
||||||
|
x: Union[Tuple[torch.Tensor, ...], torch.Tensor],
|
||||||
|
lora_b_stacked: Tuple[torch.Tensor, ...],
|
||||||
|
lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]],
|
||||||
|
output_slices: Tuple[int, ...],
|
||||||
|
offset_start: int = 0,
|
||||||
|
add_input=True,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError
|
@ -10,5 +10,10 @@ def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase:
|
|||||||
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
|
from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU
|
||||||
print_info_once("Using PunicaWrapperGPU.")
|
print_info_once("Using PunicaWrapperGPU.")
|
||||||
return PunicaWrapperGPU(*args, **kwargs)
|
return PunicaWrapperGPU(*args, **kwargs)
|
||||||
|
elif current_platform.is_hpu():
|
||||||
|
# Lazy import to avoid ImportError
|
||||||
|
from vllm.lora.punica_wrapper.punica_hpu import PunicaWrapperHPU
|
||||||
|
print_info_once("Using PunicaWrapperHPU.")
|
||||||
|
return PunicaWrapperHPU(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -622,6 +622,10 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
|||||||
assert hasattr(
|
assert hasattr(
|
||||||
self.model, "embedding_padding_modules"
|
self.model, "embedding_padding_modules"
|
||||||
), "Model does not have embedding_padding_modules"
|
), "Model does not have embedding_padding_modules"
|
||||||
|
assert not self.lora_config.bias_enabled, \
|
||||||
|
"Bias support in LoRA is not enabled in HPU yet."
|
||||||
|
assert not self.lora_config.fully_sharded_loras, \
|
||||||
|
"Fully sharded LoRAs is not enabled in HPU yet."
|
||||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||||
self.scheduler_config.max_num_seqs,
|
self.scheduler_config.max_num_seqs,
|
||||||
self.scheduler_config.max_num_batched_tokens,
|
self.scheduler_config.max_num_batched_tokens,
|
||||||
@ -1282,11 +1286,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
|||||||
def profile_run(self) -> None:
|
def profile_run(self) -> None:
|
||||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||||
kv_caches = [None] * num_layers
|
kv_caches = [None] * num_layers
|
||||||
max_batch_size = self.bucketing_global_state.prompt_bs_bucket_cfg[-1]
|
max_seq_len = self.bucketing_global_state.prompt_seq_bucket_cfg[-1]
|
||||||
max_seq_len = min(
|
max_batch_size = min(self.max_num_batched_tokens // max_seq_len,
|
||||||
self.bucketing_global_state.prompt_seq_bucket_cfg[-1],
|
self.scheduler_config.max_num_seqs)
|
||||||
self.max_num_batched_tokens // max_batch_size)
|
|
||||||
|
|
||||||
self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches,
|
self.warmup_scenario(max_batch_size, max_seq_len, True, kv_caches,
|
||||||
False, True)
|
False, True)
|
||||||
return
|
return
|
||||||
@ -1304,7 +1306,6 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
|||||||
f"bs{batch_size}_"
|
f"bs{batch_size}_"
|
||||||
f"seq{seq_len}_"
|
f"seq{seq_len}_"
|
||||||
f"graphs{'T' if use_graphs else 'F'}")
|
f"graphs{'T' if use_graphs else 'F'}")
|
||||||
max_num_seqs = self.scheduler_config.max_num_seqs
|
|
||||||
# This represents the maximum number of different requests
|
# This represents the maximum number of different requests
|
||||||
# that will have unique loras, an therefore the max amount of memory
|
# that will have unique loras, an therefore the max amount of memory
|
||||||
# consumption create dummy lora request copies from the lora request
|
# consumption create dummy lora request copies from the lora request
|
||||||
@ -1326,16 +1327,10 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
|||||||
dummy_lora_requests.append(dummy_lora_request)
|
dummy_lora_requests.append(dummy_lora_request)
|
||||||
dummy_lora_requests_per_seq = [
|
dummy_lora_requests_per_seq = [
|
||||||
dummy_lora_requests[idx % len(dummy_lora_requests)]
|
dummy_lora_requests[idx % len(dummy_lora_requests)]
|
||||||
for idx in range(max_num_seqs)
|
for idx in range(batch_size)
|
||||||
]
|
]
|
||||||
self.profiler.start('internal', scenario_name)
|
self.profiler.start('internal', scenario_name)
|
||||||
times = 3 if use_graphs or is_pt_profiler_run else 1
|
times = 3 if use_graphs or is_pt_profiler_run else 1
|
||||||
if self.lora_config and not is_lora_profile_run:
|
|
||||||
lora_mapping = LoRAMapping(
|
|
||||||
**dict(index_mapping=[0] * batch_size * seq_len,
|
|
||||||
prompt_mapping=[0] * batch_size * seq_len,
|
|
||||||
is_prefill=is_prompt))
|
|
||||||
self.set_active_loras(set(), lora_mapping)
|
|
||||||
if is_prompt:
|
if is_prompt:
|
||||||
seqs = [
|
seqs = [
|
||||||
self.create_dummy_seq_group_metadata(
|
self.create_dummy_seq_group_metadata(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user