[Misc] Simplify code and fix type annotations in conftest.py (#5118)

This commit is contained in:
Cyrus Leung 2024-06-03 07:05:50 +08:00 committed by GitHub
parent a66cf40b20
commit dfbe60dc62
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple
import pytest import pytest
import torch import torch
import torch.nn.functional as F
from PIL import Image from PIL import Image
from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer, from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
LlavaConfig, LlavaForConditionalGeneration) LlavaConfig, LlavaForConditionalGeneration)
@ -12,9 +13,9 @@ from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
from vllm.distributed import destroy_model_parallel from vllm.distributed import destroy_model_parallel
from vllm.inputs import PromptInputs from vllm.inputs import TextPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import MultiModalData from vllm.sequence import MultiModalData, SampleLogprobs
logger = init_logger(__name__) logger = init_logger(__name__)
@ -188,10 +189,11 @@ class HfRunner:
prompts: List[str], prompts: List[str],
images: Optional[List[Image.Image]] = None, images: Optional[List[Image.Image]] = None,
**kwargs, **kwargs,
) -> List[Tuple[List[int], str]]: ) -> List[Tuple[List[List[int]], List[str]]]:
outputs: List[Tuple[List[int], str]] = []
if images: if images:
assert len(prompts) == len(images) assert len(prompts) == len(images)
outputs: List[Tuple[List[List[int]], List[str]]] = []
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
processor_kwargs: Dict[str, Any] = { processor_kwargs: Dict[str, Any] = {
"text": prompt, "text": prompt,
@ -201,17 +203,13 @@ class HfRunner:
processor_kwargs["images"] = images[i] processor_kwargs["images"] = images[i]
inputs = self.processor(**processor_kwargs) inputs = self.processor(**processor_kwargs)
inputs = {
key: value.cuda() if value is not None else None
for key, value in inputs.items()
}
output_ids = self.model.generate( output_ids = self.model.generate(
**inputs, **inputs.to("cuda"),
use_cache=True, use_cache=True,
**kwargs, **kwargs,
) )
output_str = self.tokenizer.batch_decode( output_str = self.processor.batch_decode(
output_ids, output_ids,
skip_special_tokens=True, skip_special_tokens=True,
clean_up_tokenization_spaces=False, clean_up_tokenization_spaces=False,
@ -224,23 +222,22 @@ class HfRunner:
self, self,
prompts: List[str], prompts: List[str],
max_tokens: int, max_tokens: int,
images: Optional["torch.Tensor"] = None, images: Optional[List[Image.Image]] = None,
) -> List[Tuple[List[int], str]]: ) -> List[Tuple[List[int], str]]:
outputs = self.generate(prompts, outputs = self.generate(prompts,
do_sample=False, do_sample=False,
max_new_tokens=max_tokens, max_new_tokens=max_tokens,
images=images) images=images)
for i in range(len(outputs)):
output_ids, output_str = outputs[i] return [(output_ids[0], output_str[0])
outputs[i] = (output_ids[0], output_str[0]) for output_ids, output_str in outputs]
return outputs
def generate_beam_search( def generate_beam_search(
self, self,
prompts: List[str], prompts: List[str],
beam_width: int, beam_width: int,
max_tokens: int, max_tokens: int,
) -> List[Tuple[List[int], str]]: ) -> List[Tuple[List[List[int]], List[str]]]:
outputs = self.generate(prompts, outputs = self.generate(prompts,
do_sample=False, do_sample=False,
max_new_tokens=max_tokens, max_new_tokens=max_tokens,
@ -282,9 +279,7 @@ class HfRunner:
if self.model.get_output_embeddings().bias is not None: if self.model.get_output_embeddings().bias is not None:
logits += self.model.get_output_embeddings( logits += self.model.get_output_embeddings(
).bias.unsqueeze(0) ).bias.unsqueeze(0)
logprobs = torch.nn.functional.log_softmax(logits, logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
dim=-1,
dtype=torch.float32)
seq_logprobs.append(logprobs) seq_logprobs.append(logprobs)
all_logprobs.append(seq_logprobs) all_logprobs.append(seq_logprobs)
return all_logprobs return all_logprobs
@ -294,10 +289,10 @@ class HfRunner:
prompts: List[str], prompts: List[str],
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
) -> List[Tuple[List[int], str]]: ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
all_logprobs = [] all_logprobs: List[List[Dict[int, float]]] = []
all_output_ids = [] all_output_ids: List[List[int]] = []
all_output_strs = [] all_output_strs: List[str] = []
for prompt in prompts: for prompt in prompts:
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
@ -310,7 +305,7 @@ class HfRunner:
return_dict_in_generate=True, return_dict_in_generate=True,
) )
seq_logprobs = [] seq_logprobs: List[torch.Tensor] = []
for _, hidden_states in enumerate(output.hidden_states): for _, hidden_states in enumerate(output.hidden_states):
last_hidden_states = hidden_states[-1][0] last_hidden_states = hidden_states[-1][0]
logits = torch.matmul( logits = torch.matmul(
@ -321,13 +316,11 @@ class HfRunner:
None) is not None: None) is not None:
logits += self.model.get_output_embeddings( logits += self.model.get_output_embeddings(
).bias.unsqueeze(0) ).bias.unsqueeze(0)
logprobs = torch.nn.functional.log_softmax(logits, logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
dim=-1,
dtype=torch.float32)
seq_logprobs.append(logprobs) seq_logprobs.append(logprobs)
# convert to dict # convert to dict
seq_logprobs_lst = [] seq_logprobs_lst: List[Dict[int, float]] = []
for tok_idx, tok_logprobs in enumerate(seq_logprobs): for tok_idx, tok_logprobs in enumerate(seq_logprobs):
# drop prompt logprobs # drop prompt logprobs
if tok_idx == 0: if tok_idx == 0:
@ -372,13 +365,13 @@ class VllmRunner:
tokenizer_name: Optional[str] = None, tokenizer_name: Optional[str] = None,
# Use smaller max model length, otherwise bigger model cannot run due # Use smaller max model length, otherwise bigger model cannot run due
# to kv cache size limit. # to kv cache size limit.
max_model_len=1024, max_model_len: int = 1024,
dtype: str = "half", dtype: str = "half",
disable_log_stats: bool = True, disable_log_stats: bool = True,
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
block_size: int = 16, block_size: int = 16,
enable_chunked_prefill: bool = False, enable_chunked_prefill: bool = False,
swap_space=4, swap_space: int = 4,
**kwargs, **kwargs,
) -> None: ) -> None:
self.model = LLM( self.model = LLM(
@ -399,32 +392,31 @@ class VllmRunner:
self, self,
prompts: List[str], prompts: List[str],
sampling_params: SamplingParams, sampling_params: SamplingParams,
images: Optional["torch.Tensor"] = None, images: Optional[torch.Tensor] = None,
) -> List[Tuple[List[int], str]]: ) -> List[Tuple[List[List[int]], List[str]]]:
if images is not None: if images is not None:
assert len(prompts) == images.shape[0] assert len(prompts) == len(images)
prompt_inputs: List[PromptInputs] = [] prompt_inputs: List[TextPrompt] = []
for i, prompt in enumerate(prompts): for i, prompt in enumerate(prompts):
image = None if images is None else images[i:i + 1] prompt = TextPrompt(prompt=prompt)
mm_data = None if image is None else MultiModalData( if images is not None:
prompt["multi_modal_data"] = MultiModalData(
type=MultiModalData.Type.IMAGE, type=MultiModalData.Type.IMAGE,
data=image, data=images[i:i + 1],
) )
prompt_inputs.append({ prompt_inputs.append(prompt)
"prompt": prompt,
"multi_modal_data": mm_data,
})
req_outputs = self.model.generate(prompt_inputs, req_outputs = self.model.generate(prompt_inputs,
sampling_params=sampling_params) sampling_params=sampling_params)
outputs = []
outputs: List[Tuple[List[List[int]], List[str]]] = []
for req_output in req_outputs: for req_output in req_outputs:
prompt_str = req_output.prompt prompt_str = req_output.prompt
prompt_ids = req_output.prompt_token_ids prompt_ids = req_output.prompt_token_ids
req_sample_output_ids = [] req_sample_output_ids: List[List[int]] = []
req_sample_output_strs = [] req_sample_output_strs: List[str] = []
for sample in req_output.outputs: for sample in req_output.outputs:
output_str = sample.text output_str = sample.text
output_ids = sample.token_ids output_ids = sample.token_ids
@ -437,12 +429,12 @@ class VllmRunner:
self, self,
prompts: List[str], prompts: List[str],
sampling_params: SamplingParams, sampling_params: SamplingParams,
) -> List[Tuple[List[int], str]]: ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
assert sampling_params.logprobs is not None assert sampling_params.logprobs is not None
req_outputs = self.model.generate(prompts, req_outputs = self.model.generate(prompts,
sampling_params=sampling_params) sampling_params=sampling_params)
outputs = [] outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
for req_output in req_outputs: for req_output in req_outputs:
for sample in req_output.outputs: for sample in req_output.outputs:
output_str = sample.text output_str = sample.text
@ -467,7 +459,7 @@ class VllmRunner:
prompts: List[str], prompts: List[str],
max_tokens: int, max_tokens: int,
num_logprobs: int, num_logprobs: int,
) -> List[Tuple[List[int], str]]: ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
greedy_logprobs_params = SamplingParams(temperature=0.0, greedy_logprobs_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens, max_tokens=max_tokens,
logprobs=num_logprobs) logprobs=num_logprobs)
@ -481,7 +473,7 @@ class VllmRunner:
prompts: List[str], prompts: List[str],
beam_width: int, beam_width: int,
max_tokens: int, max_tokens: int,
) -> List[Tuple[List[int], str]]: ) -> List[Tuple[List[List[int]], List[str]]]:
beam_search_params = SamplingParams(n=beam_width, beam_search_params = SamplingParams(n=beam_width,
use_beam_search=True, use_beam_search=True,
temperature=0.0, temperature=0.0,