[Kernel] LoRA - Enable CUDAGraphs for V1 (#14626)

Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-03-13 23:42:04 -04:00 committed by GitHub
parent 32ef4983cd
commit 0b1cfa6180
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 35 additions and 14 deletions

View File

@ -52,6 +52,7 @@ def test_worker_apply_lora(sql_lora_files):
seed=0, seed=0,
dtype="float16", dtype="float16",
revision=None, revision=None,
enforce_eager=True,
), ),
load_config=LoadConfig( load_config=LoadConfig(
download_dir=None, download_dir=None,

View File

@ -2287,9 +2287,14 @@ class LoRAConfig:
excluding anything before input ids/embeddings and after excluding anything before input ids/embeddings and after
the final hidden states. the final hidden states.
""" """
# no factors to consider.
# LoRA is not compatible with `torch.compile` .
factors: list[Any] = [] factors: list[Any] = []
factors.append(self.max_lora_rank)
factors.append(self.max_loras)
factors.append(self.fully_sharded_loras)
factors.append(self.lora_dtype)
factors.append(self.lora_extra_vocab_size)
factors.append(self.long_lora_scaling_factors)
factors.append(self.bias_enabled)
hash_str = hashlib.md5(str(factors).encode()).hexdigest() hash_str = hashlib.md5(str(factors).encode()).hexdigest()
return hash_str return hash_str
@ -3303,6 +3308,11 @@ class VllmConfig:
vllm_factors.append("None") vllm_factors.append("None")
if self.lora_config: if self.lora_config:
vllm_factors.append(self.lora_config.compute_hash()) vllm_factors.append(self.lora_config.compute_hash())
# LoRA creates static buffers based on max_num_batched_tokens.
# The tensor sizes and strides get captured in the torch.compile
# graph explicitly.
vllm_factors.append(
str(self.scheduler_config.max_num_batched_tokens))
else: else:
vllm_factors.append("None") vllm_factors.append("None")
if self.speculative_config: if self.speculative_config:
@ -3453,12 +3463,15 @@ class VllmConfig:
" Disabling `torch.compile`.") " Disabling `torch.compile`.")
self.compilation_config.level = CompilationLevel.NO_COMPILATION self.compilation_config.level = CompilationLevel.NO_COMPILATION
if self.lora_config is not None and self.compilation_config.level !=\ if ((not envs.VLLM_USE_V1) and self.lora_config is not None
CompilationLevel.NO_COMPILATION: and self.compilation_config.level
logger.warning("LoRA is not supported with `torch.compile` yet. " != CompilationLevel.NO_COMPILATION):
"Disabling `torch.compile`.") logger.warning(
"LoRA for V0 is not supported with `torch.compile` yet. "
"Disabling `torch.compile`.")
self.compilation_config.level = CompilationLevel.NO_COMPILATION self.compilation_config.level = CompilationLevel.NO_COMPILATION
if self.model_config and self.model_config.use_mla and \ if self.model_config and self.model_config.use_mla and \
not (current_platform.is_cuda() or current_platform.is_rocm()): not (current_platform.is_cuda() or current_platform.is_rocm()):
logger.info( logger.info(

View File

@ -237,16 +237,19 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
added_tokens_mask = x > self.base_layer.org_vocab_size - 1 added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1,
embeddings_indices = self.punica_wrapper.embeddings_indices 1, 0)
indices = embeddings_indices[1].view_as(x) embeddings_indices = torch.narrow(
self.punica_wrapper._embeddings_indices, 1, 0, x.size(0))
indices = embeddings_indices[1]
full_lora_a_embeddings = F.embedding( full_lora_a_embeddings = F.embedding(
x + indices, x + indices,
self.lora_a_stacked_2d, self.lora_a_stacked_2d,
) )
indices = embeddings_indices[0].view_as(x) indices = embeddings_indices[0]
full_output = self.base_layer.forward( full_output = self.base_layer.forward(x +
x.add_(indices * added_tokens_mask)) (indices * added_tokens_mask))
full_output_org = full_output full_output_org = full_output
if full_output.ndim == 3: if full_output.ndim == 3:

View File

@ -254,7 +254,9 @@ class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
y_org = y y_org = y
y = y.view(-1, y.shape[-1]) y = y.view(-1, y.shape[-1])
if lora_bias_stacked is not None: if lora_bias_stacked is not None:
self._apply_bias(self.token_lora_indices, y, output_slices, token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0,
y.size(0))
self._apply_bias(token_lora_indices, y, output_slices,
lora_bias_stacked) lora_bias_stacked)
if env.VLLM_USE_V1: if env.VLLM_USE_V1:
@ -365,7 +367,9 @@ class PunicaWrapperGPU(PunicaWrapperBase, V1KernelMixin):
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if lora_bias_stacked is not None: if lora_bias_stacked is not None:
assert len(lora_bias_stacked) == len(output_slices) assert len(lora_bias_stacked) == len(output_slices)
y = self._apply_bias(self.token_lora_indices, y, output_slices, token_lora_indices = torch.narrow(self._token_lora_indices, 0, 0,
y.size(0))
y = self._apply_bias(token_lora_indices, y, output_slices,
lora_bias_stacked) lora_bias_stacked)
if buffer is None: if buffer is None: