[Misc] Add quantization config support for speculative model. (#7343)
This commit is contained in:
parent
9c8e2d1161
commit
b67ae00cdb
@ -42,3 +42,51 @@ def test_spec_decode_cuda_graph(baseline_llm_generator, test_llm_generator,
|
|||||||
max_output_len=output_len,
|
max_output_len=output_len,
|
||||||
force_output_len=True,
|
force_output_len=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"model": "JackFram/llama-160m",
|
||||||
|
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"speculative_model": "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_llm_kwargs",
|
||||||
|
[
|
||||||
|
# Explicitly specify draft model quantization
|
||||||
|
{
|
||||||
|
"speculative_model_quantization": "gptq",
|
||||||
|
},
|
||||||
|
# Explicitly specify GPTQ-based draft model to use marlin quantization
|
||||||
|
{
|
||||||
|
"speculative_model_quantization": "marlin",
|
||||||
|
},
|
||||||
|
# Not explicitly specify draft model quantization
|
||||||
|
{
|
||||||
|
"speculative_model_quantization": None,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("batch_size", [2])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_speculative_model_quantization_config(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size: int):
|
||||||
|
"""Verify spec decode works well with draft model quantization configs.
|
||||||
|
"""
|
||||||
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=32,
|
||||||
|
force_output_len=True)
|
||||||
|
@ -961,6 +961,7 @@ class SpeculativeConfig:
|
|||||||
target_parallel_config: ParallelConfig,
|
target_parallel_config: ParallelConfig,
|
||||||
target_dtype: str,
|
target_dtype: str,
|
||||||
speculative_model: Optional[str],
|
speculative_model: Optional[str],
|
||||||
|
speculative_model_quantization: Optional[str],
|
||||||
speculative_draft_tensor_parallel_size: Optional[int],
|
speculative_draft_tensor_parallel_size: Optional[int],
|
||||||
num_speculative_tokens: Optional[int],
|
num_speculative_tokens: Optional[int],
|
||||||
speculative_max_model_len: Optional[int],
|
speculative_max_model_len: Optional[int],
|
||||||
@ -989,6 +990,9 @@ class SpeculativeConfig:
|
|||||||
target_dtype (str): The data type used for the target model.
|
target_dtype (str): The data type used for the target model.
|
||||||
speculative_model (Optional[str]): The name of the speculative
|
speculative_model (Optional[str]): The name of the speculative
|
||||||
model, if provided.
|
model, if provided.
|
||||||
|
speculative_model_quantization (Optional[str]): Quantization method
|
||||||
|
that was used to quantize the speculative model weights. If
|
||||||
|
None, we assume the model weights are not quantized.
|
||||||
speculative_draft_tensor_parallel_size (Optional[int]): The degree
|
speculative_draft_tensor_parallel_size (Optional[int]): The degree
|
||||||
of the tensor parallelism for the draft model.
|
of the tensor parallelism for the draft model.
|
||||||
num_speculative_tokens (Optional[int]): The number of speculative
|
num_speculative_tokens (Optional[int]): The number of speculative
|
||||||
@ -1056,11 +1060,11 @@ class SpeculativeConfig:
|
|||||||
"Speculative decoding requires usage of the V2 "
|
"Speculative decoding requires usage of the V2 "
|
||||||
"block manager. Enable it with --use-v2-block-manager.")
|
"block manager. Enable it with --use-v2-block-manager.")
|
||||||
|
|
||||||
# TODO: The user should be able to specify revision/quantization/max
|
# TODO: The user should be able to specify revision/max model len
|
||||||
# model len for the draft model. It is not currently supported.
|
# for the draft model. It is not currently supported.
|
||||||
draft_revision = None
|
draft_revision = None
|
||||||
draft_code_revision = None
|
draft_code_revision = None
|
||||||
draft_quantization = None
|
draft_quantization = speculative_model_quantization
|
||||||
|
|
||||||
if speculative_model == "[ngram]":
|
if speculative_model == "[ngram]":
|
||||||
if ngram_prompt_lookup_min is None:
|
if ngram_prompt_lookup_min is None:
|
||||||
|
@ -129,6 +129,7 @@ class EngineArgs:
|
|||||||
guided_decoding_backend: str = 'outlines'
|
guided_decoding_backend: str = 'outlines'
|
||||||
# Speculative decoding configuration.
|
# Speculative decoding configuration.
|
||||||
speculative_model: Optional[str] = None
|
speculative_model: Optional[str] = None
|
||||||
|
speculative_model_quantization: Optional[str] = None
|
||||||
speculative_draft_tensor_parallel_size: Optional[int] = None
|
speculative_draft_tensor_parallel_size: Optional[int] = None
|
||||||
num_speculative_tokens: Optional[int] = None
|
num_speculative_tokens: Optional[int] = None
|
||||||
speculative_max_model_len: Optional[int] = None
|
speculative_max_model_len: Optional[int] = None
|
||||||
@ -571,6 +572,18 @@ class EngineArgs:
|
|||||||
default=EngineArgs.speculative_model,
|
default=EngineArgs.speculative_model,
|
||||||
help=
|
help=
|
||||||
'The name of the draft model to be used in speculative decoding.')
|
'The name of the draft model to be used in speculative decoding.')
|
||||||
|
# Quantization settings for speculative model.
|
||||||
|
parser.add_argument(
|
||||||
|
'--speculative-model-quantization',
|
||||||
|
type=nullable_str,
|
||||||
|
choices=[*QUANTIZATION_METHODS, None],
|
||||||
|
default=EngineArgs.speculative_model_quantization,
|
||||||
|
help='Method used to quantize the weights of speculative model.'
|
||||||
|
'If None, we first check the `quantization_config` '
|
||||||
|
'attribute in the model config file. If that is '
|
||||||
|
'None, we assume the model weights are not '
|
||||||
|
'quantized and use `dtype` to determine the data '
|
||||||
|
'type of the weights.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--num-speculative-tokens',
|
'--num-speculative-tokens',
|
||||||
type=int,
|
type=int,
|
||||||
@ -844,6 +857,8 @@ class EngineArgs:
|
|||||||
target_parallel_config=parallel_config,
|
target_parallel_config=parallel_config,
|
||||||
target_dtype=self.dtype,
|
target_dtype=self.dtype,
|
||||||
speculative_model=self.speculative_model,
|
speculative_model=self.speculative_model,
|
||||||
|
speculative_model_quantization = \
|
||||||
|
self.speculative_model_quantization,
|
||||||
speculative_draft_tensor_parallel_size = \
|
speculative_draft_tensor_parallel_size = \
|
||||||
self.speculative_draft_tensor_parallel_size,
|
self.speculative_draft_tensor_parallel_size,
|
||||||
num_speculative_tokens=self.num_speculative_tokens,
|
num_speculative_tokens=self.num_speculative_tokens,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user