[Misc] Add quantization config support for speculative model. (#7343)

This commit is contained in:
shangmingc 2024-08-16 10:34:28 +08:00 committed by GitHub
parent 9c8e2d1161
commit b67ae00cdb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 71 additions and 4 deletions

View File

@ -42,3 +42,51 @@ def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
max_output_len=output_len,
force_output_len=True,
)
@pytest.mark.parametrize(
"common_llm_kwargs",
[{
"model": "JackFram/llama-160m",
# Skip cuda graph recording for fast test.
"enforce_eager": True,
# Required for spec decode.
"use_v2_block_manager": True,
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"speculative_model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
"num_speculative_tokens": 5,
},
])
@pytest.mark.parametrize(
"test_llm_kwargs",
[
# Explicitly specify draft model quantization
{
"speculative_model_quantization": "gptq",
},
# Explicitly specify GPTQ-based draft model to use marlin quantization
{
"speculative_model_quantization": "marlin",
},
# Not explicitly specify draft model quantization
{
"speculative_model_quantization": None,
},
])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
@pytest.mark.parametrize("batch_size", [2])
@pytest.mark.parametrize("seed", [1])
def test_speculative_model_quantization_config(baseline_llm_generator,
test_llm_generator,
batch_size: int):
"""Verify spec decode works well with draft model quantization configs.
"""
run_greedy_equality_correctness_test(baseline_llm_generator,
test_llm_generator,
batch_size,
max_output_len=32,
force_output_len=True)

View File

@ -961,6 +961,7 @@ class SpeculativeConfig:
target_parallel_config: ParallelConfig,
target_dtype: str,
speculative_model: Optional[str],
speculative_model_quantization: Optional[str],
speculative_draft_tensor_parallel_size: Optional[int],
num_speculative_tokens: Optional[int],
speculative_max_model_len: Optional[int],
@ -989,6 +990,9 @@ class SpeculativeConfig:
target_dtype (str): The data type used for the target model.
speculative_model (Optional[str]): The name of the speculative
model, if provided.
speculative_model_quantization (Optional[str]): Quantization method
that was used to quantize the speculative model weights. If
None, we assume the model weights are not quantized.
speculative_draft_tensor_parallel_size (Optional[int]): The degree
of the tensor parallelism for the draft model.
num_speculative_tokens (Optional[int]): The number of speculative
@ -1056,11 +1060,11 @@ class SpeculativeConfig:
"Speculative decoding requires usage of the V2 "
"block manager. Enable it with --use-v2-block-manager.")
# TODO: The user should be able to specify revision/quantization/max
# model len for the draft model. It is not currently supported.
# TODO: The user should be able to specify revision/max model len
# for the draft model. It is not currently supported.
draft_revision = None
draft_code_revision = None
draft_quantization = None
draft_quantization = speculative_model_quantization
if speculative_model == "[ngram]":
if ngram_prompt_lookup_min is None:
@ -1217,7 +1221,7 @@ class SpeculativeConfig:
elif speculative_draft_tensor_parallel_size != 1:
# TODO(wooyeon): allow tp values larger than 1
raise ValueError(
f"{speculative_draft_tensor_parallel_size=} cannot be"
f"{speculative_draft_tensor_parallel_size=} cannot be "
f"other value than 1")
draft_parallel_config = ParallelConfig(

View File

@ -129,6 +129,7 @@ class EngineArgs:
guided_decoding_backend: str = 'outlines'
# Speculative decoding configuration.
speculative_model: Optional[str] = None
speculative_model_quantization: Optional[str] = None
speculative_draft_tensor_parallel_size: Optional[int] = None
num_speculative_tokens: Optional[int] = None
speculative_max_model_len: Optional[int] = None
@ -571,6 +572,18 @@ class EngineArgs:
default=EngineArgs.speculative_model,
help=
'The name of the draft model to be used in speculative decoding.')
# Quantization settings for speculative model.
parser.add_argument(
'--speculative-model-quantization',
type=nullable_str,
choices=[*QUANTIZATION_METHODS, None],
default=EngineArgs.speculative_model_quantization,
help='Method used to quantize the weights of speculative model.'
'If None, we first check the `quantization_config` '
'attribute in the model config file. If that is '
'None, we assume the model weights are not '
'quantized and use `dtype` to determine the data '
'type of the weights.')
parser.add_argument(
'--num-speculative-tokens',
type=int,
@ -844,6 +857,8 @@ class EngineArgs:
target_parallel_config=parallel_config,
target_dtype=self.dtype,
speculative_model=self.speculative_model,
speculative_model_quantization = \
self.speculative_model_quantization,
speculative_draft_tensor_parallel_size = \
self.speculative_draft_tensor_parallel_size,
num_speculative_tokens=self.num_speculative_tokens,