Disable spec-decode + chunked-prefill for draft models with tensor parallelism > 1 (#10136)
Signed-off-by: Sourashis Roy <sroy@roblox.com>
This commit is contained in:
parent
0535e5fe6c
commit
f6778620a9
@ -50,3 +50,49 @@ def test_spec_decode_xfail_spec_max_model_len(test_llm_generator):
|
|||||||
with pytest.raises(ValueError, match="cannot be larger than"):
|
with pytest.raises(ValueError, match="cannot be larger than"):
|
||||||
get_output_from_llm_generator(test_llm_generator, prompts,
|
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||||
sampling_params)
|
sampling_params)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"model": "meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
"speculative_model": "JackFram/llama-68m",
|
||||||
|
"num_speculative_tokens": 5,
|
||||||
|
"enable_chunked_prefill": "True",
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"tensor_parallel_size": 2,
|
||||||
|
"speculative_draft_tensor_parallel_size": 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_parallel_size": 4,
|
||||||
|
"speculative_draft_tensor_parallel_size": 4,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"tensor_parallel_size": 8,
|
||||||
|
"speculative_draft_tensor_parallel_size": 8,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_spec_decode_xfail_chunked_prefill_draft_model_tp_not_one(
|
||||||
|
test_llm_generator):
|
||||||
|
"""Verify that speculative decoding fails if chunked prefill is enabled for
|
||||||
|
draft model with tensor parallelism of more than 1.
|
||||||
|
"""
|
||||||
|
output_len = 128
|
||||||
|
temperature = 0.0
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
]
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
max_tokens=output_len,
|
||||||
|
ignore_eos=True,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="with tensor parallel size 1"):
|
||||||
|
get_output_from_llm_generator(test_llm_generator, prompts,
|
||||||
|
sampling_params)
|
||||||
|
@ -1388,6 +1388,23 @@ class SpeculativeConfig:
|
|||||||
"Chunked prefill and hidden-state based draft models are "
|
"Chunked prefill and hidden-state based draft models are "
|
||||||
"not compatible.")
|
"not compatible.")
|
||||||
|
|
||||||
|
speculative_draft_tensor_parallel_size = \
|
||||||
|
SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size(
|
||||||
|
target_parallel_config,
|
||||||
|
speculative_draft_tensor_parallel_size,
|
||||||
|
draft_hf_config
|
||||||
|
)
|
||||||
|
|
||||||
|
if (enable_chunked_prefill and \
|
||||||
|
speculative_draft_tensor_parallel_size != 1):
|
||||||
|
# TODO - Investigate why the error reported in
|
||||||
|
# https://github.com/vllm-project/vllm/pull/9291#issuecomment-2463266258
|
||||||
|
# is happening and re-enable it.
|
||||||
|
raise ValueError(
|
||||||
|
"Chunked prefill and speculative decoding can be enabled "
|
||||||
|
"simultaneously only for draft models with tensor "
|
||||||
|
"parallel size 1.")
|
||||||
|
|
||||||
draft_model_config.max_model_len = (
|
draft_model_config.max_model_len = (
|
||||||
SpeculativeConfig._maybe_override_draft_max_model_len(
|
SpeculativeConfig._maybe_override_draft_max_model_len(
|
||||||
speculative_max_model_len,
|
speculative_max_model_len,
|
||||||
@ -1466,15 +1483,16 @@ class SpeculativeConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_draft_parallel_config(
|
def _verify_and_get_draft_model_tensor_parallel_size(
|
||||||
target_parallel_config: ParallelConfig,
|
target_parallel_config: ParallelConfig,
|
||||||
speculative_draft_tensor_parallel_size: Optional[int],
|
speculative_draft_tensor_parallel_size: Optional[int],
|
||||||
draft_hf_config: PretrainedConfig,
|
draft_hf_config: PretrainedConfig) -> int:
|
||||||
) -> ParallelConfig:
|
|
||||||
"""Create a parallel config for use by the draft worker.
|
|
||||||
|
|
||||||
This is mostly a copy of the target parallel config, except the tp_size.
|
|
||||||
"""
|
"""
|
||||||
|
Verifies and adjusts the tensor parallel size for a draft model
|
||||||
|
specified using speculative_draft_tensor_parallel_size.
|
||||||
|
"""
|
||||||
|
# If speculative_draft_tensor_parallel_size is unset then set it
|
||||||
|
# appropriately else verify that it is set correctly.
|
||||||
if speculative_draft_tensor_parallel_size is None:
|
if speculative_draft_tensor_parallel_size is None:
|
||||||
if draft_hf_config.model_type == "mlp_speculator":
|
if draft_hf_config.model_type == "mlp_speculator":
|
||||||
speculative_draft_tensor_parallel_size = 1
|
speculative_draft_tensor_parallel_size = 1
|
||||||
@ -1490,7 +1508,18 @@ class SpeculativeConfig:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{speculative_draft_tensor_parallel_size=} cannot be "
|
f"{speculative_draft_tensor_parallel_size=} cannot be "
|
||||||
f"other value than 1 or target model tensor_parallel_size")
|
f"other value than 1 or target model tensor_parallel_size")
|
||||||
|
return speculative_draft_tensor_parallel_size
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_draft_parallel_config(
|
||||||
|
target_parallel_config: ParallelConfig,
|
||||||
|
speculative_draft_tensor_parallel_size: int,
|
||||||
|
draft_hf_config: PretrainedConfig,
|
||||||
|
) -> ParallelConfig:
|
||||||
|
"""Create a parallel config for use by the draft worker.
|
||||||
|
|
||||||
|
This is mostly a copy of the target parallel config, except the tp_size.
|
||||||
|
"""
|
||||||
draft_parallel_config = ParallelConfig(
|
draft_parallel_config = ParallelConfig(
|
||||||
pipeline_parallel_size=target_parallel_config.
|
pipeline_parallel_size=target_parallel_config.
|
||||||
pipeline_parallel_size,
|
pipeline_parallel_size,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user