[Misc] Simplify code and fix type annotations in conftest.py
(#5118)
This commit is contained in:
parent
a66cf40b20
commit
dfbe60dc62
@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
|
from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
|
||||||
LlavaConfig, LlavaForConditionalGeneration)
|
LlavaConfig, LlavaForConditionalGeneration)
|
||||||
@ -12,9 +13,9 @@ from transformers import (AutoModelForCausalLM, AutoProcessor, AutoTokenizer,
|
|||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
|
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
|
||||||
from vllm.distributed import destroy_model_parallel
|
from vllm.distributed import destroy_model_parallel
|
||||||
from vllm.inputs import PromptInputs
|
from vllm.inputs import TextPrompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.sequence import MultiModalData
|
from vllm.sequence import MultiModalData, SampleLogprobs
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -188,10 +189,11 @@ class HfRunner:
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
images: Optional[List[Image.Image]] = None,
|
images: Optional[List[Image.Image]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> List[Tuple[List[int], str]]:
|
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||||
outputs: List[Tuple[List[int], str]] = []
|
|
||||||
if images:
|
if images:
|
||||||
assert len(prompts) == len(images)
|
assert len(prompts) == len(images)
|
||||||
|
|
||||||
|
outputs: List[Tuple[List[List[int]], List[str]]] = []
|
||||||
for i, prompt in enumerate(prompts):
|
for i, prompt in enumerate(prompts):
|
||||||
processor_kwargs: Dict[str, Any] = {
|
processor_kwargs: Dict[str, Any] = {
|
||||||
"text": prompt,
|
"text": prompt,
|
||||||
@ -201,17 +203,13 @@ class HfRunner:
|
|||||||
processor_kwargs["images"] = images[i]
|
processor_kwargs["images"] = images[i]
|
||||||
|
|
||||||
inputs = self.processor(**processor_kwargs)
|
inputs = self.processor(**processor_kwargs)
|
||||||
inputs = {
|
|
||||||
key: value.cuda() if value is not None else None
|
|
||||||
for key, value in inputs.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
output_ids = self.model.generate(
|
output_ids = self.model.generate(
|
||||||
**inputs,
|
**inputs.to("cuda"),
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
output_str = self.tokenizer.batch_decode(
|
output_str = self.processor.batch_decode(
|
||||||
output_ids,
|
output_ids,
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=True,
|
||||||
clean_up_tokenization_spaces=False,
|
clean_up_tokenization_spaces=False,
|
||||||
@ -224,23 +222,22 @@ class HfRunner:
|
|||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
images: Optional["torch.Tensor"] = None,
|
images: Optional[List[Image.Image]] = None,
|
||||||
) -> List[Tuple[List[int], str]]:
|
) -> List[Tuple[List[int], str]]:
|
||||||
outputs = self.generate(prompts,
|
outputs = self.generate(prompts,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
max_new_tokens=max_tokens,
|
max_new_tokens=max_tokens,
|
||||||
images=images)
|
images=images)
|
||||||
for i in range(len(outputs)):
|
|
||||||
output_ids, output_str = outputs[i]
|
return [(output_ids[0], output_str[0])
|
||||||
outputs[i] = (output_ids[0], output_str[0])
|
for output_ids, output_str in outputs]
|
||||||
return outputs
|
|
||||||
|
|
||||||
def generate_beam_search(
|
def generate_beam_search(
|
||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
beam_width: int,
|
beam_width: int,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
) -> List[Tuple[List[int], str]]:
|
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||||
outputs = self.generate(prompts,
|
outputs = self.generate(prompts,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
max_new_tokens=max_tokens,
|
max_new_tokens=max_tokens,
|
||||||
@ -282,9 +279,7 @@ class HfRunner:
|
|||||||
if self.model.get_output_embeddings().bias is not None:
|
if self.model.get_output_embeddings().bias is not None:
|
||||||
logits += self.model.get_output_embeddings(
|
logits += self.model.get_output_embeddings(
|
||||||
).bias.unsqueeze(0)
|
).bias.unsqueeze(0)
|
||||||
logprobs = torch.nn.functional.log_softmax(logits,
|
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||||
dim=-1,
|
|
||||||
dtype=torch.float32)
|
|
||||||
seq_logprobs.append(logprobs)
|
seq_logprobs.append(logprobs)
|
||||||
all_logprobs.append(seq_logprobs)
|
all_logprobs.append(seq_logprobs)
|
||||||
return all_logprobs
|
return all_logprobs
|
||||||
@ -294,10 +289,10 @@ class HfRunner:
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
) -> List[Tuple[List[int], str]]:
|
) -> List[Tuple[List[int], str, List[Dict[int, float]]]]:
|
||||||
all_logprobs = []
|
all_logprobs: List[List[Dict[int, float]]] = []
|
||||||
all_output_ids = []
|
all_output_ids: List[List[int]] = []
|
||||||
all_output_strs = []
|
all_output_strs: List[str] = []
|
||||||
|
|
||||||
for prompt in prompts:
|
for prompt in prompts:
|
||||||
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
||||||
@ -310,7 +305,7 @@ class HfRunner:
|
|||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
seq_logprobs = []
|
seq_logprobs: List[torch.Tensor] = []
|
||||||
for _, hidden_states in enumerate(output.hidden_states):
|
for _, hidden_states in enumerate(output.hidden_states):
|
||||||
last_hidden_states = hidden_states[-1][0]
|
last_hidden_states = hidden_states[-1][0]
|
||||||
logits = torch.matmul(
|
logits = torch.matmul(
|
||||||
@ -321,13 +316,11 @@ class HfRunner:
|
|||||||
None) is not None:
|
None) is not None:
|
||||||
logits += self.model.get_output_embeddings(
|
logits += self.model.get_output_embeddings(
|
||||||
).bias.unsqueeze(0)
|
).bias.unsqueeze(0)
|
||||||
logprobs = torch.nn.functional.log_softmax(logits,
|
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
||||||
dim=-1,
|
|
||||||
dtype=torch.float32)
|
|
||||||
seq_logprobs.append(logprobs)
|
seq_logprobs.append(logprobs)
|
||||||
|
|
||||||
# convert to dict
|
# convert to dict
|
||||||
seq_logprobs_lst = []
|
seq_logprobs_lst: List[Dict[int, float]] = []
|
||||||
for tok_idx, tok_logprobs in enumerate(seq_logprobs):
|
for tok_idx, tok_logprobs in enumerate(seq_logprobs):
|
||||||
# drop prompt logprobs
|
# drop prompt logprobs
|
||||||
if tok_idx == 0:
|
if tok_idx == 0:
|
||||||
@ -372,13 +365,13 @@ class VllmRunner:
|
|||||||
tokenizer_name: Optional[str] = None,
|
tokenizer_name: Optional[str] = None,
|
||||||
# Use smaller max model length, otherwise bigger model cannot run due
|
# Use smaller max model length, otherwise bigger model cannot run due
|
||||||
# to kv cache size limit.
|
# to kv cache size limit.
|
||||||
max_model_len=1024,
|
max_model_len: int = 1024,
|
||||||
dtype: str = "half",
|
dtype: str = "half",
|
||||||
disable_log_stats: bool = True,
|
disable_log_stats: bool = True,
|
||||||
tensor_parallel_size: int = 1,
|
tensor_parallel_size: int = 1,
|
||||||
block_size: int = 16,
|
block_size: int = 16,
|
||||||
enable_chunked_prefill: bool = False,
|
enable_chunked_prefill: bool = False,
|
||||||
swap_space=4,
|
swap_space: int = 4,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model = LLM(
|
self.model = LLM(
|
||||||
@ -399,32 +392,31 @@ class VllmRunner:
|
|||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
images: Optional["torch.Tensor"] = None,
|
images: Optional[torch.Tensor] = None,
|
||||||
) -> List[Tuple[List[int], str]]:
|
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||||
if images is not None:
|
if images is not None:
|
||||||
assert len(prompts) == images.shape[0]
|
assert len(prompts) == len(images)
|
||||||
|
|
||||||
prompt_inputs: List[PromptInputs] = []
|
prompt_inputs: List[TextPrompt] = []
|
||||||
for i, prompt in enumerate(prompts):
|
for i, prompt in enumerate(prompts):
|
||||||
image = None if images is None else images[i:i + 1]
|
prompt = TextPrompt(prompt=prompt)
|
||||||
mm_data = None if image is None else MultiModalData(
|
if images is not None:
|
||||||
type=MultiModalData.Type.IMAGE,
|
prompt["multi_modal_data"] = MultiModalData(
|
||||||
data=image,
|
type=MultiModalData.Type.IMAGE,
|
||||||
)
|
data=images[i:i + 1],
|
||||||
|
)
|
||||||
|
|
||||||
prompt_inputs.append({
|
prompt_inputs.append(prompt)
|
||||||
"prompt": prompt,
|
|
||||||
"multi_modal_data": mm_data,
|
|
||||||
})
|
|
||||||
|
|
||||||
req_outputs = self.model.generate(prompt_inputs,
|
req_outputs = self.model.generate(prompt_inputs,
|
||||||
sampling_params=sampling_params)
|
sampling_params=sampling_params)
|
||||||
outputs = []
|
|
||||||
|
outputs: List[Tuple[List[List[int]], List[str]]] = []
|
||||||
for req_output in req_outputs:
|
for req_output in req_outputs:
|
||||||
prompt_str = req_output.prompt
|
prompt_str = req_output.prompt
|
||||||
prompt_ids = req_output.prompt_token_ids
|
prompt_ids = req_output.prompt_token_ids
|
||||||
req_sample_output_ids = []
|
req_sample_output_ids: List[List[int]] = []
|
||||||
req_sample_output_strs = []
|
req_sample_output_strs: List[str] = []
|
||||||
for sample in req_output.outputs:
|
for sample in req_output.outputs:
|
||||||
output_str = sample.text
|
output_str = sample.text
|
||||||
output_ids = sample.token_ids
|
output_ids = sample.token_ids
|
||||||
@ -437,12 +429,12 @@ class VllmRunner:
|
|||||||
self,
|
self,
|
||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
) -> List[Tuple[List[int], str]]:
|
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
||||||
assert sampling_params.logprobs is not None
|
assert sampling_params.logprobs is not None
|
||||||
|
|
||||||
req_outputs = self.model.generate(prompts,
|
req_outputs = self.model.generate(prompts,
|
||||||
sampling_params=sampling_params)
|
sampling_params=sampling_params)
|
||||||
outputs = []
|
outputs: List[Tuple[List[int], str, Optional[SampleLogprobs]]] = []
|
||||||
for req_output in req_outputs:
|
for req_output in req_outputs:
|
||||||
for sample in req_output.outputs:
|
for sample in req_output.outputs:
|
||||||
output_str = sample.text
|
output_str = sample.text
|
||||||
@ -467,7 +459,7 @@ class VllmRunner:
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
num_logprobs: int,
|
num_logprobs: int,
|
||||||
) -> List[Tuple[List[int], str]]:
|
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
|
||||||
greedy_logprobs_params = SamplingParams(temperature=0.0,
|
greedy_logprobs_params = SamplingParams(temperature=0.0,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
logprobs=num_logprobs)
|
logprobs=num_logprobs)
|
||||||
@ -481,7 +473,7 @@ class VllmRunner:
|
|||||||
prompts: List[str],
|
prompts: List[str],
|
||||||
beam_width: int,
|
beam_width: int,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
) -> List[Tuple[List[int], str]]:
|
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||||
beam_search_params = SamplingParams(n=beam_width,
|
beam_search_params = SamplingParams(n=beam_width,
|
||||||
use_beam_search=True,
|
use_beam_search=True,
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user