[CI] Make mistral tests pass (#4596)
This commit is contained in:
parent
d7740ea4dc
commit
f6a593093a
@ -76,7 +76,7 @@ steps:
|
|||||||
#mirror_hardwares: [amd]
|
#mirror_hardwares: [amd]
|
||||||
commands:
|
commands:
|
||||||
- bash ../.buildkite/download-images.sh
|
- bash ../.buildkite/download-images.sh
|
||||||
- pytest -v -s models --ignore=models/test_llava.py --ignore=models/test_mistral.py
|
- pytest -v -s models --ignore=models/test_llava.py
|
||||||
|
|
||||||
- label: Llava Test
|
- label: Llava Test
|
||||||
#mirror_hardwares: [amd]
|
#mirror_hardwares: [amd]
|
||||||
|
@ -272,6 +272,68 @@ class HfRunner:
|
|||||||
all_logprobs.append(seq_logprobs)
|
all_logprobs.append(seq_logprobs)
|
||||||
return all_logprobs
|
return all_logprobs
|
||||||
|
|
||||||
|
def generate_greedy_logprobs_limit(
|
||||||
|
self,
|
||||||
|
prompts: List[str],
|
||||||
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
|
) -> List[Tuple[List[int], str]]:
|
||||||
|
all_logprobs = []
|
||||||
|
all_output_ids = []
|
||||||
|
all_output_strs = []
|
||||||
|
|
||||||
|
for prompt in prompts:
|
||||||
|
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids
|
||||||
|
output = self.model.generate(
|
||||||
|
input_ids.cuda(),
|
||||||
|
use_cache=True,
|
||||||
|
do_sample=False,
|
||||||
|
max_new_tokens=max_tokens,
|
||||||
|
output_hidden_states=True,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
seq_logprobs = []
|
||||||
|
for _, hidden_states in enumerate(output.hidden_states):
|
||||||
|
last_hidden_states = hidden_states[-1][0]
|
||||||
|
logits = torch.matmul(
|
||||||
|
last_hidden_states,
|
||||||
|
self.model.get_output_embeddings().weight.t(),
|
||||||
|
)
|
||||||
|
if getattr(self.model.get_output_embeddings(), "bias",
|
||||||
|
None) is not None:
|
||||||
|
logits += self.model.get_output_embeddings(
|
||||||
|
).bias.unsqueeze(0)
|
||||||
|
logprobs = torch.nn.functional.log_softmax(logits,
|
||||||
|
dim=-1,
|
||||||
|
dtype=torch.float32)
|
||||||
|
seq_logprobs.append(logprobs)
|
||||||
|
|
||||||
|
# convert to dict
|
||||||
|
seq_logprobs_lst = []
|
||||||
|
for tok_idx, tok_logprobs in enumerate(seq_logprobs):
|
||||||
|
# drop prompt logprobs
|
||||||
|
if tok_idx == 0:
|
||||||
|
tok_logprobs = tok_logprobs[-1, :].reshape(1, -1)
|
||||||
|
topk = tok_logprobs.topk(num_logprobs)
|
||||||
|
|
||||||
|
tok_logprobs_dct = {}
|
||||||
|
for token_id, logprob in zip(topk.indices[0], topk.values[0]):
|
||||||
|
tok_logprobs_dct[token_id.item()] = logprob.item()
|
||||||
|
|
||||||
|
seq_logprobs_lst.append(tok_logprobs_dct)
|
||||||
|
|
||||||
|
all_logprobs.append(seq_logprobs_lst)
|
||||||
|
seq_ids = output.sequences[0]
|
||||||
|
output_len = seq_ids.shape[0] - input_ids.shape[1]
|
||||||
|
output_ids = seq_ids[-output_len:]
|
||||||
|
all_output_ids.append(output_ids.tolist())
|
||||||
|
all_output_strs.append(self.tokenizer.decode(output_ids))
|
||||||
|
|
||||||
|
outputs = zip(all_output_ids, all_output_strs, all_logprobs)
|
||||||
|
return [(output_ids, output_str, output_logprobs)
|
||||||
|
for output_ids, output_str, output_logprobs in outputs]
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
del self.model
|
del self.model
|
||||||
cleanup()
|
cleanup()
|
||||||
|
@ -8,7 +8,7 @@ import pytest
|
|||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
"meta-llama/Llama-2-7b-hf",
|
"meta-llama/Llama-2-7b-hf",
|
||||||
# "mistralai/Mistral-7B-v0.1", # Broken
|
# "mistralai/Mistral-7B-v0.1", # Tested by test_mistral.py
|
||||||
# "Deci/DeciLM-7b", # Broken
|
# "Deci/DeciLM-7b", # Broken
|
||||||
# "tiiuae/falcon-7b", # Broken
|
# "tiiuae/falcon-7b", # Broken
|
||||||
"EleutherAI/gpt-j-6b",
|
"EleutherAI/gpt-j-6b",
|
||||||
|
@ -4,6 +4,8 @@ Run `pytest tests/models/test_mistral.py`.
|
|||||||
"""
|
"""
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from tests.models.utils import check_logprobs_close
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
"mistralai/Mistral-7B-Instruct-v0.1",
|
"mistralai/Mistral-7B-Instruct-v0.1",
|
||||||
]
|
]
|
||||||
@ -11,30 +13,31 @@ MODELS = [
|
|||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||||
@pytest.mark.parametrize("max_tokens", [128])
|
@pytest.mark.parametrize("max_tokens", [64])
|
||||||
@pytest.mark.skip(
|
@pytest.mark.parametrize("num_logprobs", [5])
|
||||||
"Two problems: 1. Failing correctness tests. 2. RuntimeError: expected "
|
|
||||||
"scalar type BFloat16 but found Half (only in CI).")
|
|
||||||
def test_models(
|
def test_models(
|
||||||
hf_runner,
|
hf_runner,
|
||||||
vllm_runner,
|
vllm_runner,
|
||||||
example_long_prompts,
|
example_prompts,
|
||||||
model: str,
|
model: str,
|
||||||
dtype: str,
|
dtype: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
|
num_logprobs: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
# TODO(sang): Sliding window should be tested separately.
|
||||||
hf_model = hf_runner(model, dtype=dtype)
|
hf_model = hf_runner(model, dtype=dtype)
|
||||||
hf_outputs = hf_model.generate_greedy(example_long_prompts, max_tokens)
|
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||||
|
example_prompts, max_tokens, num_logprobs)
|
||||||
del hf_model
|
del hf_model
|
||||||
|
|
||||||
vllm_model = vllm_runner(model, dtype=dtype)
|
vllm_model = vllm_runner(model, dtype=dtype)
|
||||||
vllm_outputs = vllm_model.generate_greedy(example_long_prompts, max_tokens)
|
vllm_outputs = vllm_model.generate_greedy_logprobs(example_prompts,
|
||||||
|
max_tokens,
|
||||||
|
num_logprobs)
|
||||||
del vllm_model
|
del vllm_model
|
||||||
|
check_logprobs_close(
|
||||||
for i in range(len(example_long_prompts)):
|
outputs_0_lst=hf_outputs,
|
||||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
outputs_1_lst=vllm_outputs,
|
||||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
name_0="hf",
|
||||||
assert hf_output_str == vllm_output_str, (
|
name_1="vllm",
|
||||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
)
|
||||||
assert hf_output_ids == vllm_output_ids, (
|
|
||||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
|
||||||
|
@ -109,7 +109,7 @@ class RotaryEmbedding(nn.Module):
|
|||||||
key_pass = key[..., self.rotary_dim:]
|
key_pass = key[..., self.rotary_dim:]
|
||||||
|
|
||||||
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
|
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
|
||||||
positions.device)
|
positions.device, dtype=query.dtype)
|
||||||
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
|
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
|
||||||
if offsets is not None else positions]
|
if offsets is not None else positions]
|
||||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||||
@ -143,7 +143,8 @@ class RotaryEmbedding(nn.Module):
|
|||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
offsets: Optional[torch.Tensor] = None,
|
offsets: Optional[torch.Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
|
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
|
||||||
|
dtype=query.dtype)
|
||||||
# ops.rotary_embedding()/batched_rotary_embedding()
|
# ops.rotary_embedding()/batched_rotary_embedding()
|
||||||
# are in-place operations that update the query and key tensors.
|
# are in-place operations that update the query and key tensors.
|
||||||
if offsets is not None:
|
if offsets is not None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user