[Misc] Remove unnecessary ModelRunner imports (#4703)

This commit is contained in:
Woosuk Kwon 2024-05-09 00:17:17 -07:00 committed by GitHub
parent f12b20decc
commit 190bc838e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 73 deletions

View File

@ -11,8 +11,7 @@ from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import Counter from vllm.utils import Counter, is_pin_memory_available
from vllm.worker.model_runner import ModelRunner
class MockLogitsSampler(Sampler): class MockLogitsSampler(Sampler):
@ -26,20 +25,14 @@ class MockLogitsSampler(Sampler):
def _prepare_test( def _prepare_test(
batch_size: int batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]: ) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler]:
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
fake_logits = torch.full((batch_size, VOCAB_SIZE), fake_logits = torch.full((batch_size, VOCAB_SIZE),
1e-2, 1e-2,
dtype=input_tensor.dtype) dtype=input_tensor.dtype)
sampler = MockLogitsSampler(fake_logits) sampler = MockLogitsSampler(fake_logits)
model_runner = ModelRunner(model_config=None, return input_tensor, fake_logits, sampler
parallel_config=None,
scheduler_config=None,
device_config=None,
load_config=None,
lora_config=None)
return input_tensor, fake_logits, sampler, model_runner
VOCAB_SIZE = 32000 VOCAB_SIZE = 32000
@ -53,7 +46,6 @@ def _do_sample(
batch_size: int, batch_size: int,
input_tensor: torch.Tensor, input_tensor: torch.Tensor,
sampler: MockLogitsSampler, sampler: MockLogitsSampler,
model_runner: ModelRunner,
sampling_params: SamplingParams, sampling_params: SamplingParams,
device: str, device: str,
): ):
@ -75,7 +67,7 @@ def _do_sample(
seq_lens, seq_lens,
query_lens=seq_lens, query_lens=seq_lens,
device=device, device=device,
pin_memory=model_runner.pin_memory) pin_memory=is_pin_memory_available())
return sampler(logits=input_tensor, sampling_metadata=sampling_metadata) return sampler(logits=input_tensor, sampling_metadata=sampling_metadata)
@ -85,19 +77,16 @@ def test_sampler_all_greedy(seed: int, device: str):
set_random_seed(seed) set_random_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
batch_size = random.randint(1, 256) batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test( input_tensor, fake_logits, sampler = _prepare_test(batch_size)
batch_size)
sampling_params = SamplingParams(temperature=0) sampling_params = SamplingParams(temperature=0)
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device) sampling_params, device)
expected = torch.argmax(fake_logits, dim=-1) expected = torch.argmax(fake_logits, dim=-1)
for i, sequence_output in enumerate(sampler_output): for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples: for nth_output in sequence_output.samples:
assert nth_output.output_token == expected[i].item() assert nth_output.output_token == expected[i].item()
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@ -105,8 +94,7 @@ def test_sampler_all_random(seed: int, device: str):
set_random_seed(seed) set_random_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
batch_size = random.randint(1, 256) batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test( _, fake_logits, sampler = _prepare_test(batch_size)
batch_size)
for i in range(batch_size): for i in range(batch_size):
fake_logits[i, i] = 1e2 fake_logits[i, i] = 1e2
@ -115,15 +103,13 @@ def test_sampler_all_random(seed: int, device: str):
temperature=1.0, temperature=1.0,
n=random.randint(1, 10), n=random.randint(1, 10),
) )
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device) sampling_params, device)
for i, sequence_output in enumerate(sampler_output): for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples: for nth_output in sequence_output.samples:
assert nth_output.output_token == i assert nth_output.output_token == i
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@ -131,7 +117,7 @@ def test_sampler_all_random_seed(seed: int, device: str):
set_random_seed(seed) set_random_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
batch_size = random.randint(1, 256) batch_size = random.randint(1, 256)
_, fake_logits, sampler, model_runner = _prepare_test(batch_size) _, fake_logits, sampler = _prepare_test(batch_size)
for i in range(batch_size): for i in range(batch_size):
fake_logits[i, i] = 1e2 fake_logits[i, i] = 1e2
@ -141,15 +127,13 @@ def test_sampler_all_random_seed(seed: int, device: str):
n=random.randint(1, 10), n=random.randint(1, 10),
seed=random.randint(0, 10000), seed=random.randint(0, 10000),
) )
sampler_output = _do_sample(batch_size, fake_logits, sampler, model_runner, sampler_output = _do_sample(batch_size, fake_logits, sampler,
sampling_params, device) sampling_params, device)
for i, sequence_output in enumerate(sampler_output): for i, sequence_output in enumerate(sampler_output):
for nth_output in sequence_output.samples: for nth_output in sequence_output.samples:
assert nth_output.output_token == i assert nth_output.output_token == i
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@ -157,7 +141,7 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str):
set_random_seed(seed) set_random_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
batch_size = random.randint(1, 256) batch_size = random.randint(1, 256)
_, fake_logits, sampler, model_runner = _prepare_test(batch_size) _, fake_logits, sampler = _prepare_test(batch_size)
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=1.0, temperature=1.0,
@ -165,15 +149,13 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str):
seed=random.randint(0, 10000), seed=random.randint(0, 10000),
) )
first_sampler_output = _do_sample(batch_size, fake_logits, sampler, first_sampler_output = _do_sample(batch_size, fake_logits, sampler,
model_runner, sampling_params, device) sampling_params, device)
second_sampler_output = _do_sample(batch_size, fake_logits, sampler, second_sampler_output = _do_sample(batch_size, fake_logits, sampler,
model_runner, sampling_params, device) sampling_params, device)
assert first_sampler_output == second_sampler_output assert first_sampler_output == second_sampler_output
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("seed", RANDOM_SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("device", CUDA_DEVICES)
@ -181,20 +163,18 @@ def test_sampler_all_beam(seed: int, device: str):
set_random_seed(seed) set_random_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
batch_size = random.randint(1, 256) batch_size = random.randint(1, 256)
_, fake_logits, sampler, model_runner = _prepare_test(batch_size) _, fake_logits, sampler = _prepare_test(batch_size)
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=0, temperature=0,
best_of=2, best_of=2,
use_beam_search=True, use_beam_search=True,
) )
_do_sample(batch_size, fake_logits, sampler, model_runner, sampling_params, _do_sample(batch_size, fake_logits, sampler, sampling_params, device)
device)
# no assertion here as I am not sure how to determine whether # no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests # the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler # whether there are no exceptions in the sampler
# when handling an all-beam search case. # when handling an all-beam search case.
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("seed", RANDOM_SEEDS)
@ -448,13 +428,13 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
("Invalid test case, expected_penalization does not match computed" ("Invalid test case, expected_penalization does not match computed"
"batch size") "batch size")
_, fake_logits, sampler, model_runner = _prepare_test(batch_size) _, fake_logits, sampler = _prepare_test(batch_size)
sampling_metadata = SamplingMetadata.prepare( sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_group_metadata_list,
seq_lens=seq_lens if seq_lens else None, seq_lens=seq_lens if seq_lens else None,
query_lens=seq_lens if seq_lens else None, query_lens=seq_lens if seq_lens else None,
device=device, device=device,
pin_memory=model_runner.pin_memory) pin_memory=is_pin_memory_available())
# the logits tensor is modified in-place by the sampler # the logits tensor is modified in-place by the sampler
_ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata) _ = sampler(logits=fake_logits, sampling_metadata=sampling_metadata)
@ -480,8 +460,6 @@ def test_sampler_min_tokens_penalty(seed: int, device: str):
fake_logits[logits_idx, :] == fake_logits[logits_idx, :] ==
-float('inf')) == 0, "No tokens should have been penalized" -float('inf')) == 0, "No tokens should have been penalized"
del model_runner
for test_case in test_cases: for test_case in test_cases:
run_test_case(**test_case) run_test_case(**test_case)
@ -492,8 +470,7 @@ def test_sampler_mixed(seed: int, device: str):
set_random_seed(seed) set_random_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
batch_size = random.randint(1, 256) batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test( input_tensor, fake_logits, sampler = _prepare_test(batch_size)
batch_size)
seq_group_metadata_list = [] seq_group_metadata_list = []
expected_tokens: List[Optional[List[int]]] = [] expected_tokens: List[Optional[List[int]]] = []
@ -534,13 +511,13 @@ def test_sampler_mixed(seed: int, device: str):
)) ))
seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len()) seq_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
def test_sampling(model_runner: ModelRunner): def test_sampling():
sampling_metadata = SamplingMetadata.prepare( sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list, seq_group_metadata_list,
seq_lens, seq_lens,
query_lens=seq_lens, query_lens=seq_lens,
device=device, device=device,
pin_memory=model_runner.pin_memory) pin_memory=is_pin_memory_available())
sampler_output = sampler(logits=fake_logits, sampler_output = sampler(logits=fake_logits,
sampling_metadata=sampling_metadata) sampling_metadata=sampling_metadata)
@ -570,7 +547,7 @@ def test_sampler_mixed(seed: int, device: str):
assert nth_output.output_token in expected_tokens[i] assert nth_output.output_token in expected_tokens[i]
# Test batch # Test batch
test_sampling(model_runner) test_sampling()
# Shuffle the batch and resample # Shuffle the batch and resample
target_index = list(range(batch_size)) target_index = list(range(batch_size))
@ -583,9 +560,7 @@ def test_sampler_mixed(seed: int, device: str):
# This time, results of seeded random samples will be compared with # This time, results of seeded random samples will be compared with
# the corresponding sample in the pre-shuffled batch # the corresponding sample in the pre-shuffled batch
test_sampling(model_runner) test_sampling()
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("seed", RANDOM_SEEDS)
@ -605,12 +580,6 @@ def test_sampler_top_k_top_p(seed: int, device: str):
device=input_tensor.device, device=input_tensor.device,
dtype=input_tensor.dtype) dtype=input_tensor.dtype)
sampler = MockLogitsSampler(fake_logits) sampler = MockLogitsSampler(fake_logits)
model_runner = ModelRunner(model_config=None,
parallel_config=None,
scheduler_config=None,
device_config=None,
load_config=None,
lora_config=None)
generation_model = GenerationMixin() generation_model = GenerationMixin()
generation_config = GenerationConfig(top_k=top_k, generation_config = GenerationConfig(top_k=top_k,
@ -641,7 +610,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
seq_lens, seq_lens,
query_lens=seq_lens, query_lens=seq_lens,
device=device, device=device,
pin_memory=model_runner.pin_memory) pin_memory=is_pin_memory_available())
sample_probs = None sample_probs = None
@ -657,5 +626,3 @@ def test_sampler_top_k_top_p(seed: int, device: str):
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float) hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
assert torch.allclose(hf_probs, sample_probs, atol=1e-5) assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0)) assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
del model_runner

