[TPU][V1][Bugfix] Fix w8a8 recompiilation with GSM8K (#15714)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
parent
5b800f0932
commit
da461f3cbf
@ -28,16 +28,14 @@ docker run --privileged --net host --shm-size=16G -it \
|
|||||||
&& echo TEST_3 \
|
&& echo TEST_3 \
|
||||||
&& pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine \
|
&& pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine \
|
||||||
&& echo TEST_4 \
|
&& echo TEST_4 \
|
||||||
&& python3 /workspace/vllm/examples/offline_inference/tpu.py \
|
&& pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
|
||||||
&& echo TEST_5 \
|
&& echo TEST_5 \
|
||||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py \
|
&& python3 /workspace/vllm/examples/offline_inference/tpu.py \
|
||||||
&& echo TEST_6 \
|
&& echo TEST_6 \
|
||||||
|
&& pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py \
|
||||||
|
&& echo TEST_7 \
|
||||||
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py" \
|
&& pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py" \
|
||||||
|
|
||||||
|
|
||||||
# TODO: This test fails because it uses RANDOM_SEED sampling
|
# TODO: This test fails because it uses RANDOM_SEED sampling
|
||||||
# && VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \
|
# && VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \
|
||||||
|
|
||||||
# TODO: Re-enable this after fixing recompilation in quantization.
|
|
||||||
# && echo TEST_4 \
|
|
||||||
# && pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \
|
|
||||||
|
@ -97,7 +97,8 @@ class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
|
|||||||
block_size=-1,
|
block_size=-1,
|
||||||
int4_weight=False,
|
int4_weight=False,
|
||||||
quantize_activation=True)
|
quantize_activation=True)
|
||||||
|
# `quantized_matmul` output is fp32, cast it down to bf16 for perf
|
||||||
|
out = out.to(x.dtype)
|
||||||
# Explicitly capture control flow to make dynamo happy.
|
# Explicitly capture control flow to make dynamo happy.
|
||||||
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
|
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
|
||||||
return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])
|
return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])
|
||||||
|
@ -80,6 +80,7 @@ class TPUModelRunner:
|
|||||||
self.enforce_eager = model_config.enforce_eager
|
self.enforce_eager = model_config.enforce_eager
|
||||||
self.pin_memory = is_pin_memory_available()
|
self.pin_memory = is_pin_memory_available()
|
||||||
self.dtype = self.model_config.dtype
|
self.dtype = self.model_config.dtype
|
||||||
|
self._hidden_states_dtype = self.dtype
|
||||||
|
|
||||||
self.is_multimodal_model = model_config.is_multimodal_model
|
self.is_multimodal_model = model_config.is_multimodal_model
|
||||||
self.sliding_window = model_config.get_sliding_window()
|
self.sliding_window = model_config.get_sliding_window()
|
||||||
@ -771,10 +772,11 @@ class TPUModelRunner:
|
|||||||
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
||||||
|
|
||||||
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
||||||
self.model(input_ids=input_ids,
|
out = self.model(input_ids=input_ids,
|
||||||
positions=position_ids,
|
positions=position_ids,
|
||||||
kv_caches=kv_caches,
|
kv_caches=kv_caches,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds)
|
||||||
|
self._hidden_states_dtype = out.dtype
|
||||||
|
|
||||||
def capture_model(self) -> None:
|
def capture_model(self) -> None:
|
||||||
"""Compile the model."""
|
"""Compile the model."""
|
||||||
@ -800,7 +802,7 @@ class TPUModelRunner:
|
|||||||
num_reqs_to_sample = MIN_NUM_SEQS
|
num_reqs_to_sample = MIN_NUM_SEQS
|
||||||
dummy_hidden = torch.randn((num_tokens, hsize),
|
dummy_hidden = torch.randn((num_tokens, hsize),
|
||||||
device=device,
|
device=device,
|
||||||
dtype=torch.bfloat16)
|
dtype=self._hidden_states_dtype)
|
||||||
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
|
# Compile for [8, 16, .., 128,.., `self.max_num_reqs`]
|
||||||
while True:
|
while True:
|
||||||
indices = torch.zeros(
|
indices = torch.zeros(
|
||||||
@ -823,7 +825,7 @@ class TPUModelRunner:
|
|||||||
num_reqs_to_sample + 1, self.max_num_reqs)
|
num_reqs_to_sample + 1, self.max_num_reqs)
|
||||||
xm.wait_device_ops()
|
xm.wait_device_ops()
|
||||||
end = time.perf_counter()
|
end = time.perf_counter()
|
||||||
logger.info("Compilation finished in in %.2f [secs].", end - start)
|
logger.info("Compilation finished in %.2f [secs].", end - start)
|
||||||
# Record the number cached XLA graph after warming up, this will be
|
# Record the number cached XLA graph after warming up, this will be
|
||||||
# used for checking there is no additional graph compilation during
|
# used for checking there is no additional graph compilation during
|
||||||
# runtime execution.
|
# runtime execution.
|
||||||
|
@ -105,8 +105,8 @@ class TPUWorker:
|
|||||||
|
|
||||||
# Increase the cache size limit, which is the maximum number of
|
# Increase the cache size limit, which is the maximum number of
|
||||||
# dynamo graphs that can be compiled.
|
# dynamo graphs that can be compiled.
|
||||||
# NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and
|
# TODO (NickLucche) On gsm we compile 80+ graphs.
|
||||||
# 30-40 graphs for decode. 128 is an arbitrary safe number.
|
# Re-evaluate limit, with MM we may get close to this limit.
|
||||||
torch._dynamo.config.cache_size_limit = 128
|
torch._dynamo.config.cache_size_limit = 128
|
||||||
# Use persistent cache to avoid XLA recompilation.
|
# Use persistent cache to avoid XLA recompilation.
|
||||||
# NOTE(woosuk): Set per-rank cache path since different ranks
|
# NOTE(woosuk): Set per-rank cache path since different ranks
|
||||||
|
Loading…
x
Reference in New Issue
Block a user