Disable spec-decode + chunked-prefill for draft models with tensor parallelism > 1 (#10136)

Signed-off-by: Sourashis Roy <sroy@roblox.com>
This commit is contained in:
sroy745 2024-11-08 07:56:18 -08:00 committed by GitHub
parent 0535e5fe6c
commit f6778620a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 83 additions and 8 deletions

View File

@ -50,3 +50,49 @@ def test_spec_decode_xfail_spec_max_model_len(test_llm_generator):
with pytest.raises(ValueError, match="cannot be larger than"): with pytest.raises(ValueError, match="cannot be larger than"):
get_output_from_llm_generator(test_llm_generator, prompts, get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params) sampling_params)
@pytest.mark.parametrize("common_llm_kwargs",
[{
"model": "meta-llama/Llama-2-7b-chat-hf",
"speculative_model": "JackFram/llama-68m",
"num_speculative_tokens": 5,
"enable_chunked_prefill": "True",
}])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
{
"tensor_parallel_size": 2,
"speculative_draft_tensor_parallel_size": 2,
},
{
"tensor_parallel_size": 4,
"speculative_draft_tensor_parallel_size": 4,
},
{
"tensor_parallel_size": 8,
"speculative_draft_tensor_parallel_size": 8,
},
])
@pytest.mark.parametrize("test_llm_kwargs", [{}])
@pytest.mark.parametrize("seed", [1])
def test_spec_decode_xfail_chunked_prefill_draft_model_tp_not_one(
test_llm_generator):
"""Verify that speculative decoding fails if chunked prefill is enabled for
draft model with tensor parallelism of more than 1.
"""
output_len = 128
temperature = 0.0
prompts = [
"Hello, my name is",
]
sampling_params = SamplingParams(
max_tokens=output_len,
ignore_eos=True,
temperature=temperature,
)
with pytest.raises(ValueError, match="with tensor parallel size 1"):
get_output_from_llm_generator(test_llm_generator, prompts,
sampling_params)

View File

@ -1388,6 +1388,23 @@ class SpeculativeConfig:
"Chunked prefill and hidden-state based draft models are " "Chunked prefill and hidden-state based draft models are "
"not compatible.") "not compatible.")
speculative_draft_tensor_parallel_size = \
SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size(
target_parallel_config,
speculative_draft_tensor_parallel_size,
draft_hf_config
)
if (enable_chunked_prefill and \
speculative_draft_tensor_parallel_size != 1):
# TODO - Investigate why the error reported in
# https://github.com/vllm-project/vllm/pull/9291#issuecomment-2463266258
# is happening and re-enable it.
raise ValueError(
"Chunked prefill and speculative decoding can be enabled "
"simultaneously only for draft models with tensor "
"parallel size 1.")
draft_model_config.max_model_len = ( draft_model_config.max_model_len = (
SpeculativeConfig._maybe_override_draft_max_model_len( SpeculativeConfig._maybe_override_draft_max_model_len(
speculative_max_model_len, speculative_max_model_len,
@ -1466,15 +1483,16 @@ class SpeculativeConfig:
) )
@staticmethod @staticmethod
def create_draft_parallel_config( def _verify_and_get_draft_model_tensor_parallel_size(
target_parallel_config: ParallelConfig, target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: Optional[int], speculative_draft_tensor_parallel_size: Optional[int],
draft_hf_config: PretrainedConfig, draft_hf_config: PretrainedConfig) -> int:
) -> ParallelConfig:
"""Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config, except the tp_size.
""" """
Verifies and adjusts the tensor parallel size for a draft model
specified using speculative_draft_tensor_parallel_size.
"""
# If speculative_draft_tensor_parallel_size is unset then set it
# appropriately else verify that it is set correctly.
if speculative_draft_tensor_parallel_size is None: if speculative_draft_tensor_parallel_size is None:
if draft_hf_config.model_type == "mlp_speculator": if draft_hf_config.model_type == "mlp_speculator":
speculative_draft_tensor_parallel_size = 1 speculative_draft_tensor_parallel_size = 1
@ -1490,7 +1508,18 @@ class SpeculativeConfig:
raise ValueError( raise ValueError(
f"{speculative_draft_tensor_parallel_size=} cannot be " f"{speculative_draft_tensor_parallel_size=} cannot be "
f"other value than 1 or target model tensor_parallel_size") f"other value than 1 or target model tensor_parallel_size")
return speculative_draft_tensor_parallel_size
@staticmethod
def create_draft_parallel_config(
target_parallel_config: ParallelConfig,
speculative_draft_tensor_parallel_size: int,
draft_hf_config: PretrainedConfig,
) -> ParallelConfig:
"""Create a parallel config for use by the draft worker.
This is mostly a copy of the target parallel config, except the tp_size.
"""
draft_parallel_config = ParallelConfig( draft_parallel_config = ParallelConfig(
pipeline_parallel_size=target_parallel_config. pipeline_parallel_size=target_parallel_config.
pipeline_parallel_size, pipeline_parallel_size,