View File

@ -9,7 +9,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.worker.model_runner import ModelRunner from vllm.utils import is_pin_memory_available
class MockLogitsProcessor(LogitsProcessor): class MockLogitsProcessor(LogitsProcessor):
@ -30,21 +30,15 @@ class MockLogitsProcessor(LogitsProcessor):
def _prepare_test( def _prepare_test(
batch_size: int batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor, ModelRunner]: ) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsProcessor]:
vocab_size = 32000 vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
fake_logits = torch.full((batch_size, vocab_size), fake_logits = torch.full((batch_size, vocab_size),
1e-2, 1e-2,
dtype=input_tensor.dtype) dtype=input_tensor.dtype)
logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits) logits_processor = MockLogitsProcessor(32000, 0.5, fake_logits)
model_runner = ModelRunner(model_config=None, return input_tensor, fake_logits, logits_processor
parallel_config=None,
scheduler_config=None,
device_config=None,
load_config=None,
lora_config=None)
return input_tensor, fake_logits, logits_processor, model_runner
RANDOM_SEEDS = list(range(128)) RANDOM_SEEDS = list(range(128))
@ -59,8 +53,7 @@ def test_logits_processors(seed: int, device: str):
set_random_seed(seed) set_random_seed(seed)
torch.set_default_device(device) torch.set_default_device(device)
batch_size = random.randint(1, 256) batch_size = random.randint(1, 256)
input_tensor, fake_logits, logits_processor, model_runner = _prepare_test( input_tensor, fake_logits, logits_processor = _prepare_test(batch_size)
batch_size)
# This sample logits processor gives infinite score to the i-th token, # This sample logits processor gives infinite score to the i-th token,
# where i is the length of the input sequence. # where i is the length of the input sequence.
@ -87,8 +80,8 @@ def test_logits_processors(seed: int, device: str):
seq_group_metadata_list, seq_group_metadata_list,
seq_lens, seq_lens,
query_lens=seq_lens, query_lens=seq_lens,
device=model_runner.device, device=device,
pin_memory=model_runner.pin_memory) pin_memory=is_pin_memory_available())
logits_processor_output = logits_processor( logits_processor_output = logits_processor(
embedding=None, embedding=None,
hidden_states=input_tensor, hidden_states=input_tensor,
@ -99,5 +92,3 @@ def test_logits_processors(seed: int, device: str):
fake_logits *= logits_processor.scale fake_logits *= logits_processor.scale
assert torch.allclose(logits_processor_output[:, 1], fake_logits[:, 1], assert torch.allclose(logits_processor_output[:, 1], fake_logits[:, 1],
1e-4) 1e-4)
del model_runner