[Misc] Simplify code and fix type annotations in conftest.py (#5118)

This commit is contained in:
Cyrus Leung 2024-06-03 07:05:50 +08:00 committed by GitHub
parent a66cf40b20
commit dfbe60dc62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple
import pytest
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
LlavaConfig, LlavaForConditionalGeneration)
@ -12,9 +13,9 @@ from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.distributed import destroy_model_parallel
from vllm.inputs import PromptInputs
from vllm.inputs import TextPrompt
from vllm.logger import init_logger
from vllm.sequence import MultiModalData
from vllm.sequence import MultiModalData, SampleLogprobs
logger = init_logger(__name__)
@ -188,10 +189,11 @@ class HfRunner:
prompts: List[str],
images: Optional[List[Image.Image]] = None,
**kwargs,
) -> List[Tuple[List[int], str]]:
outputs: List[Tuple[List[int], str]] = []
) -> List[Tuple[List[List[int]], List[str]]]:
if images:
assert len(prompts) == len(images)
outputs: List[Tuple[List[List[int]], List[str]]] = []
for i, prompt in enumerate(prompts):
processor_kwargs: Dict[str, Any] = {
"text": prompt,
@ -201,17 +203,13 @@ class HfRunner:
processor_kwargs["images"] = images[i]
inputs = self.processor(**processor_kwargs)
inputs = {
key: value.cuda() if value is not None else None
for key, value in inputs.items()
}
output_ids = self.model.generate(
**inputs,
**inputs.to("cuda"),
use_cache=True,
**kwargs,
)
output_str = self.tokenizer.batch_decode(
output_str = self.processor.batch_decode(
output_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
@ -224,23 +222,22 @@ class HfRunner:
self,
prompts: List[str],
max_tokens: int,
images: Optional["torch.Tensor"] = None,
images: Optional[List[Image.Image]] = None,
) -> List[Tuple[List[int], str]]:
outputs = self.generate(prompts,
do_sample=False,
max_new_tokens=max_tokens,
images=images)
for i in range(len(outputs)):
output_ids, output_str = outputs[i]
outputs[i] = (output_ids[0], output_str[0])
return outputs
return [(output_ids[0], output_str[0])
for output_ids, output_str in outputs]
def generate_beam_search(
self,
prompts: List[str],
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[int], str]]:
) -> List[Tuple[List[List[int]], List[str]]]:
outputs = self.generate(prompts,
do_sample=False,
max_new_tokens=max_tokens,
@ -282,9 +279,7 @@ class HfRunner:
if self.model.get_output_embeddings().bias is not None:
logits += self.model.get_output_embeddings(
).bias.unsqueeze(0)
logprobs = torch.nn.functional.log_softmax(logits,
dim=-1,
dtype=torch.float32)
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
seq_logprobs.append(logprobs)
all_logprobs.append(seq_logprobs)
return all_logprobs
@ -294,10 +289,10 @@ class HfRunner:
prompts: List[str],
max_tokens: int,
num_logprobs: int,
) -> List[Tuple[List[int], str]]:
all_logprobs = []
all_output_ids = []
all_output_strs = []
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
all_logprobs: List[List[Dict[int, float]]] = []
all_output_ids: List[List[int]] = []
all_output_strs: List[str] = []
for prompt in prompts:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
@ -310,7 +305,7 @@ class HfRunner:
return_dict_in_generate=True,
)
seq_logprobs = []
seq_logprobs: List[torch.Tensor] = []
for _, hidden_states in enumerate(output.hidden_states):
last_hidden_states = hidden_states[-1][0]
logits = torch.matmul(
@ -321,13 +316,11 @@ class HfRunner:
None) is not None:
logits += self.model.get_output_embeddings(
).bias.unsqueeze(0)
logprobs = torch.nn.functional.log_softmax(logits,
dim=-1,
dtype=torch.float32)
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
seq_logprobs.append(logprobs)
# convert to dict
seq_logprobs_lst = []
seq_logprobs_lst: List[Dict[int, float]] = []
for tok_idx, tok_logprobs in enumerate(seq_logprobs):
# drop prompt logprobs
if tok_idx == 0:
@ -372,13 +365,13 @@ class VllmRunner:
tokenizer_name: Optional[str] = None,
# Use smaller max model length, otherwise bigger model cannot run due
# to kv cache size limit.
max_model_len=1024,
max_model_len: int = 1024,
dtype: str = "half",
disable_log_stats: bool = True,
tensor_parallel_size: int = 1,
block_size: int = 16,
enable_chunked_prefill: bool = False,
swap_space=4,
swap_space: int = 4,
**kwargs,
) -> None:
self.model = LLM(
@ -399,32 +392,31 @@ class VllmRunner:
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional["torch.Tensor"] = None,
) -> List[Tuple[List[int], str]]:
images: Optional[torch.Tensor] = None,
) -> List[Tuple[List[List[int]], List[str]]]:
if images is not None:
assert len(prompts) == images.shape[0]
assert len(prompts) == len(images)
prompt_inputs: List[PromptInputs] = []
prompt_inputs: List[TextPrompt] = []
for i, prompt in enumerate(prompts):
image = None if images is None else images[i:i + 1]
mm_data = None if image is None else MultiModalData(
type=MultiModalData.Type.IMAGE,
data=image,
)
prompt = TextPrompt(prompt=prompt)
if images is not None:
prompt["multi_modal_data"] = MultiModalData(
type=MultiModalData.Type.IMAGE,
data=images[i:i + 1],
)
prompt_inputs.append({
"prompt": prompt,
"multi_modal_data": mm_data,
})
prompt_inputs.append(prompt)
req_outputs = self.model.generate(prompt_inputs,
sampling_params=sampling_params)
outputs = []
outputs: List[Tuple[List[List[int]], List[str]]] = []
for req_output in req_outputs:
prompt_str = req_output.prompt
prompt_ids = req_output.prompt_token_ids
req_sample_output_ids = []
req_sample_output_strs = []
req_sample_output_ids: List[List[int]] = []
req_sample_output_strs: List[str] = []
for sample in req_output.outputs:
output_str = sample.text
output_ids = sample.token_ids
@ -437,12 +429,12 @@ class VllmRunner:
self,
prompts: List[str],
sampling_params: SamplingParams,
) -> List[Tuple[List[int], str]]:
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
assert sampling_params.logprobs is not None
req_outputs = self.model.generate(prompts,
sampling_params=sampling_params)
outputs = []
outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
for req_output in req_outputs:
for sample in req_output.outputs:
output_str = sample.text
@ -467,7 +459,7 @@ class VllmRunner:
prompts: List[str],
max_tokens: int,
num_logprobs: int,
) -> List[Tuple[List[int], str]]:
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
logprobs=num_logprobs)
@ -481,7 +473,7 @@ class VllmRunner:
prompts: List[str],
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[int], str]]:
) -> List[Tuple[List[List[int]], List[str]]]:
beam_search_params = SamplingParams(n=beam_width,
use_beam_search=True,
temperature=0.0,