[TPU][V1][Bugfix] Fix w8a8 recompiilation with GSM8K (#15714)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
5b800f0932
commit
da461f3cbf
@ -28,16 +28,14 @@ docker run --privileged --net host --shm-size=16G -it \
|
||||
&& echo TEST_3 \
|
||||
&& pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine \
|
||||
&& echo TEST_4 \
|
||||
&& python3 /workspace/vllm/examples/offline_inference/tpu.py \
|
||||
&& pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
|
||||
&& echo TEST_5 \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py \
|
||||
&& python3 /workspace/vllm/examples/offline_inference/tpu.py \
|
||||
&& echo TEST_6 \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py \
|
||||
&& echo TEST_7 \
|
||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py" \
|
||||
|
||||
|
||||
# TODO: This test fails because it uses RANDOM_SEED sampling
|
||||
# && VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \
|
||||
|
||||
# TODO: Re-enable this after fixing recompilation in quantization.
|
||||
# && echo TEST_4 \
|
||||
# && pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
|
||||
|
@ -97,7 +97,8 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
||||
block_size=-1,
|
||||
int4_weight=False,
|
||||
quantize_activation=True)
|
||||
|
||||
# `quantized_matmul` output is fp32, cast it down to bf16 for perf
|
||||
out = out.to(x.dtype)
|
||||
# Explicitly capture control flow to make dynamo happy.
|
||||
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
|
||||
return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])
|
||||
|
@ -80,6 +80,7 @@ class TPUModelRunner:
|
||||
self.enforce_eager = model_config.enforce_eager
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
self.dtype = self.model_config.dtype
|
||||
self._hidden_states_dtype = self.dtype
|
||||
|
||||
self.is_multimodal_model = model_config.is_multimodal_model
|
||||
self.sliding_window = model_config.get_sliding_window()
|
||||
@ -771,10 +772,11 @@ class TPUModelRunner:
|
||||
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
||||
|
||||
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
||||
self.model(input_ids=input_ids,
|
||||
out = self.model(input_ids=input_ids,
|
||||
positions=position_ids,
|
||||
kv_caches=kv_caches,
|
||||
inputs_embeds=inputs_embeds)
|
||||
self._hidden_states_dtype = out.dtype
|
||||
|
||||
def capture_model(self) -> None:
|
||||
"""Compile the model."""
|
||||
@ -800,7 +802,7 @@ class TPUModelRunner:
|
||||
num_reqs_to_sample = MIN_NUM_SEQS
|
||||
dummy_hidden = torch.randn((num_tokens, hsize),
|
||||
device=device,
|
||||
dtype=torch.bfloat16)
|
||||
dtype=self._hidden_states_dtype)
|
||||
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
|
||||
while True:
|
||||
indices = torch.zeros(
|
||||
@ -823,7 +825,7 @@ class TPUModelRunner:
|
||||
num_reqs_to_sample + 1, self.max_num_reqs)
|
||||
xm.wait_device_ops()
|
||||
end = time.perf_counter()
|
||||
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
||||
logger.info("Compilation finished in %.2f [secs].", end - start)
|
||||
# Record the number cached XLA graph after warming up, this will be
|
||||
# used for checking there is no additional graph compilation during
|
||||
# runtime execution.
|
||||
|
@ -105,8 +105,8 @@ class TPUWorker:
|
||||
|
||||
# Increase the cache size limit, which is the maximum number of
|
||||
# dynamo graphs that can be compiled.
|
||||
# NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and
|
||||
# 30-40 graphs for decode. 128 is an arbitrary safe number.
|
||||
# TODO (NickLucche) On gsm we compile 80+ graphs.
|
||||
# Re-evaluate limit, with MM we may get close to this limit.
|
||||
torch._dynamo.config.cache_size_limit = 128
|
||||
# Use persistent cache to avoid XLA recompilation.
|
||||
# NOTE(woosuk): Set per-rank cache path since different ranks
|
||||
|
Loading…
x
Reference in New Issue
Block a user