[MISC] Consolidate FP8 kv-cache tests (#8131)
This commit is contained in:
parent
d3311562fb
commit
2ad2e5608e
@ -23,7 +23,12 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
|
|||||||
# Run basic model test
|
# Run basic model test
|
||||||
docker exec cpu-test bash -c "
|
docker exec cpu-test bash -c "
|
||||||
pip install pytest matplotlib einops transformers_stream_generator
|
pip install pytest matplotlib einops transformers_stream_generator
|
||||||
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_oot_registration.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported
|
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py \
|
||||||
|
--ignore=tests/models/test_oot_registration.py \
|
||||||
|
--ignore=tests/models/test_registry.py \
|
||||||
|
--ignore=tests/models/test_fp8.py \
|
||||||
|
--ignore=tests/models/test_jamba.py \
|
||||||
|
--ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported
|
||||||
|
|
||||||
# online inference
|
# online inference
|
||||||
docker exec cpu-test bash -c "
|
docker exec cpu-test bash -c "
|
||||||
|
@ -16,18 +16,6 @@ MODELS = [
|
|||||||
"facebook/opt-125m",
|
"facebook/opt-125m",
|
||||||
"meta-llama/Llama-2-7b-hf",
|
"meta-llama/Llama-2-7b-hf",
|
||||||
]
|
]
|
||||||
E5M2_KV_MODELS = [
|
|
||||||
"facebook/opt-125m",
|
|
||||||
"meta-llama/Llama-2-7b-chat-hf",
|
|
||||||
]
|
|
||||||
E4M3_KV_MODELS = [
|
|
||||||
"meta-llama/Llama-2-7b-chat-hf", "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V",
|
|
||||||
"nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
|
|
||||||
]
|
|
||||||
KV_CACHE_QUANTIZATION_PATHS = {
|
|
||||||
"meta-llama/Llama-2-7b-chat-hf":
|
|
||||||
"./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@ -78,10 +66,10 @@ def test_models(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("kv_cache_dtype,model",
|
@pytest.mark.parametrize(
|
||||||
[("fp8_e5m2", m)
|
"kv_cache_dtype,model",
|
||||||
for m in E5M2_KV_MODELS] + [("fp8_e4m3", m)
|
[("fp8_e4m3",
|
||||||
for m in E4M3_KV_MODELS])
|
"nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme")])
|
||||||
# Due to low-precision numerical divergence, we only test logprob of 4 tokens
|
# Due to low-precision numerical divergence, we only test logprob of 4 tokens
|
||||||
@pytest.mark.parametrize("max_tokens", [4])
|
@pytest.mark.parametrize("max_tokens", [4])
|
||||||
@pytest.mark.parametrize("chunked_prefill_token_size", [4, 16])
|
@pytest.mark.parametrize("chunked_prefill_token_size", [4, 16])
|
||||||
@ -104,30 +92,15 @@ def test_models_with_fp8_kv_cache(
|
|||||||
disable_async_output_proc: bool,
|
disable_async_output_proc: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Only checks log probs match between chunked-prefill and
|
Check output logprobs match between no_chunked_prefill and chunked_prefill
|
||||||
non-chunked-prefill version of vLLM model runner.
|
with fp8 kv cache. General fp8 kv-cache tests are covered in test_fp8.py,
|
||||||
|
so here we only check chunked prefill.
|
||||||
This test is used when there is discrepancy in kernels
|
|
||||||
/ numerics (e.g. when using lower-precision types like FP8).
|
|
||||||
"""
|
"""
|
||||||
NUM_LOG_PROBS = 8
|
NUM_LOG_PROBS = 8
|
||||||
|
|
||||||
if model == "facebook/opt-125m":
|
|
||||||
pytest.skip(
|
|
||||||
"#7378: CUDA illegal memory access (undiagnosed) facebook/opt-125m"
|
|
||||||
)
|
|
||||||
if ((model, kv_cache_dtype, chunked_prefill_token_size) == (
|
|
||||||
"nm-testing/Qwen2-1.5B-Instruct-FP8-K-V", "fp8_e4m3", 4)):
|
|
||||||
pytest.skip("flakey test, see: #7874 #8051")
|
|
||||||
|
|
||||||
max_num_seqs = chunked_prefill_token_size
|
max_num_seqs = chunked_prefill_token_size
|
||||||
max_num_batched_tokens = chunked_prefill_token_size
|
max_num_batched_tokens = chunked_prefill_token_size
|
||||||
|
|
||||||
extra_kwargs = {}
|
|
||||||
if model in KV_CACHE_QUANTIZATION_PATHS:
|
|
||||||
extra_kwargs["quantization_param_path"] = KV_CACHE_QUANTIZATION_PATHS[
|
|
||||||
model]
|
|
||||||
|
|
||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
model,
|
model,
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
@ -135,7 +108,6 @@ def test_models_with_fp8_kv_cache(
|
|||||||
max_num_seqs=max_num_seqs,
|
max_num_seqs=max_num_seqs,
|
||||||
kv_cache_dtype=kv_cache_dtype,
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
disable_async_output_proc=disable_async_output_proc,
|
disable_async_output_proc=disable_async_output_proc,
|
||||||
**extra_kwargs,
|
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
no_chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
|
no_chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
example_prompts, max_tokens, NUM_LOG_PROBS)
|
example_prompts, max_tokens, NUM_LOG_PROBS)
|
||||||
@ -149,7 +121,6 @@ def test_models_with_fp8_kv_cache(
|
|||||||
max_num_seqs=max_num_seqs,
|
max_num_seqs=max_num_seqs,
|
||||||
kv_cache_dtype=kv_cache_dtype,
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
disable_async_output_proc=disable_async_output_proc,
|
disable_async_output_proc=disable_async_output_proc,
|
||||||
**extra_kwargs,
|
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
|
chunked_prefill_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
example_prompts, max_tokens, NUM_LOG_PROBS)
|
example_prompts, max_tokens, NUM_LOG_PROBS)
|
||||||
|
@ -3,116 +3,97 @@
|
|||||||
Note: these tests will only pass on L4 GPU.
|
Note: these tests will only pass on L4 GPU.
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
|
from tests.kernels.utils import override_backend_env_variable
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
from tests.quantization.utils import is_quant_method_supported
|
||||||
from vllm import LLM, SamplingParams
|
|
||||||
|
from ..models.utils import check_logprobs_close
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||||
|
|
||||||
MAX_MODEL_LEN = 1024
|
|
||||||
|
|
||||||
MODELS = [
|
|
||||||
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV",
|
|
||||||
"meta-llama/Meta-Llama-3-8B-Instruct",
|
|
||||||
]
|
|
||||||
|
|
||||||
EXPECTED_STRS_MAP = {
|
|
||||||
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV": {
|
|
||||||
"auto": [
|
|
||||||
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
|
|
||||||
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
|
||||||
'Artificial intelligence (AI) and human intelligence (HI) process information in distinct ways, with both',
|
|
||||||
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
|
|
||||||
'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep',
|
|
||||||
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
|
|
||||||
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
|
||||||
'Here are the translations:\n\n**Japanese:** (Haya aki no tori, nemuri no'
|
|
||||||
],
|
|
||||||
"fp8": [
|
|
||||||
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
|
|
||||||
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
|
||||||
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
|
|
||||||
'A neural network is a complex system made up of several basic components that work together to enable it to',
|
|
||||||
'Zeta-5, a highly advanced robot designed for menial labor, had never experienced anything like',
|
|
||||||
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here',
|
|
||||||
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
|
||||||
'Here are the translations:\n\n**Japanese:** (Haya kotori wa mushi o tsuk'
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"meta-llama/Meta-Llama-3-8B-Instruct": {
|
|
||||||
"auto": [
|
|
||||||
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
|
|
||||||
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
|
||||||
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
|
|
||||||
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
|
|
||||||
'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short',
|
|
||||||
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
|
|
||||||
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
|
||||||
'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu'
|
|
||||||
],
|
|
||||||
"fp8": [
|
|
||||||
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
|
|
||||||
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
|
||||||
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
|
|
||||||
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
|
|
||||||
'In the year 2154, robotics engineer Dr. Rachel Kim had spent years perfecting her latest',
|
|
||||||
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
|
|
||||||
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
|
||||||
'Here are the translations:\n\n**Japanese:** (Haya tori, mushi o tsukamu'
|
|
||||||
]
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# This test compares against golden strings for exact match since
|
|
||||||
# there is no baseline implementation to compare against
|
|
||||||
# and is unstable w.r.t specifics of the fp8 implementation or
|
|
||||||
# the hardware being run on.
|
|
||||||
# Disabled to prevent it from breaking the build
|
|
||||||
@pytest.mark.skip(
|
|
||||||
reason=
|
|
||||||
"Prevent unstable test based on golden strings from breaking the build.")
|
|
||||||
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
||||||
reason="fp8 is not supported on this GPU type.")
|
reason="fp8 is not supported on this GPU type.")
|
||||||
@pytest.mark.parametrize("model_name", MODELS)
|
@pytest.mark.parametrize(
|
||||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
|
"kv_cache_dtype,base_model,test_model,scale_path",
|
||||||
def test_models(example_prompts, model_name, kv_cache_dtype) -> None:
|
[
|
||||||
model = LLM(model=model_name,
|
# Test FP8 checkpoint w. fp8_e4m3 kv-cache scaling factors.
|
||||||
max_model_len=MAX_MODEL_LEN,
|
("fp8_e4m3", "meta-llama/Meta-Llama-3-8B-Instruct",
|
||||||
trust_remote_code=True,
|
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV", None),
|
||||||
enforce_eager=True,
|
# Test FP16 checkpoint w. fp8_e5m2 kv-cache.
|
||||||
quantization="fp8",
|
("fp8_e5m2", "meta-llama/Meta-Llama-3-8B-Instruct",
|
||||||
kv_cache_dtype=kv_cache_dtype)
|
"meta-llama/Meta-Llama-3-8B-Instruct", None),
|
||||||
|
# Test FP16 checkpoint w. fp8_e4m3 kv-cache scaling factors in json.
|
||||||
|
("fp8_e4m3", "meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
"meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
"./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json")
|
||||||
|
])
|
||||||
|
# Due to low-precision numerical divergence, we only test logprob of 4 tokens
|
||||||
|
@pytest.mark.parametrize("max_tokens", [4])
|
||||||
|
@pytest.mark.parametrize("enforce_eager", [False, True])
|
||||||
|
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"])
|
||||||
|
# NOTE: Increasing this in this suite will fail CI because we currently cannot
|
||||||
|
# reset distributed env properly. Use a value > 1 just when you test.
|
||||||
|
@pytest.mark.parametrize("tensor_parallel_size", [1])
|
||||||
|
# Due to low-precision numerical divergence, this test is too sensitive for
|
||||||
|
# the async postprocessor
|
||||||
|
@pytest.mark.parametrize("disable_async_output_proc", [True])
|
||||||
|
def test_models(
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
base_model: str,
|
||||||
|
test_model: str,
|
||||||
|
scale_path: Optional[str],
|
||||||
|
max_tokens: int,
|
||||||
|
enforce_eager: bool,
|
||||||
|
backend: str,
|
||||||
|
tensor_parallel_size: int,
|
||||||
|
disable_async_output_proc: bool,
|
||||||
|
monkeypatch,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Only checks log probs match to cover the discrepancy in
|
||||||
|
numerical sensitive kernels.
|
||||||
|
"""
|
||||||
|
override_backend_env_variable(monkeypatch, backend)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
MAX_MODEL_LEN = 1024
|
||||||
formatted_prompts = [
|
NUM_LOG_PROBS = 8
|
||||||
tokenizer.apply_chat_template([{
|
|
||||||
"role": "user",
|
|
||||||
"content": prompt
|
|
||||||
}],
|
|
||||||
tokenize=False,
|
|
||||||
add_generation_prompt=True)
|
|
||||||
for prompt in example_prompts
|
|
||||||
]
|
|
||||||
|
|
||||||
params = SamplingParams(max_tokens=20, temperature=0)
|
with vllm_runner(
|
||||||
generations: List[str] = []
|
base_model,
|
||||||
# Note: these need to be run 1 at a time due to numerical precision,
|
max_model_len=MAX_MODEL_LEN,
|
||||||
# since the expected strs were generated this way.
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
for prompt in formatted_prompts:
|
enforce_eager=enforce_eager,
|
||||||
outputs = model.generate(prompt, params)
|
kv_cache_dtype="auto",
|
||||||
generations.append(outputs[0].outputs[0].text)
|
disable_async_output_proc=disable_async_output_proc,
|
||||||
del model
|
) as vllm_model:
|
||||||
|
baseline_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
|
example_prompts, max_tokens, NUM_LOG_PROBS)
|
||||||
|
|
||||||
print(model_name, kv_cache_dtype, generations)
|
extra_kwargs = {}
|
||||||
expected_strs = EXPECTED_STRS_MAP[model_name][kv_cache_dtype]
|
if scale_path is not None:
|
||||||
for i in range(len(example_prompts)):
|
extra_kwargs["quantization_param_path"] = scale_path
|
||||||
generated_str = generations[i]
|
|
||||||
expected_str = expected_strs[i]
|
with vllm_runner(
|
||||||
assert expected_str == generated_str, (
|
test_model,
|
||||||
f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}")
|
max_model_len=MAX_MODEL_LEN,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
enforce_eager=enforce_eager,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
disable_async_output_proc=disable_async_output_proc,
|
||||||
|
**extra_kwargs,
|
||||||
|
) as vllm_model:
|
||||||
|
test_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
|
example_prompts, max_tokens, NUM_LOG_PROBS)
|
||||||
|
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=baseline_outputs,
|
||||||
|
outputs_1_lst=test_outputs,
|
||||||
|
name_0="fp16_kv_cache",
|
||||||
|
name_1="fp8_kv_cache",
|
||||||
|
)
|
||||||
|
@ -1,96 +0,0 @@
|
|||||||
# flake8: noqa
|
|
||||||
"""Tests fp8 models against ground truth generation
|
|
||||||
This verifies the flashinfer backend with fp8
|
|
||||||
quantization and fp8 KV Cache without scaling
|
|
||||||
factors Note: these tests will only pass on H100 GPU.
|
|
||||||
"""
|
|
||||||
import os
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from transformers import AutoTokenizer
|
|
||||||
|
|
||||||
from tests.quantization.utils import is_quant_method_supported
|
|
||||||
from vllm import LLM, SamplingParams
|
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|
||||||
|
|
||||||
MAX_MODEL_LEN = 1024
|
|
||||||
|
|
||||||
MODELS = [
|
|
||||||
"nm-testing/Meta-Llama-3-8B-Instruct-FP8",
|
|
||||||
]
|
|
||||||
|
|
||||||
EXPECTED_STRS_MAP = {
|
|
||||||
"nm-testing/Meta-Llama-3-8B-Instruct-FP8": {
|
|
||||||
"auto": [
|
|
||||||
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
|
|
||||||
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
|
||||||
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
|
|
||||||
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
|
|
||||||
'In the sterile, metallic halls of the robotics lab, a peculiar phenomenon occurred. Zeta-5',
|
|
||||||
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
|
|
||||||
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
|
||||||
'Here are the translations:\n\n**Japanese:** (Haya aki no tori, mushi o',
|
|
||||||
],
|
|
||||||
"fp8": [
|
|
||||||
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
|
|
||||||
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
|
||||||
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
|
|
||||||
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
|
|
||||||
'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep',
|
|
||||||
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here',
|
|
||||||
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
|
||||||
'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o',
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# This test compares against golden strings for exact match since
|
|
||||||
# there is no baseline implementation to compare against
|
|
||||||
# and is unstable w.r.t specifics of the fp8 implementation or
|
|
||||||
# the hardware being run on.
|
|
||||||
# No assert to prevent it from breaking the build
|
|
||||||
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
|
||||||
reason="fp8 is not supported on this GPU type.")
|
|
||||||
@pytest.mark.parametrize("model_name", MODELS)
|
|
||||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
|
|
||||||
@pytest.mark.parametrize("backend", ["XFORMERS", "FLASHINFER"])
|
|
||||||
def test_models(example_prompts, model_name, kv_cache_dtype, backend) -> None:
|
|
||||||
# Note that the golden strings may not work for FLASHINFER Backend.
|
|
||||||
# The intention is to test the path
|
|
||||||
os.environ["VLLM_ATTENTION_BACKEND"] = backend
|
|
||||||
model = LLM(model=model_name,
|
|
||||||
max_model_len=MAX_MODEL_LEN,
|
|
||||||
trust_remote_code=True,
|
|
||||||
quantization="fp8",
|
|
||||||
kv_cache_dtype=kv_cache_dtype)
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
||||||
formatted_prompts = [
|
|
||||||
tokenizer.apply_chat_template([{
|
|
||||||
"role": "user",
|
|
||||||
"content": prompt
|
|
||||||
}],
|
|
||||||
tokenize=False,
|
|
||||||
add_generation_prompt=True)
|
|
||||||
for prompt in example_prompts
|
|
||||||
]
|
|
||||||
|
|
||||||
params = SamplingParams(max_tokens=20, temperature=0)
|
|
||||||
generations: List[str] = []
|
|
||||||
# Note: these need to be run 1 at a time due to numerical precision,
|
|
||||||
# since the expected strs were generated this way.
|
|
||||||
for prompt in formatted_prompts:
|
|
||||||
outputs = model.generate(prompt, params)
|
|
||||||
generations.append(outputs[0].outputs[0].text)
|
|
||||||
del model
|
|
||||||
|
|
||||||
print(f"Testing: {model_name} with kv_cache_dtype: {kv_cache_dtype}")
|
|
||||||
expected_strs = EXPECTED_STRS_MAP[model_name][kv_cache_dtype]
|
|
||||||
for i in range(len(example_prompts)):
|
|
||||||
generated_str = generations[i]
|
|
||||||
expected_str = expected_strs[i]
|
|
||||||
print(f"generated_str\n: {generated_str}")
|
|
||||||
print(f"expected_str\n: {expected_str}")
|
|
Loading…
x
Reference in New Issue
Block a user