[TPU][V1][Bugfix] Fix w8a8 recompiilation with GSM8K (#15714)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi 2025-03-29 05:13:06 +01:00 committed by GitHub
parent 5b800f0932
commit da461f3cbf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 16 additions and 15 deletions

View File

@ -28,16 +28,14 @@ docker run --privileged --net host --shm-size=16G -it \
&& echo TEST_3 \
&& pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine \
&& echo TEST_4 \
&& python3 /workspace/vllm/examples/offline_inference/tpu.py \
&& pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
&& echo TEST_5 \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py \
&& python3 /workspace/vllm/examples/offline_inference/tpu.py \
&& echo TEST_6 \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py \
&& echo TEST_7 \
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py" \
# TODO: This test fails because it uses RANDOM_SEED sampling
# && VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \
# TODO: Re-enable this after fixing recompilation in quantization.
# && echo TEST_4 \
# && pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \

View File

@ -97,7 +97,8 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
block_size=-1,
int4_weight=False,
quantize_activation=True)
# `quantized_matmul` output is fp32, cast it down to bf16 for perf
out = out.to(x.dtype)
# Explicitly capture control flow to make dynamo happy.
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])

View File

@ -80,6 +80,7 @@ class TPUModelRunner:
self.enforce_eager = model_config.enforce_eager
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
self._hidden_states_dtype = self.dtype
self.is_multimodal_model = model_config.is_multimodal_model
self.sliding_window = model_config.get_sliding_window()
@ -771,10 +772,11 @@ class TPUModelRunner:
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
with set_forward_context(attn_metadata, self.vllm_config, 0):
self.model(input_ids=input_ids,
out = self.model(input_ids=input_ids,
positions=position_ids,
kv_caches=kv_caches,
inputs_embeds=inputs_embeds)
self._hidden_states_dtype = out.dtype
def capture_model(self) -> None:
"""Compile the model."""
@ -800,7 +802,7 @@ class TPUModelRunner:
num_reqs_to_sample = MIN_NUM_SEQS
dummy_hidden = torch.randn((num_tokens, hsize),
device=device,
dtype=torch.bfloat16)
dtype=self._hidden_states_dtype)
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
while True:
indices = torch.zeros(
@ -823,7 +825,7 @@ class TPUModelRunner:
num_reqs_to_sample + 1, self.max_num_reqs)
xm.wait_device_ops()
end = time.perf_counter()
logger.info("Compilation finished in in %.2f [secs].", end - start)
logger.info("Compilation finished in %.2f [secs].", end - start)
# Record the number cached XLA graph after warming up, this will be
# used for checking there is no additional graph compilation during
# runtime execution.

View File

@ -105,8 +105,8 @@ class TPUWorker:
# Increase the cache size limit, which is the maximum number of
# dynamo graphs that can be compiled.
# NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and
# 30-40 graphs for decode. 128 is an arbitrary safe number.
# TODO (NickLucche) On gsm we compile 80+ graphs.
# Re-evaluate limit, with MM we may get close to this limit.
torch._dynamo.config.cache_size_limit = 128
# Use persistent cache to avoid XLA recompilation.
# NOTE(woosuk): Set per-rank cache path since different ranks