[torch.compile][TPU] Make @support_torch_compile work for XLA backend (#15782)

Signed-off-by: Siyuan Liu <lsiyuan@google.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Siyuan Liu 2025-04-07 23:23:53 -07:00 committed by GitHub
parent f6b32efb7f
commit 87918e40c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 76 deletions

View File

@ -15,6 +15,7 @@ import torch_xla.runtime as xr
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
@ -691,11 +692,10 @@ class TPUModelRunner:
hidden_states = self.model(
input_ids=input_ids,
positions=self.position_ids,
kv_caches=self.kv_caches,
inputs_embeds=inputs_embeds,
)
selected_token_ids = self.model.sample_from_hidden(
hidden_states, tpu_sampling_metadata)
selected_token_ids = self.sample_from_hidden(hidden_states,
tpu_sampling_metadata)
# Remove padding on cpu and keep dynamic op outside of xla graph.
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
@ -795,17 +795,15 @@ class TPUModelRunner:
"get_tensor_model_parallel_rank",
return_value=xm_tp_rank):
model = get_model(vllm_config=self.vllm_config)
model = model.eval()
# Sync all pending XLA execution during model initialization and weight
# loading.
xm.mark_step()
xm.wait_device_ops()
model = ModelWrapperV1(model)
self.model = torch.compile(model,
backend="openxla",
fullgraph=True,
dynamic=False)
self.model = model
self.sampler = TPUSampler()
@torch.no_grad()
def _dummy_run(self, kv_caches, num_tokens: int) -> None:
def _dummy_run(self, num_tokens: int) -> None:
if self.is_multimodal_model:
input_ids = None
inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
@ -856,7 +854,6 @@ class TPUModelRunner:
with set_forward_context(attn_metadata, self.vllm_config, 0):
out = self.model(input_ids=input_ids,
positions=position_ids,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds)
self._hidden_states_dtype = out.dtype
@ -868,7 +865,7 @@ class TPUModelRunner:
start = time.perf_counter()
for num_tokens in self.num_tokens_paddings:
logger.info(" -- num_tokens: %d", num_tokens)
self._dummy_run(self.kv_caches, num_tokens)
self._dummy_run(num_tokens)
xm.mark_step()
xm.wait_device_ops()
end = time.perf_counter()
@ -899,8 +896,7 @@ class TPUModelRunner:
from_input_batch(self.input_batch, indices)
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
num_reqs_to_sample)
out = self.model.sample_from_hidden(dummy_hidden,
sampling_meta)
out = self.sample_from_hidden(dummy_hidden, sampling_meta)
out = out.cpu()
# Requests can't be more than tokens. But do compile for the
# next bigger value in case num_tokens uses bucketed padding.
@ -954,45 +950,17 @@ class TPUModelRunner:
self.vllm_config.compilation_config.static_forward_context,
self.kv_caches)
class ModelWrapperV1(nn.Module):
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
self.sampler = TPUSampler()
def sample(
self, logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata) -> SamplerOutput:
sampler_out = self.sampler(logits, sampling_metadata)
return sampler_out
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: list[torch.Tensor],
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Executes the forward pass of the model.
Args:
input_ids: The input token IDs of shape [num_tokens].
positions: The input position IDs of shape [num_tokens].
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
inputs_embeds: The input embeddings of shape [num_tokens,
hidden_size]. It is used for multimodal models.
"""
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
return hidden_states
def reset_dynamo_cache(self):
if self.is_multimodal_model:
assert hasattr(self.model, "language_model")
compiled_model = self.model.language_model.model
else:
compiled_model = self.model.model
if isinstance(compiled_model, TorchCompileWrapperWithCustomDispatcher):
logger.info("Clear dynamo cache and cached dynamo bytecode.")
torch._dynamo.eval_frame.remove_from_cache(
compiled_model.original_code_object)
compiled_model.compiled_codes.clear()
def sample_from_hidden(
self,
@ -1001,32 +969,29 @@ class ModelWrapperV1(nn.Module):
) -> torch.Tensor:
"""
Sample with xla-friendly function. This function is to be traced
separately from `forward` for lighter compilation overhead.
separately for lighter compilation overhead.
"""
# Tensor `sample_hidden_states` is of fixed pre-compiled size.
sample_hidden_states = \
hidden_states[sampling_metadata.indices_do_sample]
logits = self.compute_logits(sample_hidden_states)
# SamplingMetadata here for pruning output in LogitsProcessor, disabled.
logits = self.model.compute_logits(sample_hidden_states, None)
def sample(
logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata
) -> SamplerOutput:
sampler_out = self.sampler(logits, sampling_metadata)
return sampler_out
# Optimized greedy sampling branch, tracing both paths in a single pass
# NOTE all_greedy is a scalar, this is just an optimized if/else.
out_tokens = torch.where(sampling_metadata.all_greedy,
out_tokens = torch.where(
sampling_metadata.all_greedy,
torch.argmax(logits, dim=-1, keepdim=True),
self.sample(logits, sampling_metadata)\
.sampled_token_ids)
sample(logits, sampling_metadata).sampled_token_ids)
return out_tokens
def compute_logits(self,
hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
# SamplingMetadata here for pruning output in LogitsProcessor, disabled
logits = self.model.compute_logits(hidden_states, None)
return logits
def get_multimodal_embeddings(self, *args, **kwargs):
return self.model.get_multimodal_embeddings(*args, **kwargs)
def get_input_embeddings(self, *args, **kwargs):
return self.model.get_input_embeddings(*args, **kwargs)
def _get_padded_number(n: int, multiple: int) -> int:
return ((n + multiple - 1) // multiple) * multiple

View File

@ -157,13 +157,19 @@ class TPUWorker:
runner_kv_caches)
self.model_runner._dummy_run(
runner_kv_caches,
num_tokens=self.scheduler_config.max_num_batched_tokens,
)
self.scheduler_config.max_num_batched_tokens)
# Synchronize before measuring the memory usage.
xm.wait_device_ops()
# During the profiling run, the model runs without KV cache. After
# the profiling run, the model always runs with KV cache. Here we clear
# the dynamo cache and cached bytecode to ensure the model always has
# one compiled bytecode. Having one FX graph/cached bytecode per
# compiled model is required for `support_torch_compile` decorator to
# skip dynamo guard.
self.model_runner.reset_dynamo_cache()
# Get the maximum amount of memory used by the model weights and
# intermediate activations.
m = xm.get_memory_info(self.device)