[torch.compile][TPU] Make @support_torch_compile work for XLA backend (#15782)
Signed-off-by: Siyuan Liu <lsiyuan@google.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
f6b32efb7f
commit
87918e40c4
@ -15,6 +15,7 @@ import torch_xla.runtime as xr
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
@ -691,11 +692,10 @@ class TPUModelRunner:
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=self.position_ids,
|
||||
kv_caches=self.kv_caches,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
selected_token_ids = self.model.sample_from_hidden(
|
||||
hidden_states, tpu_sampling_metadata)
|
||||
selected_token_ids = self.sample_from_hidden(hidden_states,
|
||||
tpu_sampling_metadata)
|
||||
# Remove padding on cpu and keep dynamic op outside of xla graph.
|
||||
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
|
||||
|
||||
@ -795,17 +795,15 @@ class TPUModelRunner:
|
||||
"get_tensor_model_parallel_rank",
|
||||
return_value=xm_tp_rank):
|
||||
model = get_model(vllm_config=self.vllm_config)
|
||||
model = model.eval()
|
||||
# Sync all pending XLA execution during model initialization and weight
|
||||
# loading.
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
model = ModelWrapperV1(model)
|
||||
self.model = torch.compile(model,
|
||||
backend="openxla",
|
||||
fullgraph=True,
|
||||
dynamic=False)
|
||||
self.model = model
|
||||
self.sampler = TPUSampler()
|
||||
|
||||
@torch.no_grad()
|
||||
def _dummy_run(self, kv_caches, num_tokens: int) -> None:
|
||||
def _dummy_run(self, num_tokens: int) -> None:
|
||||
if self.is_multimodal_model:
|
||||
input_ids = None
|
||||
inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
|
||||
@ -856,7 +854,6 @@ class TPUModelRunner:
|
||||
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
||||
out = self.model(input_ids=input_ids,
|
||||
positions=position_ids,
|
||||
kv_caches=kv_caches,
|
||||
inputs_embeds=inputs_embeds)
|
||||
self._hidden_states_dtype = out.dtype
|
||||
|
||||
@ -868,7 +865,7 @@ class TPUModelRunner:
|
||||
start = time.perf_counter()
|
||||
for num_tokens in self.num_tokens_paddings:
|
||||
logger.info(" -- num_tokens: %d", num_tokens)
|
||||
self._dummy_run(self.kv_caches, num_tokens)
|
||||
self._dummy_run(num_tokens)
|
||||
xm.mark_step()
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
@ -899,8 +896,7 @@ class TPUModelRunner:
|
||||
from_input_batch(self.input_batch, indices)
|
||||
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
|
||||
num_reqs_to_sample)
|
||||
out = self.model.sample_from_hidden(dummy_hidden,
|
||||
sampling_meta)
|
||||
out = self.sample_from_hidden(dummy_hidden, sampling_meta)
|
||||
out = out.cpu()
|
||||
# Requests can't be more than tokens. But do compile for the
|
||||
# next bigger value in case num_tokens uses bucketed padding.
|
||||
@ -954,45 +950,17 @@ class TPUModelRunner:
|
||||
self.vllm_config.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
|
||||
|
||||
class ModelWrapperV1(nn.Module):
|
||||
|
||||
def __init__(self, model: nn.Module):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.sampler = TPUSampler()
|
||||
|
||||
def sample(
|
||||
self, logits: torch.Tensor,
|
||||
sampling_metadata: TPUSupportedSamplingMetadata) -> SamplerOutput:
|
||||
sampler_out = self.sampler(logits, sampling_metadata)
|
||||
return sampler_out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: list[torch.Tensor],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Executes the forward pass of the model.
|
||||
|
||||
Args:
|
||||
input_ids: The input token IDs of shape [num_tokens].
|
||||
positions: The input position IDs of shape [num_tokens].
|
||||
kv_caches: The key and value caches. They can be None during the
|
||||
memory profiling at initialization.
|
||||
inputs_embeds: The input embeddings of shape [num_tokens,
|
||||
hidden_size]. It is used for multimodal models.
|
||||
"""
|
||||
|
||||
hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
def reset_dynamo_cache(self):
|
||||
if self.is_multimodal_model:
|
||||
assert hasattr(self.model, "language_model")
|
||||
compiled_model = self.model.language_model.model
|
||||
else:
|
||||
compiled_model = self.model.model
|
||||
if isinstance(compiled_model, TorchCompileWrapperWithCustomDispatcher):
|
||||
logger.info("Clear dynamo cache and cached dynamo bytecode.")
|
||||
torch._dynamo.eval_frame.remove_from_cache(
|
||||
compiled_model.original_code_object)
|
||||
compiled_model.compiled_codes.clear()
|
||||
|
||||
def sample_from_hidden(
|
||||
self,
|
||||
@ -1001,32 +969,29 @@ class ModelWrapperV1(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Sample with xla-friendly function. This function is to be traced
|
||||
separately from `forward` for lighter compilation overhead.
|
||||
separately for lighter compilation overhead.
|
||||
"""
|
||||
# Tensor `sample_hidden_states` is of fixed pre-compiled size.
|
||||
sample_hidden_states = \
|
||||
hidden_states[sampling_metadata.indices_do_sample]
|
||||
logits = self.compute_logits(sample_hidden_states)
|
||||
# SamplingMetadata here for pruning output in LogitsProcessor, disabled.
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
|
||||
def sample(
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: TPUSupportedSamplingMetadata
|
||||
) -> SamplerOutput:
|
||||
sampler_out = self.sampler(logits, sampling_metadata)
|
||||
return sampler_out
|
||||
|
||||
# Optimized greedy sampling branch, tracing both paths in a single pass
|
||||
# NOTE all_greedy is a scalar, this is just an optimized if/else.
|
||||
out_tokens = torch.where(sampling_metadata.all_greedy,
|
||||
out_tokens = torch.where(
|
||||
sampling_metadata.all_greedy,
|
||||
torch.argmax(logits, dim=-1, keepdim=True),
|
||||
self.sample(logits, sampling_metadata)\
|
||||
.sampled_token_ids)
|
||||
sample(logits, sampling_metadata).sampled_token_ids)
|
||||
return out_tokens
|
||||
|
||||
def compute_logits(self,
|
||||
hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
# SamplingMetadata here for pruning output in LogitsProcessor, disabled
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
return logits
|
||||
|
||||
def get_multimodal_embeddings(self, *args, **kwargs):
|
||||
return self.model.get_multimodal_embeddings(*args, **kwargs)
|
||||
|
||||
def get_input_embeddings(self, *args, **kwargs):
|
||||
return self.model.get_input_embeddings(*args, **kwargs)
|
||||
|
||||
|
||||
def _get_padded_number(n: int, multiple: int) -> int:
|
||||
return ((n + multiple - 1) // multiple) * multiple
|
||||
|
@ -157,13 +157,19 @@ class TPUWorker:
|
||||
runner_kv_caches)
|
||||
|
||||
self.model_runner._dummy_run(
|
||||
runner_kv_caches,
|
||||
num_tokens=self.scheduler_config.max_num_batched_tokens,
|
||||
)
|
||||
self.scheduler_config.max_num_batched_tokens)
|
||||
|
||||
# Synchronize before measuring the memory usage.
|
||||
xm.wait_device_ops()
|
||||
|
||||
# During the profiling run, the model runs without KV cache. After
|
||||
# the profiling run, the model always runs with KV cache. Here we clear
|
||||
# the dynamo cache and cached bytecode to ensure the model always has
|
||||
# one compiled bytecode. Having one FX graph/cached bytecode per
|
||||
# compiled model is required for `support_torch_compile` decorator to
|
||||
# skip dynamo guard.
|
||||
self.model_runner.reset_dynamo_cache()
|
||||
|
||||
# Get the maximum amount of memory used by the model weights and
|
||||
# intermediate activations.
|
||||
m = xm.get_memory_info(self.device)
|
||||
|
Loading…
x
Reference in New Issue
Block a user