[torch.compile][TPU] Make @support_torch_compile work for XLA backend (#15782)
Signed-off-by: Siyuan Liu <lsiyuan@google.com> Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
parent
f6b32efb7f
commit
87918e40c4
@ -15,6 +15,7 @@ import torch_xla.runtime as xr
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention.backends.abstract import AttentionType
|
from vllm.attention.backends.abstract import AttentionType
|
||||||
from vllm.attention.layer import Attention
|
from vllm.attention.layer import Attention
|
||||||
|
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -691,11 +692,10 @@ class TPUModelRunner:
|
|||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=self.position_ids,
|
positions=self.position_ids,
|
||||||
kv_caches=self.kv_caches,
|
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
selected_token_ids = self.model.sample_from_hidden(
|
selected_token_ids = self.sample_from_hidden(hidden_states,
|
||||||
hidden_states, tpu_sampling_metadata)
|
tpu_sampling_metadata)
|
||||||
# Remove padding on cpu and keep dynamic op outside of xla graph.
|
# Remove padding on cpu and keep dynamic op outside of xla graph.
|
||||||
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
|
selected_token_ids = selected_token_ids.cpu()[:num_reqs]
|
||||||
|
|
||||||
@ -795,17 +795,15 @@ class TPUModelRunner:
|
|||||||
"get_tensor_model_parallel_rank",
|
"get_tensor_model_parallel_rank",
|
||||||
return_value=xm_tp_rank):
|
return_value=xm_tp_rank):
|
||||||
model = get_model(vllm_config=self.vllm_config)
|
model = get_model(vllm_config=self.vllm_config)
|
||||||
model = model.eval()
|
# Sync all pending XLA execution during model initialization and weight
|
||||||
|
# loading.
|
||||||
xm.mark_step()
|
xm.mark_step()
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
model = ModelWrapperV1(model)
|
self.model = model
|
||||||
self.model = torch.compile(model,
|
self.sampler = TPUSampler()
|
||||||
backend="openxla",
|
|
||||||
fullgraph=True,
|
|
||||||
dynamic=False)
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _dummy_run(self, kv_caches, num_tokens: int) -> None:
|
def _dummy_run(self, num_tokens: int) -> None:
|
||||||
if self.is_multimodal_model:
|
if self.is_multimodal_model:
|
||||||
input_ids = None
|
input_ids = None
|
||||||
inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
|
inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
|
||||||
@ -856,7 +854,6 @@ class TPUModelRunner:
|
|||||||
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
||||||
out = self.model(input_ids=input_ids,
|
out = self.model(input_ids=input_ids,
|
||||||
positions=position_ids,
|
positions=position_ids,
|
||||||
kv_caches=kv_caches,
|
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
self._hidden_states_dtype = out.dtype
|
self._hidden_states_dtype = out.dtype
|
||||||
|
|
||||||
@ -868,7 +865,7 @@ class TPUModelRunner:
|
|||||||
start = time.perf_counter()
|
start = time.perf_counter()
|
||||||
for num_tokens in self.num_tokens_paddings:
|
for num_tokens in self.num_tokens_paddings:
|
||||||
logger.info(" -- num_tokens: %d", num_tokens)
|
logger.info(" -- num_tokens: %d", num_tokens)
|
||||||
self._dummy_run(self.kv_caches, num_tokens)
|
self._dummy_run(num_tokens)
|
||||||
xm.mark_step()
|
xm.mark_step()
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
@ -899,8 +896,7 @@ class TPUModelRunner:
|
|||||||
from_input_batch(self.input_batch, indices)
|
from_input_batch(self.input_batch, indices)
|
||||||
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
|
logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens,
|
||||||
num_reqs_to_sample)
|
num_reqs_to_sample)
|
||||||
out = self.model.sample_from_hidden(dummy_hidden,
|
out = self.sample_from_hidden(dummy_hidden, sampling_meta)
|
||||||
sampling_meta)
|
|
||||||
out = out.cpu()
|
out = out.cpu()
|
||||||
# Requests can't be more than tokens. But do compile for the
|
# Requests can't be more than tokens. But do compile for the
|
||||||
# next bigger value in case num_tokens uses bucketed padding.
|
# next bigger value in case num_tokens uses bucketed padding.
|
||||||
@ -954,45 +950,17 @@ class TPUModelRunner:
|
|||||||
self.vllm_config.compilation_config.static_forward_context,
|
self.vllm_config.compilation_config.static_forward_context,
|
||||||
self.kv_caches)
|
self.kv_caches)
|
||||||
|
|
||||||
|
def reset_dynamo_cache(self):
|
||||||
class ModelWrapperV1(nn.Module):
|
if self.is_multimodal_model:
|
||||||
|
assert hasattr(self.model, "language_model")
|
||||||
def __init__(self, model: nn.Module):
|
compiled_model = self.model.language_model.model
|
||||||
super().__init__()
|
else:
|
||||||
self.model = model
|
compiled_model = self.model.model
|
||||||
self.sampler = TPUSampler()
|
if isinstance(compiled_model, TorchCompileWrapperWithCustomDispatcher):
|
||||||
|
logger.info("Clear dynamo cache and cached dynamo bytecode.")
|
||||||
def sample(
|
torch._dynamo.eval_frame.remove_from_cache(
|
||||||
self, logits: torch.Tensor,
|
compiled_model.original_code_object)
|
||||||
sampling_metadata: TPUSupportedSamplingMetadata) -> SamplerOutput:
|
compiled_model.compiled_codes.clear()
|
||||||
sampler_out = self.sampler(logits, sampling_metadata)
|
|
||||||
return sampler_out
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
kv_caches: list[torch.Tensor],
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Executes the forward pass of the model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_ids: The input token IDs of shape [num_tokens].
|
|
||||||
positions: The input position IDs of shape [num_tokens].
|
|
||||||
kv_caches: The key and value caches. They can be None during the
|
|
||||||
memory profiling at initialization.
|
|
||||||
inputs_embeds: The input embeddings of shape [num_tokens,
|
|
||||||
hidden_size]. It is used for multimodal models.
|
|
||||||
"""
|
|
||||||
|
|
||||||
hidden_states = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
positions=positions,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
)
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def sample_from_hidden(
|
def sample_from_hidden(
|
||||||
self,
|
self,
|
||||||
@ -1000,33 +968,30 @@ class ModelWrapperV1(nn.Module):
|
|||||||
sampling_metadata: TPUSupportedSamplingMetadata,
|
sampling_metadata: TPUSupportedSamplingMetadata,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Sample with xla-friendly function. This function is to be traced
|
Sample with xla-friendly function. This function is to be traced
|
||||||
separately from `forward` for lighter compilation overhead.
|
separately for lighter compilation overhead.
|
||||||
"""
|
"""
|
||||||
# Tensor `sample_hidden_states` is of fixed pre-compiled size.
|
# Tensor `sample_hidden_states` is of fixed pre-compiled size.
|
||||||
sample_hidden_states = \
|
sample_hidden_states = \
|
||||||
hidden_states[sampling_metadata.indices_do_sample]
|
hidden_states[sampling_metadata.indices_do_sample]
|
||||||
logits = self.compute_logits(sample_hidden_states)
|
# SamplingMetadata here for pruning output in LogitsProcessor, disabled.
|
||||||
|
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: TPUSupportedSamplingMetadata
|
||||||
|
) -> SamplerOutput:
|
||||||
|
sampler_out = self.sampler(logits, sampling_metadata)
|
||||||
|
return sampler_out
|
||||||
|
|
||||||
# Optimized greedy sampling branch, tracing both paths in a single pass
|
# Optimized greedy sampling branch, tracing both paths in a single pass
|
||||||
# NOTE all_greedy is a scalar, this is just an optimized if/else.
|
# NOTE all_greedy is a scalar, this is just an optimized if/else.
|
||||||
out_tokens = torch.where(sampling_metadata.all_greedy,
|
out_tokens = torch.where(
|
||||||
torch.argmax(logits, dim=-1, keepdim=True),
|
sampling_metadata.all_greedy,
|
||||||
self.sample(logits, sampling_metadata)\
|
torch.argmax(logits, dim=-1, keepdim=True),
|
||||||
.sampled_token_ids)
|
sample(logits, sampling_metadata).sampled_token_ids)
|
||||||
return out_tokens
|
return out_tokens
|
||||||
|
|
||||||
def compute_logits(self,
|
|
||||||
hidden_states: torch.Tensor) -> Optional[torch.Tensor]:
|
|
||||||
# SamplingMetadata here for pruning output in LogitsProcessor, disabled
|
|
||||||
logits = self.model.compute_logits(hidden_states, None)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
def get_multimodal_embeddings(self, *args, **kwargs):
|
|
||||||
return self.model.get_multimodal_embeddings(*args, **kwargs)
|
|
||||||
|
|
||||||
def get_input_embeddings(self, *args, **kwargs):
|
|
||||||
return self.model.get_input_embeddings(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_padded_number(n: int, multiple: int) -> int:
|
def _get_padded_number(n: int, multiple: int) -> int:
|
||||||
return ((n + multiple - 1) // multiple) * multiple
|
return ((n + multiple - 1) // multiple) * multiple
|
||||||
|
@ -157,13 +157,19 @@ class TPUWorker:
|
|||||||
runner_kv_caches)
|
runner_kv_caches)
|
||||||
|
|
||||||
self.model_runner._dummy_run(
|
self.model_runner._dummy_run(
|
||||||
runner_kv_caches,
|
self.scheduler_config.max_num_batched_tokens)
|
||||||
num_tokens=self.scheduler_config.max_num_batched_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Synchronize before measuring the memory usage.
|
# Synchronize before measuring the memory usage.
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
|
|
||||||
|
# During the profiling run, the model runs without KV cache. After
|
||||||
|
# the profiling run, the model always runs with KV cache. Here we clear
|
||||||
|
# the dynamo cache and cached bytecode to ensure the model always has
|
||||||
|
# one compiled bytecode. Having one FX graph/cached bytecode per
|
||||||
|
# compiled model is required for `support_torch_compile` decorator to
|
||||||
|
# skip dynamo guard.
|
||||||
|
self.model_runner.reset_dynamo_cache()
|
||||||
|
|
||||||
# Get the maximum amount of memory used by the model weights and
|
# Get the maximum amount of memory used by the model weights and
|
||||||
# intermediate activations.
|
# intermediate activations.
|
||||||
m = xm.get_memory_info(self.device)
|
m = xm.get_memory_info(self.device)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user