[Misc] Simplify code and fix type annotations in conftest.py
(#5118)
This commit is contained in:
parent
a66cf40b20
commit
dfbe60dc62
@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
|
||||
LlavaConfig, LlavaForConditionalGeneration)
|
||||
@ -12,9 +13,9 @@ from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
|
||||
from vllm.distributed import destroy_model_parallel
|
||||
from vllm.inputs import PromptInputs
|
||||
from vllm.inputs import TextPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import MultiModalData
|
||||
from vllm.sequence import MultiModalData, SampleLogprobs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -188,10 +189,11 @@ class HfRunner:
|
||||
prompts: List[str],
|
||||
images: Optional[List[Image.Image]] = None,
|
||||
**kwargs,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
outputs: List[Tuple[List[int], str]] = []
|
||||
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||
if images:
|
||||
assert len(prompts) == len(images)
|
||||
|
||||
outputs: List[Tuple[List[List[int]], List[str]]] = []
|
||||
for i, prompt in enumerate(prompts):
|
||||
processor_kwargs: Dict[str, Any] = {
|
||||
"text": prompt,
|
||||
@ -201,17 +203,13 @@ class HfRunner:
|
||||
processor_kwargs["images"] = images[i]
|
||||
|
||||
inputs = self.processor(**processor_kwargs)
|
||||
inputs = {
|
||||
key: value.cuda() if value is not None else None
|
||||
for key, value in inputs.items()
|
||||
}
|
||||
|
||||
output_ids = self.model.generate(
|
||||
**inputs,
|
||||
**inputs.to("cuda"),
|
||||
use_cache=True,
|
||||
**kwargs,
|
||||
)
|
||||
output_str = self.tokenizer.batch_decode(
|
||||
output_str = self.processor.batch_decode(
|
||||
output_ids,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
@ -224,23 +222,22 @@ class HfRunner:
|
||||
self,
|
||||
prompts: List[str],
|
||||
max_tokens: int,
|
||||
images: Optional["torch.Tensor"] = None,
|
||||
images: Optional[List[Image.Image]] = None,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
outputs = self.generate(prompts,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_tokens,
|
||||
images=images)
|
||||
for i in range(len(outputs)):
|
||||
output_ids, output_str = outputs[i]
|
||||
outputs[i] = (output_ids[0], output_str[0])
|
||||
return outputs
|
||||
|
||||
return [(output_ids[0], output_str[0])
|
||||
for output_ids, output_str in outputs]
|
||||
|
||||
def generate_beam_search(
|
||||
self,
|
||||
prompts: List[str],
|
||||
beam_width: int,
|
||||
max_tokens: int,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||
outputs = self.generate(prompts,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_tokens,
|
||||
@ -282,9 +279,7 @@ class HfRunner:
|
||||
if self.model.get_output_embeddings().bias is not None:
|
||||
logits += self.model.get_output_embeddings(
|
||||
).bias.unsqueeze(0)
|
||||
logprobs = torch.nn.functional.log_softmax(logits,
|
||||
dim=-1,
|
||||
dtype=torch.float32)
|
||||
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||
seq_logprobs.append(logprobs)
|
||||
all_logprobs.append(seq_logprobs)
|
||||
return all_logprobs
|
||||
@ -294,10 +289,10 @@ class HfRunner:
|
||||
prompts: List[str],
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
all_logprobs = []
|
||||
all_output_ids = []
|
||||
all_output_strs = []
|
||||
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
|
||||
all_logprobs: List[List[Dict[int, float]]] = []
|
||||
all_output_ids: List[List[int]] = []
|
||||
all_output_strs: List[str] = []
|
||||
|
||||
for prompt in prompts:
|
||||
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
||||
@ -310,7 +305,7 @@ class HfRunner:
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
seq_logprobs = []
|
||||
seq_logprobs: List[torch.Tensor] = []
|
||||
for _, hidden_states in enumerate(output.hidden_states):
|
||||
last_hidden_states = hidden_states[-1][0]
|
||||
logits = torch.matmul(
|
||||
@ -321,13 +316,11 @@ class HfRunner:
|
||||
None) is not None:
|
||||
logits += self.model.get_output_embeddings(
|
||||
).bias.unsqueeze(0)
|
||||
logprobs = torch.nn.functional.log_softmax(logits,
|
||||
dim=-1,
|
||||
dtype=torch.float32)
|
||||
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||
seq_logprobs.append(logprobs)
|
||||
|
||||
# convert to dict
|
||||
seq_logprobs_lst = []
|
||||
seq_logprobs_lst: List[Dict[int, float]] = []
|
||||
for tok_idx, tok_logprobs in enumerate(seq_logprobs):
|
||||
# drop prompt logprobs
|
||||
if tok_idx == 0:
|
||||
@ -372,13 +365,13 @@ class VllmRunner:
|
||||
tokenizer_name: Optional[str] = None,
|
||||
# Use smaller max model length, otherwise bigger model cannot run due
|
||||
# to kv cache size limit.
|
||||
max_model_len=1024,
|
||||
max_model_len: int = 1024,
|
||||
dtype: str = "half",
|
||||
disable_log_stats: bool = True,
|
||||
tensor_parallel_size: int = 1,
|
||||
block_size: int = 16,
|
||||
enable_chunked_prefill: bool = False,
|
||||
swap_space=4,
|
||||
swap_space: int = 4,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.model = LLM(
|
||||
@ -399,32 +392,31 @@ class VllmRunner:
|
||||
self,
|
||||
prompts: List[str],
|
||||
sampling_params: SamplingParams,
|
||||
images: Optional["torch.Tensor"] = None,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
images: Optional[torch.Tensor] = None,
|
||||
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||
if images is not None:
|
||||
assert len(prompts) == images.shape[0]
|
||||
assert len(prompts) == len(images)
|
||||
|
||||
prompt_inputs: List[PromptInputs] = []
|
||||
prompt_inputs: List[TextPrompt] = []
|
||||
for i, prompt in enumerate(prompts):
|
||||
image = None if images is None else images[i:i + 1]
|
||||
mm_data = None if image is None else MultiModalData(
|
||||
prompt = TextPrompt(prompt=prompt)
|
||||
if images is not None:
|
||||
prompt["multi_modal_data"] = MultiModalData(
|
||||
type=MultiModalData.Type.IMAGE,
|
||||
data=image,
|
||||
data=images[i:i + 1],
|
||||
)
|
||||
|
||||
prompt_inputs.append({
|
||||
"prompt": prompt,
|
||||
"multi_modal_data": mm_data,
|
||||
})
|
||||
prompt_inputs.append(prompt)
|
||||
|
||||
req_outputs = self.model.generate(prompt_inputs,
|
||||
sampling_params=sampling_params)
|
||||
outputs = []
|
||||
|
||||
outputs: List[Tuple[List[List[int]], List[str]]] = []
|
||||
for req_output in req_outputs:
|
||||
prompt_str = req_output.prompt
|
||||
prompt_ids = req_output.prompt_token_ids
|
||||
req_sample_output_ids = []
|
||||
req_sample_output_strs = []
|
||||
req_sample_output_ids: List[List[int]] = []
|
||||
req_sample_output_strs: List[str] = []
|
||||
for sample in req_output.outputs:
|
||||
output_str = sample.text
|
||||
output_ids = sample.token_ids
|
||||
@ -437,12 +429,12 @@ class VllmRunner:
|
||||
self,
|
||||
prompts: List[str],
|
||||
sampling_params: SamplingParams,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
||||
assert sampling_params.logprobs is not None
|
||||
|
||||
req_outputs = self.model.generate(prompts,
|
||||
sampling_params=sampling_params)
|
||||
outputs = []
|
||||
outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
|
||||
for req_output in req_outputs:
|
||||
for sample in req_output.outputs:
|
||||
output_str = sample.text
|
||||
@ -467,7 +459,7 @@ class VllmRunner:
|
||||
prompts: List[str],
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
||||
greedy_logprobs_params = SamplingParams(temperature=0.0,
|
||||
max_tokens=max_tokens,
|
||||
logprobs=num_logprobs)
|
||||
@ -481,7 +473,7 @@ class VllmRunner:
|
||||
prompts: List[str],
|
||||
beam_width: int,
|
||||
max_tokens: int,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||
beam_search_params = SamplingParams(n=beam_width,
|
||||
use_beam_search=True,
|
||||
temperature=0.0,
|
||||
|
Loading…
x
Reference in New Issue
Block a user