vllm/cacheflow/worker/worker.py
2023-05-20 13:06:59 -07:00

319 lines
12 KiB
Python

"""A GPU worker class."""
from typing import Dict, List, Tuple
import torch
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
from cacheflow.model_executor import get_model, InputMetadata, set_random_seed
from cacheflow.model_executor.parallel_utils.parallel_state import (
initialize_model_parallel, initialize_all_reduce_launcher)
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import (SequenceData, SequenceGroupMetadata,
SequenceOutputs)
from cacheflow.worker.cache_engine import CacheEngine
from cacheflow.utils import get_gpu_memory
class Worker:
"""A worker class that executes (a partition of) the model on a GPU.
Each worker is associated with a single GPU. The worker is responsible for
maintaining the KV cache and executing the model on the GPU. In case of
distributed inference, each worker is assigned a partition of the model.
"""
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
rank: int,
distributed_init_method: str,
) -> None:
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.rank = rank
self.distributed_init_method = distributed_init_method
# Initialize the distributed environment.
_init_distributed_environment(parallel_config, rank,
distributed_init_method)
# Initialize the model.
set_random_seed(self.model_config.seed)
self.model = get_model(model_config)
initialize_all_reduce_launcher(
self.scheduler_config.max_num_batched_tokens,
self.model_config.get_hidden_size(),
self.model_config.dtype,
)
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
self.cache_config = None
self.block_size = None
self.cache_engine = None
self.cache_events = None
self.gpu_cache = None
@torch.inference_mode()
def profile_num_available_blocks(
self,
block_size: int,
gpu_memory_utilization: float,
cpu_swap_space: int,
) -> Tuple[int, int]:
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Profile memory usage with max_num_sequences sequences and the total
# number of tokens equal to max_num_batched_tokens.
# Enable top-k sampling to reflect the accurate memory usage.
sampling_params = SamplingParams(top_p=0.99,
top_k=self.model.config.vocab_size - 1)
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
max_num_seqs = self.scheduler_config.max_num_seqs
seqs = []
for group_id in range(max_num_seqs):
seq_len = (max_num_batched_tokens // max_num_seqs +
(group_id < max_num_batched_tokens % max_num_seqs))
seq_data = SequenceData([0] * seq_len)
seq = SequenceGroupMetadata(
request_id=str(group_id),
is_prompt=True,
seq_data={group_id: seq_data},
sampling_params=sampling_params,
block_tables=None,
)
seqs.append(seq)
input_tokens, input_positions, input_metadata = self._prepare_inputs(seqs)
# Execute the model.
num_layers = self.model_config.get_num_layers(self.parallel_config)
self.model(
input_ids=input_tokens,
positions=input_positions,
kv_caches=[(None, None)] * num_layers,
input_metadata=input_metadata,
cache_events=None,
)
# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
torch.cuda.synchronize()
peak_memory = torch.cuda.max_memory_allocated()
total_gpu_memory = get_gpu_memory()
cache_block_size = CacheEngine.get_cache_block_size(
block_size, self.model_config, self.parallel_config)
num_gpu_blocks = int((total_gpu_memory * gpu_memory_utilization
- peak_memory) // cache_block_size)
num_cpu_blocks = int(cpu_swap_space // cache_block_size)
torch.cuda.empty_cache()
# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
return num_gpu_blocks, num_cpu_blocks
def init_cache_engine(self, cache_config: CacheConfig) -> None:
self.cache_config = cache_config
self.block_size = cache_config.block_size
self.cache_engine = CacheEngine(
self.cache_config, self.model_config, self.parallel_config)
self.cache_events = self.cache_engine.events
self.gpu_cache = self.cache_engine.gpu_cache
def _prepare_inputs(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]:
seq_groups: List[Tuple[List[int], SamplingParams]] = []
input_tokens: List[int] = []
input_positions: List[int] = []
slot_mapping: List[int] = []
# Add prompt tokens.
prompt_lens: List[int] = []
for seq_group_metadata in seq_group_metadata_list:
if not seq_group_metadata.is_prompt:
continue
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
# Use any sequence in the group.
seq_id = seq_ids[0]
seq_data = seq_group_metadata.seq_data[seq_id]
prompt_tokens = seq_data.get_token_ids()
prompt_len = len(prompt_tokens)
prompt_lens.append(prompt_len)
input_tokens.extend(prompt_tokens)
# NOTE(woosuk): Here we assume that the first token in the prompt
# is always the first token in the sequence.
input_positions.extend(range(len(prompt_tokens)))
if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
slot_mapping.extend([0] * prompt_len)
continue
# Compute the slot mapping.
block_table = seq_group_metadata.block_tables[seq_id]
for i in range(prompt_len):
block_number = block_table[i // self.block_size]
block_offset = i % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
# Add generation tokens.
max_context_len = 0
max_num_blocks_per_seq = 0
context_lens: List[int] = []
generation_block_tables: List[List[int]] = []
for seq_group_metadata in seq_group_metadata_list:
if seq_group_metadata.is_prompt:
continue
seq_ids = list(seq_group_metadata.seq_data.keys())
sampling_params = seq_group_metadata.sampling_params
seq_groups.append((seq_ids, sampling_params))
for seq_id in seq_ids:
seq_data = seq_group_metadata.seq_data[seq_id]
generation_token = seq_data.get_last_token_id()
input_tokens.append(generation_token)
context_len = seq_data.get_len()
position = context_len - 1
input_positions.append(position)
block_table = seq_group_metadata.block_tables[seq_id]
generation_block_tables.append(block_table)
max_context_len = max(max_context_len, context_len)
max_num_blocks_per_seq = max(
max_num_blocks_per_seq, len(block_table))
context_lens.append(context_len)
block_number = block_table[position // self.block_size]
block_offset = position % self.block_size
slot = block_number * self.block_size + block_offset
slot_mapping.append(slot)
# Optimization: Pad the input length to be a multiple of 8.
# This is required for utilizing the Tensor Cores in NVIDIA GPUs.
input_tokens = _pad_to_alignment(input_tokens, multiple_of=8)
input_positions = _pad_to_alignment(input_positions, multiple_of=8)
# Convert to tensors.
tokens_tensor = torch.tensor(
input_tokens, dtype=torch.long, device='cuda')
positions_tensor = torch.tensor(
input_positions, dtype=torch.long, device='cuda')
slot_mapping_tensor = torch.tensor(
slot_mapping, dtype=torch.int, device='cuda')
context_lens_tensor = torch.tensor(
context_lens, dtype=torch.int, device='cuda')
padded_block_tables = [
_pad_to_max(block_table, max_num_blocks_per_seq)
for block_table in generation_block_tables]
block_tables_tensor = torch.tensor(
padded_block_tables, dtype=torch.int, device='cuda')
seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list:
seq_data.update(seq_group_metadata.seq_data)
input_metadata = InputMetadata(
seq_groups=seq_groups,
seq_data=seq_data,
prompt_lens=prompt_lens,
slot_mapping=slot_mapping_tensor,
context_lens=context_lens_tensor,
max_context_len=max_context_len,
block_tables=block_tables_tensor,
)
return tokens_tensor, positions_tensor, input_metadata
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> Dict[int, SequenceOutputs]:
# Issue cache operations.
issued_cache_op = False
if blocks_to_swap_in:
self.cache_engine.swap_in(blocks_to_swap_in)
issued_cache_op = True
if blocks_to_swap_out:
self.cache_engine.swap_out(blocks_to_swap_out)
issued_cache_op = True
if blocks_to_copy:
self.cache_engine.copy(blocks_to_copy)
issued_cache_op = True
if issued_cache_op:
cache_events = self.cache_events
else:
cache_events = None
# If there is no input, we don't need to execute the model.
if not seq_group_metadata_list:
if cache_events is not None:
for event in cache_events:
event.wait()
return {}
# Prepare input tensors.
input_tokens, input_positions, input_metadata = self._prepare_inputs(
seq_group_metadata_list)
# Execute the model.
output = self.model(
input_ids=input_tokens,
positions=input_positions,
kv_caches=self.gpu_cache,
input_metadata=input_metadata,
cache_events=cache_events,
)
return output
def _init_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: str,
) -> None:
"""Initialize the distributed environment."""
torch.distributed.init_process_group(
backend="nccl",
world_size=parallel_config.world_size,
rank=rank,
init_method=distributed_init_method,
)
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda())
initialize_model_parallel(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
return x + [0] * ((-len(x)) % multiple_of)
def _pad_to_max(x: List[int], max_len: int) -> List[int]:
return x + [0] * (max_len - len(x))