[Kernel] LoRA - Enable CUDAGraphs for V1 (#14626)
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
parent
32ef4983cd
commit
0b1cfa6180
@ -52,6 +52,7 @@ def test_worker_apply_lora(sql_lora_files):
|
|||||||
seed=0,
|
seed=0,
|
||||||
dtype="float16",
|
dtype="float16",
|
||||||
revision=None,
|
revision=None,
|
||||||
|
enforce_eager=True,
|
||||||
),
|
),
|
||||||
load_config=LoadConfig(
|
load_config=LoadConfig(
|
||||||
download_dir=None,
|
download_dir=None,
|
||||||
|
@ -2287,9 +2287,14 @@ class LoRAConfig:
|
|||||||
excluding anything before input ids/embeddings and after
|
excluding anything before input ids/embeddings and after
|
||||||
the final hidden states.
|
the final hidden states.
|
||||||
"""
|
"""
|
||||||
# no factors to consider.
|
|
||||||
# LoRA is not compatible with `torch.compile` .
|
|
||||||
factors: list[Any] = []
|
factors: list[Any] = []
|
||||||
|
factors.append(self.max_lora_rank)
|
||||||
|
factors.append(self.max_loras)
|
||||||
|
factors.append(self.fully_sharded_loras)
|
||||||
|
factors.append(self.lora_dtype)
|
||||||
|
factors.append(self.lora_extra_vocab_size)
|
||||||
|
factors.append(self.long_lora_scaling_factors)
|
||||||
|
factors.append(self.bias_enabled)
|
||||||
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
|
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
@ -3303,6 +3308,11 @@ class VllmConfig:
|
|||||||
vllm_factors.append("None")
|
vllm_factors.append("None")
|
||||||
if self.lora_config:
|
if self.lora_config:
|
||||||
vllm_factors.append(self.lora_config.compute_hash())
|
vllm_factors.append(self.lora_config.compute_hash())
|
||||||
|
# LoRA creates static buffers based on max_num_batched_tokens.
|
||||||
|
# The tensor sizes and strides get captured in the torch.compile
|
||||||
|
# graph explicitly.
|
||||||
|
vllm_factors.append(
|
||||||
|
str(self.scheduler_config.max_num_batched_tokens))
|
||||||
else:
|
else:
|
||||||
vllm_factors.append("None")
|
vllm_factors.append("None")
|
||||||
if self.speculative_config:
|
if self.speculative_config:
|
||||||
@ -3453,12 +3463,15 @@ class VllmConfig:
|
|||||||
" Disabling `torch.compile`.")
|
" Disabling `torch.compile`.")
|
||||||
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||||
|
|
||||||
if self.lora_config is not None and self.compilation_config.level !=\
|
if ((not envs.VLLM_USE_V1) and self.lora_config is not None
|
||||||
CompilationLevel.NO_COMPILATION:
|
and self.compilation_config.level
|
||||||
logger.warning("LoRA is not supported with `torch.compile` yet. "
|
!= CompilationLevel.NO_COMPILATION):
|
||||||
"Disabling `torch.compile`.")
|
logger.warning(
|
||||||
|
"LoRA for V0 is not supported with `torch.compile` yet. "
|
||||||
|
"Disabling `torch.compile`.")
|
||||||
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
self.compilation_config.level = CompilationLevel.NO_COMPILATION
|
||||||
|
|
||||||
|
|
||||||
if self.model_config and self.model_config.use_mla and \
|
if self.model_config and self.model_config.use_mla and \
|
||||||
not (current_platform.is_cuda() or current_platform.is_rocm()):
|
not (current_platform.is_cuda() or current_platform.is_rocm()):
|
||||||
logger.info(
|
logger.info(
|
||||||
|
@ -237,16 +237,19 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
|
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
added_tokens_mask = x > self.base_layer.org_vocab_size - 1
|
added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1,
|
||||||
embeddings_indices = self.punica_wrapper.embeddings_indices
|
1, 0)
|
||||||
indices = embeddings_indices[1].view_as(x)
|
embeddings_indices = torch.narrow(
|
||||||
|
self.punica_wrapper._embeddings_indices, 1, 0, x.size(0))
|
||||||
|
|
||||||
|
indices = embeddings_indices[1]
|
||||||
full_lora_a_embeddings = F.embedding(
|
full_lora_a_embeddings = F.embedding(
|
||||||
x + indices,
|
x + indices,
|
||||||
self.lora_a_stacked_2d,
|
self.lora_a_stacked_2d,
|
||||||
)
|
)
|
||||||
indices = embeddings_indices[0].view_as(x)
|
indices = embeddings_indices[0]
|
||||||
full_output = self.base_layer.forward(
|
full_output = self.base_layer.forward(x +
|
||||||
x.add_(indices * added_tokens_mask))
|
(indices * added_tokens_mask))
|
||||||
|
|
||||||
full_output_org = full_output
|
full_output_org = full_output
|
||||||
if full_output.ndim == 3:
|
if full_output.ndim == 3:
|
||||||
|
@ -254,7 +254,9 @@ class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
|
|||||||
y_org = y
|
y_org = y
|
||||||
y = y.view(-1, y.shape[-1])
|
y = y.view(-1, y.shape[-1])
|
||||||
if lora_bias_stacked is not None:
|
if lora_bias_stacked is not None:
|
||||||
self._apply_bias(self.token_lora_indices, y, output_slices,
|
token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0,
|
||||||
|
y.size(0))
|
||||||
|
self._apply_bias(token_lora_indices, y, output_slices,
|
||||||
lora_bias_stacked)
|
lora_bias_stacked)
|
||||||
|
|
||||||
if env.VLLM_USE_V1:
|
if env.VLLM_USE_V1:
|
||||||
@ -365,7 +367,9 @@ class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
|
|||||||
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
|
||||||
if lora_bias_stacked is not None:
|
if lora_bias_stacked is not None:
|
||||||
assert len(lora_bias_stacked) == len(output_slices)
|
assert len(lora_bias_stacked) == len(output_slices)
|
||||||
y = self._apply_bias(self.token_lora_indices, y, output_slices,
|
token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0,
|
||||||
|
y.size(0))
|
||||||
|
y = self._apply_bias(token_lora_indices, y, output_slices,
|
||||||
lora_bias_stacked)
|
lora_bias_stacked)
|
||||||
|
|
||||||
if buffer is None:
|
if buffer is None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user