[torch.compile][TPU] Make @support_torch_compile work for XLA backend (#15782)

Signed-off-by: Siyuan Liu <lsiyuan@google.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Siyuan Liu 2025-04-07 23:23:53 -07:00 committed by GitHub
parent f6b32efb7f
commit 87918e40c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 76 deletions

View File

@ -15,6 +15,7 @@ import torch_xla.runtime as xr
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
@ -691,11 +692,10 @@ class TPUModelRunner:
hidden_states = self.model( hidden_states = self.model(
input_ids=input_ids, input_ids=input_ids,
positions=self.position_ids, positions=self.position_ids,
kv_caches=self.kv_caches,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
) )
selected_token_ids = self.model.sample_from_hidden( selected_token_ids = self.sample_from_hidden(hidden_states,
hidden_states, tpu_sampling_metadata) tpu_sampling_metadata)
# Remove padding on cpu and keep dynamic op outside of xla graph. # Remove padding on cpu and keep dynamic op outside of xla graph.
selected_token_ids = selected_token_ids.cpu()[:num_reqs] selected_token_ids = selected_token_ids.cpu()[:num_reqs]
@ -795,17 +795,15 @@ class TPUModelRunner:
"get_tensor_model_parallel_rank", "get_tensor_model_parallel_rank",
return_value=xm_tp_rank): return_value=xm_tp_rank):
model = get_model(vllm_config=self.vllm_config) model = get_model(vllm_config=self.vllm_config)
model = model.eval() # Sync all pending XLA execution during model initialization and weight
# loading.
xm.mark_step() xm.mark_step()
xm.wait_device_ops() xm.wait_device_ops()
model = ModelWrapperV1(model) self.model = model
self.model = torch.compile(model, self.sampler = TPUSampler()
backend="openxla",
fullgraph=True,
dynamic=False)
@torch.no_grad() @torch.no_grad()
def _dummy_run(self, kv_caches, num_tokens: int) -> None: def _dummy_run(self, num_tokens: int) -> None:
if self.is_multimodal_model: if self.is_multimodal_model:
input_ids = None input_ids = None
inputs_embeds = torch.zeros((num_tokens, self.hidden_size), inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
@ -856,7 +854,6 @@ class TPUModelRunner:
with set_forward_context(attn_metadata, self.vllm_config, 0): with set_forward_context(attn_metadata, self.vllm_config, 0):
out = self.model(input_ids=input_ids, out = self.model(input_ids=input_ids,
positions=position_ids, positions=position_ids,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds)
self._hidden_states_dtype = out.dtype self._hidden_states_dtype = out.dtype
@ -868,7 +865,7 @@ class TPUModelRunner:
start = time.perf_counter() start = time.perf_counter()
for num_tokens in self.num_tokens_paddings: for num_tokens in self.num_tokens_paddings:
logger.info(" -- num_tokens: %d", num_tokens) logger.info(" -- num_tokens: %d", num_tokens)
self._dummy_run(self.kv_caches, num_tokens) self._dummy_run(num_tokens)
xm.mark_step() xm.mark_step()
xm.wait_device_ops() xm.wait_device_ops()
end = time.perf_counter() end = time.perf_counter()
@ -899,8 +896,7 @@ class TPUModelRunner:
from_input_batch(self.input_batch, indices) from_input_batch(self.input_batch, indices)
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
num_reqs_to_sample) num_reqs_to_sample)
out = self.model.sample_from_hidden(dummy_hidden, out = self.sample_from_hidden(dummy_hidden, sampling_meta)
sampling_meta)
out = out.cpu() out = out.cpu()
# Requests can't be more than tokens. But do compile for the # Requests can't be more than tokens. But do compile for the
# next bigger value in case num_tokens uses bucketed padding. # next bigger value in case num_tokens uses bucketed padding.
@ -954,45 +950,17 @@ class TPUModelRunner:
self.vllm_config.compilation_config.static_forward_context, self.vllm_config.compilation_config.static_forward_context,
self.kv_caches) self.kv_caches)
def reset_dynamo_cache(self):
class ModelWrapperV1(nn.Module): if self.is_multimodal_model:
assert hasattr(self.model, "language_model")
def __init__(self, model: nn.Module): compiled_model = self.model.language_model.model
super().__init__() else:
self.model = model compiled_model = self.model.model
self.sampler = TPUSampler() if isinstance(compiled_model, TorchCompileWrapperWithCustomDispatcher):
logger.info("Clear dynamo cache and cached dynamo bytecode.")
def sample( torch._dynamo.eval_frame.remove_from_cache(
self, logits: torch.Tensor, compiled_model.original_code_object)
sampling_metadata: TPUSupportedSamplingMetadata) -> SamplerOutput: compiled_model.compiled_codes.clear()
sampler_out = self.sampler(logits, sampling_metadata)
return sampler_out
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: list[torch.Tensor],
inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Executes the forward pass of the model.
Args:
input_ids: The input token IDs of shape [num_tokens].
positions: The input position IDs of shape [num_tokens].
kv_caches: The key and value caches. They can be None during the
memory profiling at initialization.
inputs_embeds: The input embeddings of shape [num_tokens,
hidden_size]. It is used for multimodal models.
"""
hidden_states = self.model(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
)
return hidden_states
def sample_from_hidden( def sample_from_hidden(
self, self,
@ -1000,33 +968,30 @@ class ModelWrapperV1(nn.Module):
sampling_metadata: TPUSupportedSamplingMetadata, sampling_metadata: TPUSupportedSamplingMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Sample with xla-friendly function. This function is to be traced Sample with xla-friendly function. This function is to be traced
separately from `forward` for lighter compilation overhead. separately for lighter compilation overhead.
""" """
# Tensor `sample_hidden_states` is of fixed pre-compiled size. # Tensor `sample_hidden_states` is of fixed pre-compiled size.
sample_hidden_states = \ sample_hidden_states = \
hidden_states[sampling_metadata.indices_do_sample] hidden_states[sampling_metadata.indices_do_sample]
logits = self.compute_logits(sample_hidden_states) # SamplingMetadata here for pruning output in LogitsProcessor, disabled.
logits = self.model.compute_logits(sample_hidden_states, None)
def sample(
logits: torch.Tensor,
sampling_metadata: TPUSupportedSamplingMetadata
) -> SamplerOutput:
sampler_out = self.sampler(logits, sampling_metadata)
return sampler_out
# Optimized greedy sampling branch, tracing both paths in a single pass # Optimized greedy sampling branch, tracing both paths in a single pass
# NOTE all_greedy is a scalar, this is just an optimized if/else. # NOTE all_greedy is a scalar, this is just an optimized if/else.
out_tokens = torch.where(sampling_metadata.all_greedy, out_tokens = torch.where(
torch.argmax(logits, dim=-1, keepdim=True), sampling_metadata.all_greedy,
self.sample(logits, sampling_metadata)\ torch.argmax(logits, dim=-1, keepdim=True),
.sampled_token_ids) sample(logits, sampling_metadata).sampled_token_ids)
return out_tokens return out_tokens
def compute_logits(self,
hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
# SamplingMetadata here for pruning output in LogitsProcessor, disabled
logits = self.model.compute_logits(hidden_states, None)
return logits
def get_multimodal_embeddings(self, *args, **kwargs):
return self.model.get_multimodal_embeddings(*args, **kwargs)
def get_input_embeddings(self, *args, **kwargs):
return self.model.get_input_embeddings(*args, **kwargs)
def _get_padded_number(n: int, multiple: int) -> int: def _get_padded_number(n: int, multiple: int) -> int:
return ((n + multiple - 1) // multiple) * multiple return ((n + multiple - 1) // multiple) * multiple

View File

@ -157,13 +157,19 @@ class TPUWorker:
runner_kv_caches) runner_kv_caches)
self.model_runner._dummy_run( self.model_runner._dummy_run(
runner_kv_caches, self.scheduler_config.max_num_batched_tokens)
num_tokens=self.scheduler_config.max_num_batched_tokens,
)
# Synchronize before measuring the memory usage. # Synchronize before measuring the memory usage.
xm.wait_device_ops() xm.wait_device_ops()
# During the profiling run, the model runs without KV cache. After
# the profiling run, the model always runs with KV cache. Here we clear
# the dynamo cache and cached bytecode to ensure the model always has
# one compiled bytecode. Having one FX graph/cached bytecode per
# compiled model is required for `support_torch_compile` decorator to
# skip dynamo guard.
self.model_runner.reset_dynamo_cache()
# Get the maximum amount of memory used by the model weights and # Get the maximum amount of memory used by the model weights and
# intermediate activations. # intermediate activations.
m = xm.get_memory_info(self.device) m = xm.get_memory_info(self.device)