[Core] *Prompt* logprobs support in Multi-step (#8199)
This commit is contained in:
parent
7c7714d856
commit
a8c1d161a7
@ -20,6 +20,8 @@ from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding,
|
|||||||
BatchFeature)
|
BatchFeature)
|
||||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||||
|
|
||||||
|
from tests.models.utils import (TokensTextLogprobs,
|
||||||
|
TokensTextLogprobsPromptLogprobs)
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.assets.image import ImageAsset
|
from vllm.assets.image import ImageAsset
|
||||||
from vllm.assets.video import VideoAsset
|
from vllm.assets.video import VideoAsset
|
||||||
@ -33,7 +35,6 @@ from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
|
|||||||
to_enc_dec_tuple_list, zip_enc_dec_prompts)
|
to_enc_dec_tuple_list, zip_enc_dec_prompts)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.outputs import RequestOutput
|
from vllm.outputs import RequestOutput
|
||||||
from vllm.sequence import SampleLogprobs
|
|
||||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
|
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
|
||||||
identity, is_cpu)
|
identity, is_cpu)
|
||||||
|
|
||||||
@ -469,7 +470,7 @@ class HfRunner:
|
|||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
videos: Optional[List[np.ndarray]] = None,
|
videos: Optional[List[np.ndarray]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
|
) -> List[TokensTextLogprobs]:
|
||||||
all_logprobs: List[List[Dict[int, float]]] = []
|
all_logprobs: List[List[Dict[int, float]]] = []
|
||||||
all_output_ids: List[List[int]] = []
|
all_output_ids: List[List[int]] = []
|
||||||
all_output_strs: List[str] = []
|
all_output_strs: List[str] = []
|
||||||
@ -525,7 +526,7 @@ class HfRunner:
|
|||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
|
) -> List[TokensTextLogprobs]:
|
||||||
'''
|
'''
|
||||||
Greedy logprobs generation for vLLM encoder/decoder models
|
Greedy logprobs generation for vLLM encoder/decoder models
|
||||||
'''
|
'''
|
||||||
@ -653,14 +654,16 @@ class VllmRunner:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _final_steps_generate_w_logprobs(
|
def _final_steps_generate_w_logprobs(
|
||||||
req_outputs: List[RequestOutput],
|
req_outputs: List[RequestOutput],
|
||||||
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
) -> List[TokensTextLogprobsPromptLogprobs]:
|
||||||
outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
|
outputs: List[TokensTextLogprobsPromptLogprobs] = []
|
||||||
for req_output in req_outputs:
|
for req_output in req_outputs:
|
||||||
|
assert len(req_output.outputs) > 0
|
||||||
for sample in req_output.outputs:
|
for sample in req_output.outputs:
|
||||||
output_str = sample.text
|
output_str = sample.text
|
||||||
output_ids = list(sample.token_ids)
|
output_ids = list(sample.token_ids)
|
||||||
output_logprobs = sample.logprobs
|
output_logprobs = sample.logprobs
|
||||||
outputs.append((output_ids, output_str, output_logprobs))
|
outputs.append((output_ids, output_str, output_logprobs,
|
||||||
|
req_output.prompt_logprobs))
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def generate_w_logprobs(
|
def generate_w_logprobs(
|
||||||
@ -670,7 +673,8 @@ class VllmRunner:
|
|||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
) -> Union[List[TokensTextLogprobs],
|
||||||
|
List[TokensTextLogprobsPromptLogprobs]]:
|
||||||
assert sampling_params.logprobs is not None
|
assert sampling_params.logprobs is not None
|
||||||
|
|
||||||
if images is not None:
|
if images is not None:
|
||||||
@ -695,13 +699,20 @@ class VllmRunner:
|
|||||||
|
|
||||||
req_outputs = self.model.generate(inputs,
|
req_outputs = self.model.generate(inputs,
|
||||||
sampling_params=sampling_params)
|
sampling_params=sampling_params)
|
||||||
return self._final_steps_generate_w_logprobs(req_outputs)
|
|
||||||
|
toks_str_logsprobs_prompt_logprobs = (
|
||||||
|
self._final_steps_generate_w_logprobs(req_outputs))
|
||||||
|
# Omit prompt logprobs if not required by sampling params
|
||||||
|
return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
|
||||||
|
if sampling_params.prompt_logprobs is None else
|
||||||
|
toks_str_logsprobs_prompt_logprobs)
|
||||||
|
|
||||||
def generate_encoder_decoder_w_logprobs(
|
def generate_encoder_decoder_w_logprobs(
|
||||||
self,
|
self,
|
||||||
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
|
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
) -> Union[List[TokensTextLogprobs],
|
||||||
|
List[TokensTextLogprobsPromptLogprobs]]:
|
||||||
'''
|
'''
|
||||||
Logprobs generation for vLLM encoder/decoder models
|
Logprobs generation for vLLM encoder/decoder models
|
||||||
'''
|
'''
|
||||||
@ -709,7 +720,12 @@ class VllmRunner:
|
|||||||
assert sampling_params.logprobs is not None
|
assert sampling_params.logprobs is not None
|
||||||
req_outputs = self.model.generate(encoder_decoder_prompts,
|
req_outputs = self.model.generate(encoder_decoder_prompts,
|
||||||
sampling_params=sampling_params)
|
sampling_params=sampling_params)
|
||||||
return self._final_steps_generate_w_logprobs(req_outputs)
|
toks_str_logsprobs_prompt_logprobs = (
|
||||||
|
self._final_steps_generate_w_logprobs(req_outputs))
|
||||||
|
# Omit prompt logprobs if not required by sampling params
|
||||||
|
return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
|
||||||
|
if sampling_params.prompt_logprobs is None else
|
||||||
|
toks_str_logsprobs_prompt_logprobs)
|
||||||
|
|
||||||
def generate_greedy(
|
def generate_greedy(
|
||||||
self,
|
self,
|
||||||
@ -727,44 +743,48 @@ class VllmRunner:
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
|
num_prompt_logprobs: Optional[int] = None,
|
||||||
images: Optional[PromptImageInput] = None,
|
images: Optional[PromptImageInput] = None,
|
||||||
audios: Optional[PromptAudioInput] = None,
|
audios: Optional[PromptAudioInput] = None,
|
||||||
videos: Optional[PromptVideoInput] = None,
|
videos: Optional[PromptVideoInput] = None,
|
||||||
stop_token_ids: Optional[List[int]] = None,
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
) -> Union[List[TokensTextLogprobs],
|
||||||
greedy_logprobs_params = SamplingParams(temperature=0.0,
|
List[TokensTextLogprobsPromptLogprobs]]:
|
||||||
|
greedy_logprobs_params = SamplingParams(
|
||||||
|
temperature=0.0,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
logprobs=num_logprobs,
|
logprobs=num_logprobs,
|
||||||
|
prompt_logprobs=(num_prompt_logprobs),
|
||||||
stop_token_ids=stop_token_ids)
|
stop_token_ids=stop_token_ids)
|
||||||
outputs = self.generate_w_logprobs(prompts,
|
|
||||||
|
return self.generate_w_logprobs(prompts,
|
||||||
greedy_logprobs_params,
|
greedy_logprobs_params,
|
||||||
images=images,
|
images=images,
|
||||||
audios=audios,
|
audios=audios,
|
||||||
videos=videos)
|
videos=videos)
|
||||||
|
|
||||||
return [(output_ids, output_str, output_logprobs)
|
|
||||||
for output_ids, output_str, output_logprobs in outputs]
|
|
||||||
|
|
||||||
def generate_encoder_decoder_greedy_logprobs(
|
def generate_encoder_decoder_greedy_logprobs(
|
||||||
self,
|
self,
|
||||||
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
|
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
num_prompt_logprobs: Optional[int] = None,
|
||||||
greedy_logprobs_params = SamplingParams(temperature=0.0,
|
) -> Union[List[TokensTextLogprobs],
|
||||||
|
List[TokensTextLogprobsPromptLogprobs]]:
|
||||||
|
greedy_logprobs_params = SamplingParams(
|
||||||
|
temperature=0.0,
|
||||||
use_beam_search=False,
|
use_beam_search=False,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
logprobs=num_logprobs)
|
logprobs=num_logprobs,
|
||||||
|
prompt_logprobs=(num_prompt_logprobs),
|
||||||
|
)
|
||||||
'''
|
'''
|
||||||
Greedy logprobs generation for vLLM encoder/decoder models
|
Greedy logprobs generation for vLLM encoder/decoder models
|
||||||
'''
|
'''
|
||||||
|
|
||||||
outputs = self.generate_encoder_decoder_w_logprobs(
|
return self.generate_encoder_decoder_w_logprobs(
|
||||||
encoder_decoder_prompts, greedy_logprobs_params)
|
encoder_decoder_prompts, greedy_logprobs_params)
|
||||||
|
|
||||||
return [(output_ids, output_str, output_logprobs)
|
|
||||||
for output_ids, output_str, output_logprobs in outputs]
|
|
||||||
|
|
||||||
def generate_beam_search(
|
def generate_beam_search(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
from vllm.sequence import Logprob, SampleLogprobs
|
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
|
||||||
|
|
||||||
TokensText = Tuple[List[int], str]
|
TokensText = Tuple[List[int], str]
|
||||||
|
|
||||||
@ -34,20 +34,47 @@ def check_outputs_equal(
|
|||||||
assert output_ids_0 == output_ids_1, fail_msg
|
assert output_ids_0 == output_ids_1, fail_msg
|
||||||
|
|
||||||
|
|
||||||
|
# Representation of generated sequence as a tuple of
|
||||||
|
# * Token ID list
|
||||||
|
# * String
|
||||||
|
# * List of top sample logprobs for each sampled token
|
||||||
|
#
|
||||||
|
# Assumes prompt logprobs were not requested.
|
||||||
TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int,
|
TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int,
|
||||||
float]],
|
float]],
|
||||||
SampleLogprobs]]]
|
SampleLogprobs]]]
|
||||||
|
|
||||||
# Allow for tokens to be represented as str's rather than IDs
|
# Allow for tokens to be represented as str's rather than IDs;
|
||||||
|
# tuple of
|
||||||
|
# * Token string representations list
|
||||||
|
# * String
|
||||||
|
# * Optional list of top sample logprobs for each sampled token
|
||||||
|
#
|
||||||
|
# Assumes prompt logprobs were not requested.
|
||||||
TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]],
|
TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]],
|
||||||
List[Dict[str,
|
List[Dict[str,
|
||||||
Logprob]]]]]
|
Logprob]]]]]
|
||||||
|
|
||||||
|
# Representation of generated sequence as a tuple of
|
||||||
|
# * Token ID list
|
||||||
|
# * String
|
||||||
|
# * Optional list of top sample logprobs for each sampled token
|
||||||
|
# * Optional list of top prompt logprobs for each prompt token
|
||||||
|
#
|
||||||
|
# Allows prompt logprobs to be requested.
|
||||||
|
TokensTextLogprobsPromptLogprobs = Tuple[
|
||||||
|
List[int], str, Optional[Union[List[Dict[int, float]], SampleLogprobs]],
|
||||||
|
Optional[Union[List[Optional[Dict[int, float]]], PromptLogprobs]]]
|
||||||
|
|
||||||
|
|
||||||
def check_logprobs_close(
|
def check_logprobs_close(
|
||||||
*,
|
*,
|
||||||
outputs_0_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]],
|
outputs_0_lst: Sequence[Union[TokensTextLogprobs,
|
||||||
outputs_1_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]],
|
TokensTextLogprobsPromptLogprobs,
|
||||||
|
TextTextLogprobs]],
|
||||||
|
outputs_1_lst: Sequence[Union[TokensTextLogprobs,
|
||||||
|
TokensTextLogprobsPromptLogprobs,
|
||||||
|
TextTextLogprobs]],
|
||||||
name_0: str,
|
name_0: str,
|
||||||
name_1: str,
|
name_1: str,
|
||||||
num_outputs_0_skip_tokens: int = 0,
|
num_outputs_0_skip_tokens: int = 0,
|
||||||
@ -57,6 +84,18 @@ def check_logprobs_close(
|
|||||||
"""Compare the logprobs of two sequences generated by different models,
|
"""Compare the logprobs of two sequences generated by different models,
|
||||||
which should be similar but not necessarily equal.
|
which should be similar but not necessarily equal.
|
||||||
|
|
||||||
|
How sample logprobs are compared:
|
||||||
|
* `always_check_logprobs == True`: set of highest-logprob token ids
|
||||||
|
must match between seq0 and seq1 at all sampled token offsets
|
||||||
|
* `always_check_logprobs == False`: highest-logprob token ids are
|
||||||
|
only compared at sampled token offsets for which generated token
|
||||||
|
ids don't match
|
||||||
|
|
||||||
|
Prompt logprobs must be provided either for both input sequences, or
|
||||||
|
for neither. If prompt logprobs are provided, then highest-logprob
|
||||||
|
prompt token ids must match between seq0 and seq1 at all prompt token
|
||||||
|
offsets.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
outputs_0_lst: First sequence to compare
|
outputs_0_lst: First sequence to compare
|
||||||
outputs_0_lst: Second sequence to compare
|
outputs_0_lst: Second sequence to compare
|
||||||
@ -78,8 +117,65 @@ def check_logprobs_close(
|
|||||||
for prompt_idx, (outputs_0,
|
for prompt_idx, (outputs_0,
|
||||||
outputs_1) in enumerate(zip(outputs_0_lst,
|
outputs_1) in enumerate(zip(outputs_0_lst,
|
||||||
outputs_1_lst)):
|
outputs_1_lst)):
|
||||||
|
assert len(outputs_0) == len(outputs_1)
|
||||||
|
if len(outputs_0) == 3:
|
||||||
|
assert len(outputs_1) == 3
|
||||||
|
# Break out tokens, text & sample logprobs
|
||||||
|
# (prompt logprobs were not provided)
|
||||||
output_ids_0, output_str_0, logprobs_0 = outputs_0
|
output_ids_0, output_str_0, logprobs_0 = outputs_0
|
||||||
output_ids_1, output_str_1, logprobs_1 = outputs_1
|
output_ids_1, output_str_1, logprobs_1 = outputs_1
|
||||||
|
elif len(outputs_0) == 4:
|
||||||
|
assert len(outputs_1) == 4
|
||||||
|
# Break out tokens, text, sample logprobs & prompt logprobs
|
||||||
|
(
|
||||||
|
output_ids_0,
|
||||||
|
output_str_0,
|
||||||
|
logprobs_0,
|
||||||
|
prompt_logprobs_0,
|
||||||
|
) = outputs_0
|
||||||
|
(
|
||||||
|
output_ids_1,
|
||||||
|
output_str_1,
|
||||||
|
logprobs_1,
|
||||||
|
prompt_logprobs_1,
|
||||||
|
) = outputs_1
|
||||||
|
|
||||||
|
# Test prompt logprobs closeness
|
||||||
|
if (prompt_logprobs_0 is not None
|
||||||
|
and prompt_logprobs_1 is not None):
|
||||||
|
# Both sequences' prompt logprobs lists are not `None``
|
||||||
|
# (although individual list elements may be `None`);
|
||||||
|
# for each token's logprobs:
|
||||||
|
for idx, (logprobs_elem_0, logprobs_elem_1) in enumerate(
|
||||||
|
zip(prompt_logprobs_0, prompt_logprobs_1)):
|
||||||
|
fail_msg = (
|
||||||
|
f"Prompt logprobs test:"
|
||||||
|
f"\n{name_0}:\tPrompt index {idx}\t{logprobs_elem_0}"
|
||||||
|
f"\n{name_1}:\tPrompt index {idx}\t{logprobs_elem_1}")
|
||||||
|
|
||||||
|
if logprobs_elem_0 is None:
|
||||||
|
# If the seq 0 token's logprobs are `None`,
|
||||||
|
# the seq 1 token's logprobs must be `None`
|
||||||
|
assert logprobs_elem_1 is None, fail_msg
|
||||||
|
else:
|
||||||
|
# If the seq 0 token's logprobs are not `None`,
|
||||||
|
# the seq 1 token's logprobs must not be `None`
|
||||||
|
assert logprobs_elem_1 is not None, fail_msg
|
||||||
|
# Logprobs check: top-k token choices must be the same
|
||||||
|
assert (set(logprobs_elem_0.keys()) == set(
|
||||||
|
logprobs_elem_1.keys())), fail_msg
|
||||||
|
else:
|
||||||
|
# Both sequence logprobs lists must be `None`
|
||||||
|
fail_msg = (f"Prompt logprobs test:"
|
||||||
|
f"\n{name_0}:\tlogprobs\t{prompt_logprobs_0}"
|
||||||
|
f"\n{name_1}:\tlogprobs\t{prompt_logprobs_1}")
|
||||||
|
|
||||||
|
assert (prompt_logprobs_0 is None
|
||||||
|
and prompt_logprobs_1 is None), fail_msg
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Outputs tuple must have 3 or 4 elements but "
|
||||||
|
f"{len(outputs_0)} elements were provided: "
|
||||||
|
f"{outputs_0}")
|
||||||
|
|
||||||
if logprobs_0 is None:
|
if logprobs_0 is None:
|
||||||
logprobs_0 = [None] * len(output_ids_0)
|
logprobs_0 = [None] * len(output_ids_0)
|
||||||
|
@ -100,3 +100,95 @@ def test_multi_step_llm(
|
|||||||
name_0="hf",
|
name_0="hf",
|
||||||
name_1="vllm",
|
name_1="vllm",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
|
@pytest.mark.parametrize("tp_size", [1])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [5])
|
||||||
|
@pytest.mark.parametrize("enforce_eager", [True])
|
||||||
|
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
|
||||||
|
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
|
||||||
|
@pytest.mark.parametrize("num_logprobs,num_prompt_logprobs", [(5, 5)])
|
||||||
|
def test_multi_step_llm_w_prompt_logprobs(
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
tp_size: int,
|
||||||
|
max_tokens: int,
|
||||||
|
enforce_eager: int,
|
||||||
|
num_scheduler_steps: int,
|
||||||
|
num_prompts: int,
|
||||||
|
num_logprobs: Optional[int],
|
||||||
|
num_prompt_logprobs: Optional[int],
|
||||||
|
) -> None:
|
||||||
|
"""Test prompt logprobs with multi-step scheduling via sync LLM Engine.
|
||||||
|
|
||||||
|
Set up a vLLM engine instance w/ single-step scheduling as a ground-truth
|
||||||
|
reference.
|
||||||
|
|
||||||
|
Prompt them with the same example prompts.
|
||||||
|
|
||||||
|
Validate:
|
||||||
|
* All generated logprobs are all very close
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hf_runner: HF transformers model runner fixture
|
||||||
|
vllm_runner: vLLM model runner fixture
|
||||||
|
example_prompts: test fixture providing example prompts
|
||||||
|
model: model under test (same for single- and multi-step engines)
|
||||||
|
dtype: tensor datatype for engine to utilize
|
||||||
|
tp_size: degree of tensor-parallelism
|
||||||
|
max_tokens: the maximum number of tokens to generate
|
||||||
|
enforce_eager
|
||||||
|
num_scheduler_steps: for multi-step scheduling, GPU-side steps per
|
||||||
|
GPU -> CPU output transfer
|
||||||
|
num_prompts: number of example prompts under test
|
||||||
|
num_logprobs: corresponds to the `logprobs` argument to the OpenAI
|
||||||
|
completions endpoint; `None` -> no logprobs
|
||||||
|
num_prompt_logprobs: number of logprobs to return for each prompt token;
|
||||||
|
note that this argument is not supported by the
|
||||||
|
OpenAI completions endpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompts = example_prompts
|
||||||
|
if len(prompts) < num_prompts:
|
||||||
|
prompts = prompts * ((num_prompts // len(prompts)) + 1)
|
||||||
|
prompts = prompts[:num_prompts]
|
||||||
|
assert len(prompts) == num_prompts
|
||||||
|
|
||||||
|
with vllm_runner(
|
||||||
|
model,
|
||||||
|
dtype=dtype,
|
||||||
|
enforce_eager=enforce_eager,
|
||||||
|
gpu_memory_utilization=0.7,
|
||||||
|
tensor_parallel_size=tp_size,
|
||||||
|
use_v2_block_manager=True,
|
||||||
|
num_scheduler_steps=num_scheduler_steps,
|
||||||
|
) as vllm_model:
|
||||||
|
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
|
prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs,
|
||||||
|
num_prompt_logprobs=num_prompt_logprobs)
|
||||||
|
|
||||||
|
with vllm_runner(
|
||||||
|
model,
|
||||||
|
dtype=dtype,
|
||||||
|
enforce_eager=enforce_eager,
|
||||||
|
gpu_memory_utilization=0.7,
|
||||||
|
tensor_parallel_size=tp_size,
|
||||||
|
) as vllm_model:
|
||||||
|
single_step_vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
|
prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs,
|
||||||
|
num_prompt_logprobs=num_prompt_logprobs)
|
||||||
|
|
||||||
|
check_logprobs_close(
|
||||||
|
outputs_0_lst=single_step_vllm_outputs,
|
||||||
|
outputs_1_lst=vllm_outputs,
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm",
|
||||||
|
)
|
||||||
|
@ -493,6 +493,7 @@ async def completions_with_server_args(
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
outputs = None
|
outputs = None
|
||||||
|
max_wait_seconds = 240 * 3 # 240 is default
|
||||||
with RemoteOpenAIServer(model_name,
|
with RemoteOpenAIServer(model_name,
|
||||||
server_cli_args,
|
server_cli_args,
|
||||||
max_wait_seconds=max_wait_seconds) as server:
|
max_wait_seconds=max_wait_seconds) as server:
|
||||||
@ -503,7 +504,7 @@ async def completions_with_server_args(
|
|||||||
stream=False,
|
stream=False,
|
||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
logprobs=num_logprobs)
|
logprobs=num_logprobs)
|
||||||
assert outputs is not None
|
assert outputs is not None, "Completion API call failed."
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
@ -614,34 +614,66 @@ def _pythonize_sampler_output(
|
|||||||
|
|
||||||
frozen_model_input = model_input.frozen_model_input
|
frozen_model_input = model_input.frozen_model_input
|
||||||
assert frozen_model_input.sampling_metadata is not None
|
assert frozen_model_input.sampling_metadata is not None
|
||||||
|
sampling_metadata = frozen_model_input.sampling_metadata
|
||||||
# samples generation should have been skipped
|
# samples generation should have been skipped
|
||||||
assert not output.outputs
|
assert not output.outputs
|
||||||
|
|
||||||
pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries]
|
pinned_buffer = pinned_sampled_token_buffer[:model_input.num_queries]
|
||||||
|
|
||||||
|
# We guarantee output tensors are ready, so it is safe to
|
||||||
|
# pythonize the sampler output & obtain CPU-side logprobs.
|
||||||
|
#
|
||||||
|
# However we should check whether logprobs pythonization may
|
||||||
|
# be skipped entirely, i.e. because no logprobs were requested
|
||||||
|
# or pythonization was not deferred. To that end,
|
||||||
|
#
|
||||||
|
# * `prompt_logprobs_are_requested_for_prefill` signals that
|
||||||
|
# there are *any* prefill-phase requests which specify that
|
||||||
|
# prompt logprobs should be returned.
|
||||||
|
#
|
||||||
|
# * `any_logprobs_are_requested` signals that there are any
|
||||||
|
# requests which (1) specify that sample logprobs should be
|
||||||
|
# returned, or (2) are in the prefill phase AND specify that
|
||||||
|
# prompt logprobs should be returned.
|
||||||
|
#
|
||||||
|
# Later on, these flags cause adjustments to the pythonization
|
||||||
|
# process to accommodate logprobs.
|
||||||
|
|
||||||
|
seq_groups = sampling_metadata.seq_groups
|
||||||
|
prompt_logprobs_are_requested_for_prefill = any([
|
||||||
|
sg.sampling_params.prompt_logprobs is not None and sg.is_prompt
|
||||||
|
for sg in seq_groups
|
||||||
|
])
|
||||||
|
any_logprobs_are_requested = (
|
||||||
|
prompt_logprobs_are_requested_for_prefill
|
||||||
|
or any([sg.sampling_params.logprobs is not None for sg in seq_groups]))
|
||||||
|
|
||||||
|
if prompt_logprobs_are_requested_for_prefill:
|
||||||
|
# CPU GPU sync, after gathering *only* sampled tokens (since
|
||||||
|
# requesting prompt logprobs leads `sampled_token_ids` to
|
||||||
|
# include prompt token ids in addition to sampled token ids.)
|
||||||
|
sample_idx_tensor = torch.tensor(
|
||||||
|
[sdx for sg in seq_groups for sdx in sg.sample_indices])
|
||||||
|
pinned_buffer = pinned_buffer.copy_(
|
||||||
|
sampled_token_ids[sample_idx_tensor, :], non_blocking=False)
|
||||||
|
else:
|
||||||
# CPU GPU sync
|
# CPU GPU sync
|
||||||
pinned_buffer = pinned_buffer.copy_(sampled_token_ids, non_blocking=False)
|
pinned_buffer = pinned_buffer.copy_(sampled_token_ids,
|
||||||
|
non_blocking=False)
|
||||||
|
|
||||||
# this will not block as the tensors are already on CPU
|
# this will not block as the tensors are already on CPU
|
||||||
samples_list = pinned_buffer.tolist()
|
samples_list = pinned_buffer.tolist()
|
||||||
|
|
||||||
sampling_metadata = frozen_model_input.sampling_metadata
|
|
||||||
|
|
||||||
skip_sampler_cpu_output = (
|
skip_sampler_cpu_output = (
|
||||||
frozen_model_input.sampling_metadata.skip_sampler_cpu_output)
|
frozen_model_input.sampling_metadata.skip_sampler_cpu_output)
|
||||||
|
|
||||||
# We are guaranteed output tensors are ready, so it is safe to
|
# *Don't* skip logprobs pythonization *if*:
|
||||||
# pythonize the sampler output & obtain CPU-side logprobs.
|
# * Any requests require logprobs to be returned in this
|
||||||
#
|
# iteration AND
|
||||||
# However this computation may be skipped entirely
|
# * These requests are being scheduled in a fashion which
|
||||||
# if no pythonization was deferred.
|
# defers pythonization (i.e. multi-step scheduling.)
|
||||||
seq_groups = sampling_metadata.seq_groups
|
|
||||||
logprobs_are_requested = any([
|
|
||||||
sg.sampling_params.logprobs is not None
|
|
||||||
or sg.sampling_params.prompt_logprobs is not None for sg in seq_groups
|
|
||||||
])
|
|
||||||
do_pythonize_logprobs = (skip_sampler_cpu_output
|
do_pythonize_logprobs = (skip_sampler_cpu_output
|
||||||
and logprobs_are_requested)
|
and any_logprobs_are_requested)
|
||||||
(
|
(
|
||||||
prompt_logprobs,
|
prompt_logprobs,
|
||||||
sample_logprobs,
|
sample_logprobs,
|
||||||
@ -666,7 +698,7 @@ def _pythonize_sampler_output(
|
|||||||
prompt_logprobs[sgdx],
|
prompt_logprobs[sgdx],
|
||||||
sample_logprobs[sgdx],
|
sample_logprobs[sgdx],
|
||||||
)
|
)
|
||||||
elif logprobs_are_requested:
|
elif any_logprobs_are_requested:
|
||||||
(
|
(
|
||||||
group_prompt_logprobs,
|
group_prompt_logprobs,
|
||||||
group_sample_logprobs,
|
group_sample_logprobs,
|
||||||
@ -696,7 +728,7 @@ def _pythonize_sampler_output(
|
|||||||
seq_output.parent_seq_id = seq_ids[parent_id]
|
seq_output.parent_seq_id = seq_ids[parent_id]
|
||||||
seq_output.output_token = next_token_id
|
seq_output.output_token = next_token_id
|
||||||
|
|
||||||
if logprobs_are_requested:
|
if any_logprobs_are_requested:
|
||||||
seq_output.logprobs = group_sample_logprobs[tdx]
|
seq_output.logprobs = group_sample_logprobs[tdx]
|
||||||
else:
|
else:
|
||||||
logprobs = next(iter(seq_output.logprobs.values()))
|
logprobs = next(iter(seq_output.logprobs.values()))
|
||||||
@ -714,7 +746,7 @@ def _pythonize_sampler_output(
|
|||||||
seq_outputs.append(
|
seq_outputs.append(
|
||||||
SequenceOutput(seq_ids[parent_id], next_token_id,
|
SequenceOutput(seq_ids[parent_id], next_token_id,
|
||||||
(group_sample_logprobs[tdx]
|
(group_sample_logprobs[tdx]
|
||||||
if logprobs_are_requested else {
|
if any_logprobs_are_requested else {
|
||||||
next_token_id:
|
next_token_id:
|
||||||
Logprob(logprob=float('inf'),
|
Logprob(logprob=float('inf'),
|
||||||
rank=None,
|
rank=None,
|
||||||
@ -722,12 +754,12 @@ def _pythonize_sampler_output(
|
|||||||
})))
|
})))
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
completion_seq_group_output.prompt_logprobs = \
|
completion_seq_group_output.prompt_logprobs = \
|
||||||
group_prompt_logprobs if logprobs_are_requested else None
|
group_prompt_logprobs if any_logprobs_are_requested else None
|
||||||
output.outputs.append(completion_seq_group_output)
|
output.outputs.append(completion_seq_group_output)
|
||||||
else:
|
else:
|
||||||
output.outputs.append(
|
output.outputs.append(
|
||||||
CompletionSequenceGroupOutput(
|
CompletionSequenceGroupOutput(
|
||||||
seq_outputs, (group_prompt_logprobs
|
seq_outputs, (group_prompt_logprobs
|
||||||
if logprobs_are_requested else None)))
|
if any_logprobs_are_requested else None)))
|
||||||
|
|
||||||
assert len(output.outputs) > 0
|
assert len(output.outputs) > 0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user