From a454748544f67a7677a2bf71ab329da0600d34a6 Mon Sep 17 00:00:00 2001 From: Chengji Yao Date: Wed, 9 Apr 2025 17:51:51 -0700 Subject: [PATCH] [TPU][V1] Refine tpu_model_runner to mitigate future recompilation issues (#16275) Signed-off-by: Chengji Yao --- tests/tpu/test_compilation.py | 14 +- tests/v1/tpu/worker/test_tpu_model_runner.py | 25 ++- vllm/v1/sample/tpu/metadata.py | 77 ++++---- vllm/v1/worker/tpu_model_runner.py | 175 +++++++++++-------- 4 files changed, 166 insertions(+), 125 deletions(-) diff --git a/tests/tpu/test_compilation.py b/tests/tpu/test_compilation.py index 2a71f460..06e00187 100644 --- a/tests/tpu/test_compilation.py +++ b/tests/tpu/test_compilation.py @@ -44,7 +44,7 @@ def test_tpu_compilation(): assert generated_text.startswith(answer) compiled_codes = sorted( - glob.glob(os.path.join(temp_dir, "__transformed_code*.py"))) + glob.glob(os.path.join(temp_dir, "__transformed_code*for_forward.py"))) for i, compiled_code in enumerate(compiled_codes): print("{} file: {}".format(i + 1, compiled_code)) @@ -52,15 +52,21 @@ def test_tpu_compilation(): # We should only trigger Dynamo compilation 2 times: # 1. Forward pass without kv_caches # 2. Forward pass with kv_caches - # Check we have 4 compiled codes + # Check we have 2 compiled codes assert len(compiled_codes) == 2 kv_cache_prefix = "kv_cache" attn_prefix = "ragged_paged_attention" + def extract_compiled_index(s): + parts = s.replace(".", "_").split("_") + numbers = [int(part) for part in parts if part.isdigit()] + return numbers[0] + # Check all the compilations are as expected - compiled_fns = sorted( - glob.glob(os.path.join(temp_dir, "__compiled_fn*Captured*.py"))) + compiled_fns = sorted(glob.glob( + os.path.join(temp_dir, "__compiled_fn*Captured*.py")), + key=lambda s: extract_compiled_index(s)) for i, compiled_fn in enumerate(compiled_fns): print("{} file: {}".format(i + 1, compiled_fn)) diff --git a/tests/v1/tpu/worker/test_tpu_model_runner.py b/tests/v1/tpu/worker/test_tpu_model_runner.py index 6b6a91b8..8ea8c890 100644 --- a/tests/v1/tpu/worker/test_tpu_model_runner.py +++ b/tests/v1/tpu/worker/test_tpu_model_runner.py @@ -7,9 +7,9 @@ from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig from vllm.sampling_params import SamplingParams from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) -from vllm.v1.worker.tpu_model_runner import (TPUModelRunner, - _get_padded_token_len, - _get_paddings) +from vllm.v1.worker.tpu_model_runner import ( + TPUModelRunner, _get_padded_num_reqs_with_upper_limit, + _get_padded_token_len, _get_req_paddings, _get_token_paddings) # Mock torch_xla module since it may not be available in the test environments torch_xla_patcher = mock.patch.dict( @@ -296,16 +296,29 @@ def test_update_states_request_unscheduled(model_runner): def test_get_paddings(): min_token_size, max_token_size, padding_gap = 16, 512, 64 expected_paddings = [16, 32, 64, 128, 192, 256, 320, 384, 448, 512] - actual_paddings = _get_paddings(min_token_size, max_token_size, - padding_gap) + actual_paddings = _get_token_paddings(min_token_size, max_token_size, + padding_gap) assert actual_paddings == expected_paddings def test_get_padded_token_len(): min_token_size, max_token_size, padding_gap = 16, 512, 64 - paddings = _get_paddings(min_token_size, max_token_size, padding_gap) + paddings = _get_token_paddings(min_token_size, max_token_size, padding_gap) assert _get_padded_token_len(paddings, 1) == 16 assert _get_padded_token_len(paddings, 16) == 16 assert _get_padded_token_len(paddings, 20) == 32 assert _get_padded_token_len(paddings, 300) == 320 assert _get_padded_token_len(paddings, 512) == 512 + + +def test_get_padded_num_reqs_with_upper_limit(): + assert _get_padded_num_reqs_with_upper_limit(3, 32) == 8 + assert _get_padded_num_reqs_with_upper_limit(9, 32) == 16 + assert _get_padded_num_reqs_with_upper_limit(19, 32) == 32 + assert _get_padded_num_reqs_with_upper_limit(17, 28) == 28 + + +def test_get_req_paddings(): + assert _get_req_paddings(1, 32) == [8, 16, 32] + assert _get_req_paddings(8, 32) == [8, 16, 32] + assert _get_req_paddings(8, 36) == [8, 16, 32, 36] diff --git a/vllm/v1/sample/tpu/metadata.py b/vllm/v1/sample/tpu/metadata.py index 89d3ddf5..10995d67 100644 --- a/vllm/v1/sample/tpu/metadata.py +++ b/vllm/v1/sample/tpu/metadata.py @@ -3,7 +3,6 @@ from dataclasses import dataclass, field from typing import Optional import torch -import torch_xla.core.xla_model as xm from vllm.v1.worker.gpu_input_batch import InputBatch @@ -24,15 +23,15 @@ class TPUSupportedSamplingMetadata: # This class exposes a more xla-friendly interface than SamplingMetadata # on TPU, in particular all arguments should be traceable and no optionals # are allowed, to avoid graph recompilation on Nones. - temperature: torch.Tensor + temperature: torch.Tensor = None - min_p: torch.Tensor + min_p: torch.Tensor = None # Still too slow on forward_native! top_k: torch.Tensor = None top_p: torch.Tensor = None # Greedy sampling flag for compiling single xla graph. - all_greedy: torch.Tensor = None + all_greedy: bool = True # Generator not supported by xla generators: dict[int, @@ -57,64 +56,58 @@ class TPUSupportedSamplingMetadata: allowed_token_ids_mask = None bad_words_token_ids = None - indices_do_sample: torch.Tensor = None @classmethod def from_input_batch( - cls, input_batch: InputBatch, - indices_do_sample: torch.Tensor) -> "TPUSupportedSamplingMetadata": + cls, + input_batch: InputBatch, + padded_num_reqs: int, + xla_device: torch.device, + generate_params_if_all_greedy: bool = False + ) -> "TPUSupportedSamplingMetadata": """ Copy sampling tensors slices from `input_batch` to on device tensors. `InputBatch._make_sampling_metadata` causes recompilation on XLA as it slices dynamic shapes on device tensors. This impl moves the dynamic - ops to CPU and produces tensors of fixed `padded_num_reqs` size. It - also reuses the on-device persistent tensors managed in `input_batch` - to reduce waste. + ops to CPU and produces tensors of fixed `padded_num_reqs` size. - `indices_do_sample` contains the indices to be fed to the Sampler, - normally one per request, here padded to the closest pre-compiled shape - We expect sampling params tensors to be padded to the same fixed shape. - - Eg. 3 requests, tensors padded to 4 - temperature: [0.7, 0.2, 0.9]=>[0.7, 0.2, 0.9, 0.0] - sample indices: [4, 10, 11]=>indices_do_sample: [4, 10, 11, 0] + Args: + input_batch: The input batch containing sampling parameters. + padded_num_reqs: The padded number of requests. + xla_device: The XLA device. + generate_params_if_all_greedy: If True, generate sampling parameters + even if all requests are greedy. this is useful for cases where + we want to pre-compile a graph with sampling parameters, even if + they are not strictly needed for greedy decoding. """ - num_reqs = input_batch.num_reqs - padded_num_reqs = len(indices_do_sample) + # Early return to avoid unnecessary cpu to tpu copy + if (input_batch.all_greedy is True + and generate_params_if_all_greedy is False): + return cls(all_greedy=True) - def copy_slice(cpu_tensor: torch.Tensor, tpu_tensor: torch.Tensor, - fill_val) -> torch.Tensor: - # Copy slice from CPU to corresponding TPU pre-allocated tensor. + num_reqs = input_batch.num_reqs + + def fill_slice(cpu_tensor: torch.Tensor, fill_val) -> torch.Tensor: # Pad value is the default one. cpu_tensor[num_reqs:padded_num_reqs] = fill_val - # Subtle compilation: len(tpu_tensor) must be >= `padded_num_reqs` - tpu_tensor[:padded_num_reqs] = cpu_tensor[:padded_num_reqs] - # NOTE NickLucche The sync CPU-TPU graph we produce here must be - # consistent. We can't have flags to skip copies or we'll end up - # recompiling. - copy_slice(input_batch.temperature_cpu_tensor, input_batch.temperature, + fill_slice(input_batch.temperature_cpu_tensor, DEFAULT_SAMPLING_PARAMS["temperature"]) # TODO Temporarily disabled until sampling options are enabled - # copy_slice(input_batch.top_p_cpu_tensor, input_batch.top_p) - # copy_slice(input_batch.top_k_cpu_tensor, input_batch.top_k) - copy_slice(input_batch.min_p_cpu_tensor, input_batch.min_p, + # fill_slice(input_batch.top_p_cpu_tensor) + # fill_slice(input_batch.top_k_cpu_tensor) + fill_slice(input_batch.min_p_cpu_tensor, DEFAULT_SAMPLING_PARAMS["min_p"]) - xm.mark_step() - xm.wait_device_ops() - # Slice persistent device tensors to a fixed pre-compiled padded shape. return cls( - temperature=input_batch.temperature[:padded_num_reqs], - # Scalar tensor for xla-friendly tracing. - all_greedy=torch.tensor(input_batch.all_greedy, - dtype=torch.bool, - device=input_batch.device), + temperature=input_batch.temperature_cpu_tensor[:padded_num_reqs]. + to(xla_device), + all_greedy=input_batch.all_greedy, # TODO enable more and avoid returning None values top_p=None, # input_batch.top_p[:padded_num_reqs], top_k=None, # input_batch.top_k[:padded_num_reqs], - min_p=input_batch.min_p[:padded_num_reqs], - generators=input_batch.generators, - indices_do_sample=indices_do_sample) + min_p=input_batch.min_p_cpu_tensor[:padded_num_reqs].to( + xla_device), + generators=input_batch.generators) diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index c99c6cb7..773c4264 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -32,7 +32,7 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheSpec, SlidingWindowSpec) from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, - ModelRunnerOutput, SamplerOutput) + ModelRunnerOutput) from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler from vllm.v1.utils import bind_kv_cache @@ -177,10 +177,12 @@ class TPUModelRunner: # Range tensor with values [0 .. self.max_num_tokens - 1]. # Used to initialize positions / context_lens / seq_lens self.arange_np = np.arange(self.max_num_tokens, dtype=np.int32) - self.num_tokens_paddings = _get_paddings( + self.num_tokens_paddings = _get_token_paddings( min_token_size=16, max_token_size=self.max_num_tokens, padding_gap=envs.VLLM_TPU_BUCKET_PADDING_GAP) + self.num_reqs_paddings = _get_req_paddings( + min_req_size=MIN_NUM_SEQS, max_req_size=self.max_num_reqs) def _update_num_xla_graphs(self, case_str): check_comp = self.check_recompilation and not self.enforce_eager @@ -508,7 +510,7 @@ class TPUModelRunner: # Padded to avoid recompiling when `num_reqs` varies. logits_indices = self.query_start_loc_cpu[1:padded_num_reqs + 1] - 1 logits_indices = logits_indices.to(self.device) - return attn_metadata, logits_indices + return attn_metadata, logits_indices, padded_num_reqs def _scatter_placeholders( self, @@ -663,7 +665,8 @@ class TPUModelRunner: mm_embeds = [] # Prepare inputs - attn_metadata, logits_indices = self._prepare_inputs(scheduler_output) + attn_metadata, logits_indices, padded_num_reqs = self._prepare_inputs( + scheduler_output) if self.is_multimodal_model: # NOTE(woosuk): To unify token ids and soft tokens (vision # embeddings), we always use embeddings (rather than token ids) @@ -682,11 +685,6 @@ class TPUModelRunner: input_ids = self.input_ids inputs_embeds = None num_reqs = self.input_batch.num_reqs - # NOTE (NickLucche) here we sync with TPU: sampling params tensors - # are copied to device in chunks of pre-compiled padded shape to - # avoid recompilations. - tpu_sampling_metadata = TPUSupportedSamplingMetadata.\ - from_input_batch(self.input_batch, logits_indices) # Run the decoder with set_forward_context(attn_metadata, self.vllm_config): hidden_states = self.model( @@ -694,6 +692,10 @@ class TPUModelRunner: positions=self.position_ids, inputs_embeds=inputs_embeds, ) + hidden_states = self.select_hidden_states(hidden_states, + logits_indices) + tpu_sampling_metadata = TPUSupportedSamplingMetadata.\ + from_input_batch(self.input_batch, padded_num_reqs, self.device) selected_token_ids = self.sample_from_hidden(hidden_states, tpu_sampling_metadata) # Remove padding on cpu and keep dynamic op outside of xla graph. @@ -857,60 +859,78 @@ class TPUModelRunner: inputs_embeds=inputs_embeds) self._hidden_states_dtype = out.dtype - def capture_model(self) -> None: - """Compile the model.""" - + def _precompile_backbone(self) -> None: logger.info("Compiling the model with different input shapes.") start = time.perf_counter() for num_tokens in self.num_tokens_paddings: logger.info(" -- num_tokens: %d", num_tokens) self._dummy_run(num_tokens) - xm.mark_step() xm.wait_device_ops() end = time.perf_counter() - logger.info("Compilation finished in in %.2f [secs].", end - start) - self._update_num_xla_graphs("model") + self._update_num_xla_graphs("model backbone") + def _precompile_select_hidden_states(self) -> None: + # Compile hidden state selection function for bucketed + # n_tokens x max_num_reqs. Graph is really small so this is fine. + logger.info( + "Compiling select_hidden_states with different input shapes.") + start = time.perf_counter() + hsize = self.model_config.get_hidden_size() + for num_tokens in self.num_tokens_paddings: + dummy_hidden = torch.zeros((num_tokens, hsize), + device=self.device, + dtype=self._hidden_states_dtype) + torch._dynamo.mark_dynamic(dummy_hidden, 0) + for num_reqs in self.num_reqs_paddings: + indices = torch.zeros(num_reqs, + dtype=torch.int32, + device=self.device) + torch._dynamo.mark_dynamic(indices, 0) + self.select_hidden_states(dummy_hidden, indices) + logger.info(" -- num_tokens: %d", num_tokens) + xm.wait_device_ops() + end = time.perf_counter() + logger.info("Compilation finished in in %.2f [secs].", end - start) + self._update_num_xla_graphs("select_hidden_states") + + def _precompile_sample_from_hidden(self) -> None: logger.info("Compiling sampling with different input shapes.") start = time.perf_counter() hsize = self.model_config.get_hidden_size() - device = self.device - # Compile sampling step for different model+sampler outputs in bucketed - # n_tokens x max_num_reqs. Graph is really small so this is fine. - for num_tokens in self.num_tokens_paddings: - num_reqs_to_sample = MIN_NUM_SEQS - dummy_hidden = torch.randn((num_tokens, hsize), - device=device, + for num_reqs in self.num_reqs_paddings: + dummy_hidden = torch.zeros((num_reqs, hsize), + device=self.device, dtype=self._hidden_states_dtype) - # Compile for [8, 16, .., 128,.., `self.max_num_reqs`] - while True: - indices = torch.zeros( - num_reqs_to_sample, - dtype=torch.int32, - device=device, - ) - xm.mark_step() - sampling_meta = TPUSupportedSamplingMetadata.\ - from_input_batch(self.input_batch, indices) - logger.info(" -- num_tokens: %d, num_seqs: %d", num_tokens, - num_reqs_to_sample) - out = self.sample_from_hidden(dummy_hidden, sampling_meta) - out = out.cpu() - # Requests can't be more than tokens. But do compile for the - # next bigger value in case num_tokens uses bucketed padding. - if num_reqs_to_sample >= min(num_tokens, self.max_num_reqs): - break - # Make sure to compile the `max_num_reqs` upper-limit case - num_reqs_to_sample = _get_padded_num_reqs_with_upper_limit( - num_reqs_to_sample + 1, self.max_num_reqs) + # The first dimension of dummy_hidden cannot be mark_dynamic because + # some operations in the sampler require it to be static. + for all_greedy in [False, True]: + generate_params_if_all_greedy = not all_greedy + sampling_metadata = ( + TPUSupportedSamplingMetadata.from_input_batch( + self.input_batch, + num_reqs, + self.device, + generate_params_if_all_greedy, + )) + sampling_metadata.all_greedy = all_greedy + self.sample_from_hidden(dummy_hidden, sampling_metadata) + logger.info(" -- num_seqs: %d", num_reqs) xm.wait_device_ops() end = time.perf_counter() - logger.info("Compilation finished in in %.2f [secs].", end - start) self._update_num_xla_graphs("sampling") + def capture_model(self) -> None: + """ + Precompile all the subgraphs with possible input shapes. + """ + # TODO: precompile encoder + self._precompile_backbone() + self._precompile_select_hidden_states() + self._precompile_sample_from_hidden() + def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize KV cache based on `kv_cache_config`. @@ -962,48 +982,55 @@ class TPUModelRunner: compiled_model.original_code_object) compiled_model.compiled_codes.clear() + @torch.compile(backend="openxla", fullgraph=True, dynamic=False) + def select_hidden_states(self, hidden_states, indices_do_sample): + return hidden_states[indices_do_sample] + + @torch.compile(backend="openxla", fullgraph=True, dynamic=False) def sample_from_hidden( self, - hidden_states: torch.Tensor, + sample_hidden_states: torch.Tensor, sampling_metadata: TPUSupportedSamplingMetadata, ) -> torch.Tensor: """ - Sample with xla-friendly function. This function is to be traced - separately for lighter compilation overhead. - """ - # Tensor `sample_hidden_states` is of fixed pre-compiled size. - sample_hidden_states = \ - hidden_states[sampling_metadata.indices_do_sample] - # SamplingMetadata here for pruning output in LogitsProcessor, disabled. + Sample with xla-friendly function. This function is to be traced + separately from `forward` for lighter compilation overhead. + """ logits = self.model.compute_logits(sample_hidden_states, None) - - def sample( - logits: torch.Tensor, - sampling_metadata: TPUSupportedSamplingMetadata - ) -> SamplerOutput: - sampler_out = self.sampler(logits, sampling_metadata) - return sampler_out - - # Optimized greedy sampling branch, tracing both paths in a single pass - # NOTE all_greedy is a scalar, this is just an optimized if/else. - out_tokens = torch.where( - sampling_metadata.all_greedy, - torch.argmax(logits, dim=-1, keepdim=True), - sample(logits, sampling_metadata).sampled_token_ids) + if sampling_metadata.all_greedy: + out_tokens = torch.argmax(logits, dim=-1, keepdim=True) + else: + out_tokens = self.sampler(logits, + sampling_metadata).sampled_token_ids return out_tokens + def get_multimodal_embeddings(self, *args, **kwargs): + return self.model.get_multimodal_embeddings(*args, **kwargs) -def _get_padded_number(n: int, multiple: int) -> int: - return ((n + multiple - 1) // multiple) * multiple + def get_input_embeddings(self, *args, **kwargs): + return self.model.get_input_embeddings(*args, **kwargs) -def _get_padded_num_reqs_with_upper_limit(x, upper_limit) -> int: +def _get_req_paddings(min_req_size: int, max_req_size: int) -> list[int]: + logger.info("Preparing request paddings:") + # assert min_req_size is power of 2 + assert (min_req_size & (min_req_size - 1) == 0) and min_req_size > 0 + paddings: list = [] + num = max(MIN_NUM_SEQS, min_req_size) + while num <= max_req_size and (len(paddings) == 0 or paddings[-1] != num): + paddings.append(num) + logger.info(" %d", num) + num = _get_padded_num_reqs_with_upper_limit(num + 1, max_req_size) + return paddings + + +def _get_padded_num_reqs_with_upper_limit(x: int, upper_limit: int) -> int: res = MIN_NUM_SEQS if x <= MIN_NUM_SEQS else 1 << (x - 1).bit_length() return min(res, upper_limit) -def _get_paddings(min_token_size: int, max_token_size: int, - padding_gap: int) -> list[int]: +def _get_token_paddings(min_token_size: int, max_token_size: int, + padding_gap: int) -> list[int]: """Generate a list of padding size, starting from min_token_size, ending with a number that can cover max_token_size @@ -1013,18 +1040,20 @@ def _get_paddings(min_token_size: int, max_token_size: int, first increase the size to twice, then increase the padding size by padding_gap. """ + # assert min_token_size is power of 2 + assert (min_token_size & (min_token_size - 1) == 0) and min_token_size > 0 paddings = [] num = min_token_size if padding_gap == 0: - logger.info("Using exponential paddings:") + logger.info("Using exponential token paddings:") while num <= max_token_size: logger.info(" %d", num) paddings.append(num) num *= 2 else: - logger.info("Using incremental paddings:") + logger.info("Using incremental token paddings:") while num <= padding_gap: logger.info(" %d", num) paddings.append(num)