[Speculative decoding 2/9] Multi-step worker for draft model (#2424)
This commit is contained in:
parent
71d63ed72e
commit
18bfcdd05c
0
tests/worker/__init__.py
Normal file
0
tests/worker/__init__.py
Normal file
0
tests/worker/spec_decode/__init__.py
Normal file
0
tests/worker/spec_decode/__init__.py
Normal file
261
tests/worker/spec_decode/test_multi_step_worker.py
Normal file
261
tests/worker/spec_decode/test_multi_step_worker.py
Normal file
@ -0,0 +1,261 @@
|
||||
import torch
|
||||
import random
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from vllm.worker.spec_decode.multi_step_worker import MultiStepWorker
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
from .utils import (create_execute_model_data, create_worker,
|
||||
create_seq_group_metadata_from_prompts, zero_kv_cache,
|
||||
patch_execute_model_with_seeds,
|
||||
assert_logprobs_dict_allclose)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('num_steps', list(range(1, 17)))
|
||||
def test_assert_enough_kv_space(num_steps: int):
|
||||
"""Test that the multi step worker checks for sufficient space in the KV
|
||||
cache. It should throw if it cannot run all the steps.
|
||||
"""
|
||||
block_size = 16
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
|
||||
prompts = [
|
||||
list(range(block_size * 3)),
|
||||
list(range(block_size * 2)),
|
||||
]
|
||||
|
||||
prev_output_tokens = [
|
||||
list(range(block_size * 1)),
|
||||
list(range(block_size * 2)),
|
||||
]
|
||||
|
||||
final_seq_lens = [
|
||||
len(prompt + output) + num_steps
|
||||
for prompt, output in zip(prompts, prev_output_tokens)
|
||||
]
|
||||
|
||||
inputs = create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
final_seq_lens,
|
||||
continuations=prev_output_tokens)
|
||||
|
||||
assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space # pylint: disable=protected-access
|
||||
worker = MagicMock()
|
||||
worker.model_runner.block_size = block_size
|
||||
|
||||
for seq_group_metadata in inputs:
|
||||
original_block_tables = seq_group_metadata.block_tables
|
||||
|
||||
# No exception.
|
||||
assert_enough_kv_space(worker, inputs, num_steps)
|
||||
|
||||
seq_group_metadata.block_tables = {
|
||||
seq_id: []
|
||||
for seq_id, physical_blocks in original_block_tables.items()
|
||||
}
|
||||
|
||||
# Expect exception.
|
||||
with pytest.raises(ValueError,
|
||||
match='times but found insufficient KV space for'):
|
||||
assert_enough_kv_space(worker, inputs, num_steps)
|
||||
|
||||
seq_group_metadata.block_tables = original_block_tables
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_same_output_for_single_step():
|
||||
"""Verify the multi step worker produces the same output as the normal
|
||||
worker for num_steps=1.
|
||||
"""
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
|
||||
block_size = 32
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
multi_step_worker = create_worker(
|
||||
MultiStepWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
worker = create_worker(
|
||||
Worker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
multi_step_worker.model_runner = worker.model_runner
|
||||
multi_step_worker.cache_engine = worker.cache_engine
|
||||
|
||||
num_steps = 1
|
||||
|
||||
prompts = [
|
||||
[1, 2, 3, 4, 5],
|
||||
[6, 7, 8, 9, 10],
|
||||
]
|
||||
|
||||
final_seq_lens = [len(prompt) + num_steps for prompt in prompts]
|
||||
|
||||
multi_step_execute_model_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size,
|
||||
final_seq_lens=final_seq_lens))
|
||||
|
||||
single_step_execute_model_data = create_execute_model_data(
|
||||
seq_group_metadata_list=create_seq_group_metadata_from_prompts(
|
||||
prompts, num_gpu_blocks, block_size,
|
||||
final_seq_lens=final_seq_lens))
|
||||
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
actual_output = multi_step_worker.execute_model_multi_step(
|
||||
**multi_step_execute_model_data.to_dict(), num_steps=num_steps)
|
||||
assert len(actual_output) == num_steps
|
||||
actual_output = actual_output[0]
|
||||
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
expected_output = worker.execute_model(
|
||||
**single_step_execute_model_data.to_dict(), )
|
||||
|
||||
actual_token_ids = [
|
||||
output.samples[0].output_token for output in actual_output
|
||||
]
|
||||
actual_logprobs = [output.samples[0].logprobs for output in actual_output]
|
||||
|
||||
expected_token_ids = [
|
||||
output.samples[0].output_token for output in expected_output
|
||||
]
|
||||
expected_logprobs = [
|
||||
output.samples[0].logprobs for output in expected_output
|
||||
]
|
||||
|
||||
assert actual_token_ids == expected_token_ids
|
||||
|
||||
print(f'{actual_logprobs=}')
|
||||
print(f'{expected_logprobs=}')
|
||||
assert_logprobs_dict_allclose(actual_logprobs, expected_logprobs)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_same_output_for_multi_step():
|
||||
"""Verify the multi-step worker produces the same output as the normal
|
||||
worker when num_steps > 1. This test runs the multi-step worker once, and
|
||||
then runs the worker num_steps times, and compares the output.
|
||||
"""
|
||||
seed = 100
|
||||
model_name = 'JackFram/llama-68m'
|
||||
|
||||
block_size = 16
|
||||
num_gpu_blocks = 2048 // block_size
|
||||
multi_step_worker = create_worker(
|
||||
MultiStepWorker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
|
||||
worker = create_worker(
|
||||
Worker,
|
||||
model_name,
|
||||
block_size,
|
||||
num_gpu_blocks,
|
||||
seed,
|
||||
)
|
||||
|
||||
# Make sure we go over the block boundary.
|
||||
num_steps = block_size + 1
|
||||
|
||||
random.seed(seed)
|
||||
prompts = [[
|
||||
random.randint(0, 1000) for _ in range(random.randint(10, 20))
|
||||
] for _ in range(10)]
|
||||
|
||||
final_seq_lens = [len(prompt) + num_steps for prompt in prompts]
|
||||
|
||||
rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
|
||||
multi_step_worker.execute_model = patch_execute_model_with_seeds(
|
||||
multi_step_worker, rand_seeds)
|
||||
worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)
|
||||
|
||||
continuations = [[1] for _ in prompts]
|
||||
execute_model_data = create_execute_model_data(
|
||||
create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_seq_lens=final_seq_lens), )
|
||||
|
||||
# Run multi-step.
|
||||
zero_kv_cache(multi_step_worker.cache_engine)
|
||||
set_random_seed(seed)
|
||||
multi_step_output = multi_step_worker.execute_model_multi_step(
|
||||
**execute_model_data.to_dict(), num_steps=num_steps)
|
||||
|
||||
# Run single-step repeatedly.
|
||||
zero_kv_cache(worker.cache_engine)
|
||||
single_step_output = []
|
||||
continuations = [[1] for _ in prompts]
|
||||
set_random_seed(seed)
|
||||
|
||||
for _ in multi_step_output:
|
||||
|
||||
execute_model_data = create_execute_model_data(
|
||||
create_seq_group_metadata_from_prompts(
|
||||
prompts,
|
||||
num_gpu_blocks,
|
||||
block_size,
|
||||
continuations=continuations,
|
||||
final_seq_lens=final_seq_lens))
|
||||
|
||||
single_step_output.append(
|
||||
worker.execute_model(**execute_model_data.to_dict(), ))
|
||||
|
||||
# Append output tokens to new sequence data.
|
||||
for i, seq_group_output in enumerate(single_step_output[-1]):
|
||||
continuations[i].append(seq_group_output.samples[0].output_token)
|
||||
|
||||
# Get token ids and logprobs for comparison.
|
||||
multi_step_output_logprobs = [[] for _ in prompts]
|
||||
single_step_output_logprobs = [[] for _ in prompts]
|
||||
|
||||
multi_step_output_token_ids = [[] for _ in prompts]
|
||||
single_step_output_token_ids = [[] for _ in prompts]
|
||||
for i, _ in enumerate(prompts):
|
||||
for multi_step, single_step in zip(multi_step_output,
|
||||
single_step_output):
|
||||
multi_step_output_token_ids[i].append(
|
||||
multi_step[i].samples[0].output_token)
|
||||
single_step_output_token_ids[i].append(
|
||||
single_step[i].samples[0].output_token)
|
||||
|
||||
multi_step_output_logprobs[i].append(
|
||||
multi_step[i].samples[0].logprobs)
|
||||
single_step_output_logprobs[i].append(
|
||||
single_step[i].samples[0].logprobs)
|
||||
|
||||
# Print per-sequence token ids
|
||||
for i, (multi_step_tokens, single_step_tokens) in enumerate(
|
||||
zip(multi_step_output_token_ids, single_step_output_token_ids)):
|
||||
print(f'{i=} {multi_step_tokens=}')
|
||||
print(f'{i=} {single_step_tokens=}')
|
||||
print(f'{i=} equal {multi_step_tokens == single_step_tokens}')
|
||||
|
||||
# Assert token ids are equal.
|
||||
for multi_step_tokens, single_step_tokens in zip(
|
||||
multi_step_output_token_ids, single_step_output_token_ids):
|
||||
assert multi_step_tokens == single_step_tokens
|
||||
|
||||
# Assert logprobs are equal.
|
||||
for multi_step_logprobs, single_step_logprobs in zip(
|
||||
multi_step_output_logprobs, single_step_output_logprobs):
|
||||
assert_logprobs_dict_allclose(multi_step_logprobs,
|
||||
single_step_logprobs)
|
177
tests/worker/spec_decode/utils.py
Normal file
177
tests/worker/spec_decode/utils.py
Normal file
@ -0,0 +1,177 @@
|
||||
import torch
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.sequence import SequenceGroupMetadata, SequenceData
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from dataclasses import dataclass, fields
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecuteModelData:
|
||||
"""Helper data structure which facilitates cleaner tests.
|
||||
"""
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
blocks_to_swap_in: Dict[int, int]
|
||||
blocks_to_swap_out: Dict[int, int]
|
||||
blocks_to_copy: Dict[int, List[int]]
|
||||
|
||||
def to_dict(self):
|
||||
return dict(
|
||||
(field.name, getattr(self, field.name)) for field in fields(self))
|
||||
|
||||
|
||||
def round_up_to_next_block(seq_len: int, block_size: int) -> int:
|
||||
return (seq_len + block_size - 1) // block_size
|
||||
|
||||
|
||||
def create_execute_model_data(
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Optional[Dict[int, int]] = None,
|
||||
blocks_to_swap_out: Optional[Dict[int, int]] = None,
|
||||
blocks_to_copy: Optional[Dict[int, int]] = None,
|
||||
) -> ExecuteModelData:
|
||||
if blocks_to_swap_in is None:
|
||||
blocks_to_swap_in = {}
|
||||
if blocks_to_swap_out is None:
|
||||
blocks_to_swap_out = {}
|
||||
if blocks_to_copy is None:
|
||||
blocks_to_copy = {}
|
||||
|
||||
return ExecuteModelData(
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
|
||||
|
||||
def patch_execute_model_with_seeds(worker: Worker, rand_seeds: List[int]):
|
||||
seed_iter = iter(rand_seeds)
|
||||
original_execute_model = worker.execute_model
|
||||
|
||||
def new_execute_model(*args, **kwargs):
|
||||
result = original_execute_model(*args, **kwargs)
|
||||
set_random_seed(next(seed_iter))
|
||||
return result
|
||||
|
||||
return new_execute_model
|
||||
|
||||
|
||||
def zero_kv_cache(cache_engine: CacheEngine):
|
||||
assert cache_engine.gpu_cache
|
||||
for key_blocks, value_blocks in cache_engine.gpu_cache:
|
||||
key_blocks.zero_()
|
||||
value_blocks.zero_()
|
||||
|
||||
|
||||
def create_worker(cls: type,
|
||||
model_name: str,
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
seed: int,
|
||||
is_driver_worker: bool = True,
|
||||
enforce_eager: bool = True):
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
seed=seed,
|
||||
block_size=block_size,
|
||||
enforce_eager=enforce_eager,
|
||||
)
|
||||
|
||||
(model_config, cache_config, parallel_config,
|
||||
scheduler_config) = engine_args.create_engine_configs()
|
||||
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
|
||||
worker = cls(
|
||||
model_config=model_config,
|
||||
parallel_config=parallel_config,
|
||||
scheduler_config=scheduler_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=is_driver_worker,
|
||||
)
|
||||
|
||||
worker.init_model()
|
||||
worker.load_model()
|
||||
|
||||
cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
cache_config.num_cpu_blocks = 0
|
||||
worker.init_cache_engine(cache_config)
|
||||
worker.warm_up_model()
|
||||
|
||||
return worker
|
||||
|
||||
|
||||
def create_seq_group_metadata_from_prompts(
|
||||
prompts: List[List[int]],
|
||||
num_gpu_blocks: int,
|
||||
block_size: int,
|
||||
final_seq_lens: List[int],
|
||||
continuations: Optional[List[List[int]]] = None,
|
||||
num_tokens_processed: Optional[List[int]] = None,
|
||||
seq_ids: Optional[List[int]] = None,
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
|
||||
if continuations is None:
|
||||
continuations = [[] for _ in prompts]
|
||||
|
||||
if num_tokens_processed is None:
|
||||
# Default to 1 token missing from kv cache for generation sequences.
|
||||
num_tokens_processed = []
|
||||
for continuation, prompt in zip(continuations, prompts):
|
||||
# If prefill, then default to zero tokens processed.
|
||||
if not continuation:
|
||||
num_tokens_processed.append(0)
|
||||
else:
|
||||
# If generation, then default to all but one tokens processed.
|
||||
num_tokens_processed.append(
|
||||
len(continuation) + len(prompt) - 1)
|
||||
|
||||
if seq_ids is None:
|
||||
seq_ids = list(i for i, _ in enumerate(prompts))
|
||||
|
||||
free_gpu_blocks = list(range(num_gpu_blocks))
|
||||
|
||||
block_allocations = {
|
||||
i: [
|
||||
free_gpu_blocks.pop()
|
||||
for _ in range(round_up_to_next_block(final_len, block_size))
|
||||
]
|
||||
for i, final_len in enumerate(final_seq_lens)
|
||||
}
|
||||
|
||||
return [
|
||||
SequenceGroupMetadata(
|
||||
request_id=str(i),
|
||||
is_prompt=len(cont_token_ids) == 0,
|
||||
seq_data={
|
||||
i:
|
||||
SequenceData(prompt_token_ids=prompt_token_ids[:] +
|
||||
cont_token_ids[:])
|
||||
},
|
||||
sampling_params=SamplingParams(temperature=0.0, ),
|
||||
block_tables={i: block_allocations[i][:]},
|
||||
) for i, (prompt_token_ids, cont_token_ids, num_tokens_saved) in
|
||||
enumerate(zip(prompts, continuations, num_tokens_processed))
|
||||
]
|
||||
|
||||
|
||||
def assert_logprobs_dict_allclose(
|
||||
actual_logprobs: List[Dict[int, float]],
|
||||
expected_logprobs: List[Dict[int, float]]) -> None:
|
||||
for single_step_actual_logprobs, single_step_expected_logprobs in zip(
|
||||
actual_logprobs, expected_logprobs):
|
||||
assert set(single_step_actual_logprobs.keys()) == set(
|
||||
single_step_expected_logprobs.keys())
|
||||
for token_id in single_step_actual_logprobs:
|
||||
actual = torch.tensor(single_step_actual_logprobs[token_id])
|
||||
expected = torch.tensor(single_step_expected_logprobs[token_id])
|
||||
assert torch.allclose(actual, expected)
|
@ -18,7 +18,7 @@ from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
|
||||
SequenceGroupOutput, SequenceOutput, SequenceStatus)
|
||||
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
||||
get_tokenizer)
|
||||
from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port
|
||||
from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port, get_distributed_init_method
|
||||
|
||||
if ray:
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
@ -132,7 +132,8 @@ class LLMEngine:
|
||||
"Ray is required if parallel_config.world_size > 1.")
|
||||
|
||||
self.workers: List[Worker] = []
|
||||
distributed_init_method = f"tcp://{get_ip()}:{get_open_port()}"
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
self.driver_worker = Worker(
|
||||
self.model_config,
|
||||
self.parallel_config,
|
||||
@ -207,7 +208,8 @@ class LLMEngine:
|
||||
for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
|
||||
worker.set_cuda_visible_devices.remote(node_gpus[node_id])
|
||||
|
||||
distributed_init_method = f"tcp://{driver_ip}:{get_open_port()}"
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
driver_ip, get_open_port)
|
||||
|
||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||
|
@ -65,10 +65,9 @@ def initialize_cluster(
|
||||
the default Ray cluster address.
|
||||
|
||||
Returns:
|
||||
A tuple of (`distributed_init_method`, `placement_group`). The
|
||||
`distributed_init_method` is the address for initializing the
|
||||
distributed backend. `placement_group` includes the specification
|
||||
of the resources for each distributed worker.
|
||||
An optional `PlacementGroup`. It includes the specification
|
||||
of the resources for each distributed worker. None if Ray is
|
||||
not used.
|
||||
"""
|
||||
if parallel_config.worker_use_ray or engine_use_ray:
|
||||
if ray is None:
|
||||
|
@ -83,6 +83,31 @@ def initialize_model_parallel(
|
||||
_PIPELINE_GLOBAL_RANKS = ranks
|
||||
|
||||
|
||||
def ensure_model_parallel_initialized(
|
||||
tensor_model_parallel_size: int,
|
||||
pipeline_model_parallel_size: int,
|
||||
) -> None:
|
||||
"""Helper to initialize model parallel groups if they are not initialized,
|
||||
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
|
||||
values if the model parallel groups are initialized.
|
||||
"""
|
||||
if not model_parallel_is_initialized():
|
||||
initialize_model_parallel(tensor_model_parallel_size,
|
||||
pipeline_model_parallel_size)
|
||||
return
|
||||
|
||||
assert (
|
||||
get_tensor_model_parallel_world_size() == tensor_model_parallel_size
|
||||
), ("tensor parallel group already initialized, but of unexpected size: "
|
||||
f"{get_tensor_model_parallel_world_size()=} vs. "
|
||||
f"{tensor_model_parallel_size=}")
|
||||
assert (get_pipeline_model_parallel_world_size(
|
||||
) == pipeline_model_parallel_size), (
|
||||
"pipeline parallel group already initialized, but of unexpected size: "
|
||||
f"{get_pipeline_model_parallel_world_size()=} vs. "
|
||||
f"{pipeline_model_parallel_size=}")
|
||||
|
||||
|
||||
def model_parallel_is_initialized():
|
||||
"""Check if tensor and pipeline parallel groups are initialized."""
|
||||
return (_TENSOR_MODEL_PARALLEL_GROUP is not None
|
||||
|
@ -65,6 +65,10 @@ def get_ip() -> str:
|
||||
return s.getsockname()[0]
|
||||
|
||||
|
||||
def get_distributed_init_method(ip: str, port: int) -> str:
|
||||
return f"tcp://{ip}:{port}"
|
||||
|
||||
|
||||
def get_open_port() -> int:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
|
@ -277,8 +277,8 @@ class ModelRunner:
|
||||
input_block_tables[i, :len(block_table)] = block_table
|
||||
block_tables = torch.tensor(input_block_tables, device="cuda")
|
||||
else:
|
||||
max_block_table_len = (max_context_len + self.block_size -
|
||||
1) // self.block_size
|
||||
max_block_table_len = max(
|
||||
len(block_table) for block_table in block_tables)
|
||||
block_tables = _make_tensor_with_pad(
|
||||
block_tables,
|
||||
max_len=max_block_table_len,
|
||||
|
178
vllm/worker/spec_decode/multi_step_worker.py
Normal file
178
vllm/worker/spec_decode/multi_step_worker.py
Normal file
@ -0,0 +1,178 @@
|
||||
from typing import List, Dict
|
||||
import copy
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
|
||||
class MultiStepWorker(Worker):
|
||||
"""The MultiStepWorker is equivalent to a Worker except that it allows
|
||||
multiple forward passes in a single call, assuming the scheduler has
|
||||
allocated enough space to store the additional KV. This reduces overhead
|
||||
by invoking the scheduler less.
|
||||
|
||||
The MultiStepWorker does not support cache swap operations, or beam search.
|
||||
Cache swap operations do not require large modifications. On the other hand,
|
||||
beam search requires memory allocations during sequence forks and thus
|
||||
requires more thought for MultiStepWorker support.
|
||||
"""
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model_multi_step(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
num_steps: int,
|
||||
) -> List[SamplerOutput]:
|
||||
"""Run the model forward pass num_steps times. Returns the list of
|
||||
sampler output, one per model forward pass.
|
||||
"""
|
||||
self._raise_if_unsupported(seq_group_metadata_list, blocks_to_swap_in,
|
||||
blocks_to_swap_out, blocks_to_copy)
|
||||
|
||||
# Shallow copy input data so modifications (such as appending tokens)
|
||||
# do not cause side-effects.
|
||||
copied_seq_group_metadata_list = self._shallow_copy_inputs(
|
||||
seq_group_metadata_list)
|
||||
|
||||
# Assert enough KV space for num_steps tokens per sequence.
|
||||
self._assert_enough_kv_space(seq_group_metadata_list, num_steps)
|
||||
|
||||
# Run model num_steps times.
|
||||
model_outputs = []
|
||||
for _ in range(num_steps):
|
||||
model_output = super().execute_model(
|
||||
seq_group_metadata_list=copied_seq_group_metadata_list,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
)
|
||||
|
||||
self._append_new_tokens(model_output,
|
||||
copied_seq_group_metadata_list)
|
||||
model_outputs.append(model_output)
|
||||
|
||||
return model_outputs
|
||||
|
||||
def _append_new_tokens(
|
||||
self, model_output: SamplerOutput,
|
||||
seq_group_metadata_list: SequenceGroupMetadata) -> None:
|
||||
"""Given model output from a single run, append the tokens to the
|
||||
sequences. This is normally done outside of the worker, but it is
|
||||
required if the worker is to perform multiple forward passes.
|
||||
"""
|
||||
for seq_group_metadata, sequence_group_outputs in zip(
|
||||
seq_group_metadata_list, model_output):
|
||||
seq_group_metadata.is_prompt = False
|
||||
|
||||
for seq_output in sequence_group_outputs.samples:
|
||||
# NOTE: Beam search is not supported, so we can assume that
|
||||
# parent_seq_id == seq_id.
|
||||
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]
|
||||
|
||||
token_id = seq_output.output_token
|
||||
token_logprob = seq_output.logprobs[token_id]
|
||||
|
||||
seq.append_token_id(token_id, token_logprob)
|
||||
|
||||
def _shallow_copy_inputs(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata]
|
||||
) -> List[SequenceGroupMetadata]:
|
||||
"""Copy input data structures to remove side-effects when input data
|
||||
structures are shared with other modules.
|
||||
|
||||
The multi-step worker must be able to append tokens to sequences after
|
||||
a forward pass. This necessitates modification of the data structures
|
||||
used by the worker. Since these data structures are shared with other
|
||||
parts of vLLM, like the scheduler, we must take care not to introduce
|
||||
unexpected side-effects.
|
||||
|
||||
When Ray is used to orchestrate worker processes (such as when the
|
||||
tensor-parallel degree is >1), this is not a problem because the input
|
||||
datastructures will be serialized and created anew in the worker
|
||||
process.
|
||||
|
||||
However, when Ray is not used to orchestrate the worker processes (such
|
||||
as when the tensor-parallel degree is 1), this is a problem. We avoid
|
||||
the problem by shallow-copying the input datastructures (specifically,
|
||||
the parts that will change in multiple steps).
|
||||
"""
|
||||
|
||||
# Shallow-copy the list of SequenceGroupMetadata. This allows us to
|
||||
# append tokens and change is_prompt without external side-effects.
|
||||
new_seq_group_metadata_list = []
|
||||
|
||||
for old_seq_group_metadata in seq_group_metadata_list:
|
||||
# We must shallow-copy seq_group_metadata as is_prompt could change.
|
||||
seq_group_metadata = copy.copy(old_seq_group_metadata)
|
||||
new_seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
# We must shallow-copy seq_data as we will append token ids
|
||||
new_seq_data = {}
|
||||
for seq_id, old_seq_data in seq_group_metadata.seq_data.items():
|
||||
new_seq_data[seq_id] = copy.copy(old_seq_data)
|
||||
new_seq_data[
|
||||
seq_id].output_token_ids = old_seq_data.output_token_ids[:]
|
||||
|
||||
seq_group_metadata.seq_data = new_seq_data
|
||||
|
||||
return new_seq_group_metadata_list
|
||||
|
||||
def _assert_enough_kv_space(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
num_steps: int) -> None:
|
||||
"""Assert there are enough physical blocks per sequence to store the
|
||||
current KV plus additional KV from num_steps tokens.
|
||||
"""
|
||||
assert self.model_runner.block_size is not None
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
# Only one seq_id is guaranteed because there is no beam search.
|
||||
seq_id = list(seq_group_metadata.seq_data.keys())[0]
|
||||
seq = seq_group_metadata.seq_data[seq_id]
|
||||
|
||||
# After num_steps, the seq len will be the current seq len
|
||||
# plus one token per step.
|
||||
final_seq_len = seq.get_len() + num_steps
|
||||
|
||||
# We will have final_seq_len - 1 KV because vLLM saves KV for a
|
||||
# token in the iteration after the token was generated.
|
||||
required_num_kv_slots = final_seq_len - 1
|
||||
|
||||
# The allocated number of kv slots is the number of allocated blocks
|
||||
# times the number of slots of block.
|
||||
number_physical_blocks = len(
|
||||
seq_group_metadata.block_tables[seq_id])
|
||||
allocated_kv_slots = (number_physical_blocks *
|
||||
self.model_runner.block_size)
|
||||
|
||||
if required_num_kv_slots > allocated_kv_slots:
|
||||
request_id = seq_group_metadata.request_id
|
||||
raise ValueError(
|
||||
"The worker attempted to run "
|
||||
f"{num_steps} times but found insufficient KV space for "
|
||||
f"{request_id=} {seq_id=}. ({allocated_kv_slots=} "
|
||||
f"{required_num_kv_slots=}).")
|
||||
|
||||
def _raise_if_unsupported(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
) -> None:
|
||||
"""MultiStepWorker does not yet implement support for cache swap
|
||||
operations or beam search.
|
||||
"""
|
||||
if any([blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy]):
|
||||
raise NotImplementedError(
|
||||
"MultiStepWorker does not support cache operations")
|
||||
|
||||
if any(
|
||||
len(seq_group_metadata.seq_data.keys()) != 1
|
||||
for seq_group_metadata in seq_group_metadata_list):
|
||||
raise NotImplementedError(
|
||||
"MultiStepWorker does not support beam search.")
|
@ -11,7 +11,7 @@ from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
broadcast_tensor_dict)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
initialize_model_parallel)
|
||||
ensure_model_parallel_initialized)
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
@ -227,8 +227,8 @@ def _init_distributed_environment(
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
||||
initialize_model_parallel(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
|
||||
def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
|
||||
|
Loading…
x
Reference in New Issue
Block a user