2023-05-14 22:32:38 -07:00
|
|
|
"""A GPU worker class."""
|
2023-05-10 23:39:12 -07:00
|
|
|
from typing import Dict, List, Optional, Tuple
|
2023-02-22 19:01:38 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
|
2023-05-19 11:35:44 -06:00
|
|
|
from cacheflow.model_executor import (get_model, get_cache_block_size,
|
|
|
|
InputMetadata, set_random_seed)
|
2023-05-09 15:30:12 -07:00
|
|
|
from cacheflow.model_executor.parallel_utils.parallel_state import (
|
|
|
|
initialize_model_parallel,
|
|
|
|
initialize_all_reduce_launcher,
|
|
|
|
get_tensor_model_parallel_world_size)
|
2023-03-10 09:58:21 -08:00
|
|
|
from cacheflow.sampling_params import SamplingParams
|
2023-05-10 23:39:12 -07:00
|
|
|
from cacheflow.sequence import (SequenceData, SequenceGroupMetadata,
|
|
|
|
SequenceOutputs)
|
2023-02-22 19:01:38 +00:00
|
|
|
from cacheflow.worker.cache_engine import CacheEngine
|
2023-05-19 11:35:44 -06:00
|
|
|
from cacheflow.utils import get_gpu_memory
|
2023-02-22 19:01:38 +00:00
|
|
|
|
2023-05-10 00:58:31 -07:00
|
|
|
|
2023-02-22 19:01:38 +00:00
|
|
|
class Worker:
|
2023-05-14 22:32:38 -07:00
|
|
|
"""A worker class that executes (a partition of) the model on a GPU.
|
|
|
|
|
|
|
|
Each worker is associated with a single GPU. The worker is responsible for
|
|
|
|
maintaining the KV cache and executing the model on the GPU. In case of
|
|
|
|
distributed inference, each worker is assigned a partition of the model.
|
|
|
|
"""
|
2023-02-22 19:01:38 +00:00
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
model_name: str,
|
2023-02-23 21:31:39 +00:00
|
|
|
dtype: str,
|
2023-03-10 09:58:21 -08:00
|
|
|
seed: int,
|
2023-03-22 04:45:42 +08:00
|
|
|
distributed_init_method: str,
|
|
|
|
rank: int,
|
|
|
|
world_size: int,
|
2023-05-03 15:32:04 +08:00
|
|
|
cache_dir: Optional[str],
|
2023-04-08 23:36:12 -07:00
|
|
|
use_dummy_weights: bool,
|
2023-05-03 15:32:04 +08:00
|
|
|
use_np_cache: bool,
|
2023-04-05 11:16:57 -07:00
|
|
|
max_num_batched_tokens: int,
|
2023-05-19 11:35:44 -06:00
|
|
|
max_num_sequences: int,
|
2023-03-22 04:45:42 +08:00
|
|
|
tensor_parallel_size: int = 1,
|
|
|
|
pipeline_parallel_size: int = 1,
|
2023-02-22 19:01:38 +00:00
|
|
|
) -> None:
|
2023-03-22 04:45:42 +08:00
|
|
|
self.init_distributed_environment(distributed_init_method,
|
|
|
|
rank,
|
|
|
|
world_size,
|
|
|
|
tensor_parallel_size,
|
|
|
|
pipeline_parallel_size)
|
|
|
|
self.worker_id = rank
|
2023-05-19 11:35:44 -06:00
|
|
|
self.seed = seed
|
|
|
|
set_random_seed(self.seed)
|
2023-02-22 19:01:38 +00:00
|
|
|
|
|
|
|
# Initialize the model.
|
2023-04-08 23:36:12 -07:00
|
|
|
self.model, self.dtype = get_model(
|
2023-05-03 15:32:04 +08:00
|
|
|
model_name, dtype=dtype, cache_dir=cache_dir,
|
|
|
|
use_dummy_weights=use_dummy_weights, use_np_cache=use_np_cache)
|
2023-03-22 04:45:42 +08:00
|
|
|
tensor_model_parallel_world_size = (
|
|
|
|
get_tensor_model_parallel_world_size())
|
2023-05-19 11:35:44 -06:00
|
|
|
self.max_num_batched_tokens = max_num_batched_tokens
|
2023-04-05 11:16:57 -07:00
|
|
|
initialize_all_reduce_launcher(
|
2023-05-19 11:35:44 -06:00
|
|
|
self.max_num_batched_tokens, self.model.config.hidden_size, self.dtype)
|
|
|
|
self.max_num_sequences = max_num_sequences
|
2023-02-22 19:01:38 +00:00
|
|
|
self.num_layers = self.model.config.num_hidden_layers
|
2023-03-22 04:45:42 +08:00
|
|
|
assert self.model.config.num_attention_heads % tensor_model_parallel_world_size == 0
|
|
|
|
self.num_heads = self.model.config.num_attention_heads // tensor_model_parallel_world_size
|
|
|
|
self.head_size = self.model.config.hidden_size // (self.num_heads * tensor_model_parallel_world_size)
|
2023-02-22 19:01:38 +00:00
|
|
|
|
2023-03-22 04:45:42 +08:00
|
|
|
# We reset the seed after initializing the model to ensure that
|
2023-03-10 09:58:21 -08:00
|
|
|
# the random state is not affected by the model initialization.
|
2023-03-22 04:45:42 +08:00
|
|
|
set_random_seed(seed)
|
2023-03-10 09:58:21 -08:00
|
|
|
|
2023-05-19 11:35:44 -06:00
|
|
|
# Uninitialized cache engine. Will be initialized with
|
|
|
|
# self.init_cache_engine().
|
|
|
|
self.block_size = None
|
|
|
|
self.cache_engine = None
|
|
|
|
self.cache_events = None
|
|
|
|
self.gpu_cache = None
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
|
def get_num_available_blocks(
|
|
|
|
self, block_size: int, cpu_swap_space: int,
|
|
|
|
gpu_memory_utilization: float) -> Tuple[int, int]:
|
|
|
|
# Profile the memory usage of the model and get the maximum number of
|
|
|
|
# cache blocks that can be allocated with the remaining free memory.
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
torch.cuda.reset_peak_memory_stats()
|
|
|
|
|
|
|
|
# Profile memory usage with max_num_sequences sequences and the total
|
|
|
|
# number of tokens equal to max_num_batched_tokens.
|
|
|
|
|
|
|
|
# Enable top-k sampling to reflect the accurate memory usage.
|
|
|
|
sampling_params = SamplingParams(top_p=0.99,
|
|
|
|
top_k=self.model.config.vocab_size - 1)
|
|
|
|
seqs = []
|
|
|
|
for group_id in range(self.max_num_sequences):
|
|
|
|
seq_len = (self.max_num_batched_tokens // self.max_num_sequences +
|
|
|
|
(group_id < self.max_num_batched_tokens %
|
|
|
|
self.max_num_sequences))
|
|
|
|
seq_data = SequenceData([0] * seq_len)
|
|
|
|
seq = SequenceGroupMetadata(
|
|
|
|
group_id=group_id,
|
|
|
|
is_prompt=True,
|
|
|
|
seq_data={group_id: seq_data},
|
|
|
|
sampling_params=sampling_params,
|
|
|
|
block_tables=None,
|
|
|
|
)
|
|
|
|
seqs.append(seq)
|
|
|
|
|
|
|
|
input_tokens, input_positions, input_metadata = self.prepare_inputs(seqs)
|
|
|
|
|
|
|
|
# Execute the model.
|
|
|
|
self.model(
|
|
|
|
input_ids=input_tokens,
|
|
|
|
positions=input_positions,
|
|
|
|
kv_caches=[(None, None)] * self.num_layers,
|
|
|
|
input_metadata=input_metadata,
|
|
|
|
cache_events=None,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Calculate the number of blocks that can be allocated with the
|
|
|
|
# profiled peak memory.
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
peak_memory = torch.cuda.max_memory_allocated()
|
|
|
|
total_gpu_memory = get_gpu_memory()
|
|
|
|
cache_block_size = get_cache_block_size(block_size, self.num_heads,
|
|
|
|
self.head_size, self.num_layers,
|
|
|
|
self.dtype)
|
|
|
|
num_gpu_blocks = int((total_gpu_memory * gpu_memory_utilization
|
|
|
|
- peak_memory) // cache_block_size)
|
|
|
|
num_cpu_blocks = int(cpu_swap_space // cache_block_size)
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
# Reset the seed to ensure that the model output is not affected by
|
|
|
|
# the profiling.
|
|
|
|
set_random_seed(self.seed)
|
|
|
|
return num_gpu_blocks, num_cpu_blocks
|
|
|
|
|
|
|
|
def init_cache_engine(self, block_size: int, num_gpu_blocks: int,
|
|
|
|
num_cpu_blocks: int):
|
|
|
|
self.block_size = block_size
|
2023-02-22 19:01:38 +00:00
|
|
|
self.cache_engine = CacheEngine(
|
2023-03-22 04:45:42 +08:00
|
|
|
worker_id=self.worker_id,
|
2023-02-22 19:01:38 +00:00
|
|
|
num_layers=self.num_layers,
|
|
|
|
num_heads=self.num_heads,
|
|
|
|
head_size=self.head_size,
|
2023-05-19 11:35:44 -06:00
|
|
|
block_size=self.block_size,
|
2023-02-22 19:01:38 +00:00
|
|
|
num_gpu_blocks=num_gpu_blocks,
|
|
|
|
num_cpu_blocks=num_cpu_blocks,
|
|
|
|
dtype=self.dtype,
|
|
|
|
)
|
|
|
|
self.cache_events = self.cache_engine.events
|
|
|
|
self.gpu_cache = self.cache_engine.gpu_cache
|
|
|
|
|
2023-03-22 04:45:42 +08:00
|
|
|
def init_distributed_environment(self,
|
|
|
|
distributed_init_method: str,
|
|
|
|
rank: int,
|
|
|
|
world_size: int,
|
|
|
|
tensor_parallel_size: int = 1,
|
|
|
|
pipeline_parallel_size: int = 1) -> None:
|
|
|
|
"""Initialize the distributed environment."""
|
|
|
|
torch.distributed.init_process_group(
|
|
|
|
backend='nccl',
|
|
|
|
init_method=distributed_init_method,
|
|
|
|
world_size=world_size,
|
|
|
|
rank=rank,
|
|
|
|
)
|
|
|
|
# A small all_reduce for warmup.
|
|
|
|
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
|
|
|
initialize_model_parallel(tensor_parallel_size,
|
|
|
|
pipeline_parallel_size)
|
|
|
|
|
2023-02-22 19:01:38 +00:00
|
|
|
def prepare_inputs(
|
|
|
|
self,
|
2023-05-10 00:58:31 -07:00
|
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
2023-02-22 19:01:38 +00:00
|
|
|
) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]:
|
2023-03-10 09:58:21 -08:00
|
|
|
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
2023-02-22 19:01:38 +00:00
|
|
|
input_tokens: List[int] = []
|
|
|
|
input_positions: List[int] = []
|
|
|
|
slot_mapping: List[int] = []
|
|
|
|
|
2023-03-10 09:58:21 -08:00
|
|
|
# Add prompt tokens.
|
|
|
|
prompt_lens: List[int] = []
|
2023-05-10 00:58:31 -07:00
|
|
|
for seq_group_metadata in seq_group_metadata_list:
|
|
|
|
if not seq_group_metadata.is_prompt:
|
2023-03-10 09:58:21 -08:00
|
|
|
continue
|
|
|
|
|
2023-05-10 23:39:12 -07:00
|
|
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
2023-05-10 00:58:31 -07:00
|
|
|
sampling_params = seq_group_metadata.sampling_params
|
2023-03-10 09:58:21 -08:00
|
|
|
seq_groups.append((seq_ids, sampling_params))
|
|
|
|
|
|
|
|
# Use any sequence in the group.
|
|
|
|
seq_id = seq_ids[0]
|
|
|
|
|
2023-05-10 23:39:12 -07:00
|
|
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
|
|
|
prompt_tokens = seq_data.get_token_ids()
|
2023-03-10 09:58:21 -08:00
|
|
|
prompt_len = len(prompt_tokens)
|
2023-02-22 19:01:38 +00:00
|
|
|
prompt_lens.append(prompt_len)
|
|
|
|
|
2023-03-10 09:58:21 -08:00
|
|
|
input_tokens.extend(prompt_tokens)
|
|
|
|
# NOTE(woosuk): Here we assume that the first token in the prompt
|
|
|
|
# is always the first token in the sequence.
|
|
|
|
input_positions.extend(range(len(prompt_tokens)))
|
2023-02-22 19:01:38 +00:00
|
|
|
|
2023-05-19 11:35:44 -06:00
|
|
|
if seq_group_metadata.block_tables is None:
|
|
|
|
# During memory profiling, the block tables are not initialized
|
|
|
|
# yet. In this case, we just use a dummy slot mapping.
|
|
|
|
slot_mapping.extend([0] * prompt_len)
|
|
|
|
continue
|
|
|
|
|
2023-03-10 09:58:21 -08:00
|
|
|
# Compute the slot mapping.
|
2023-05-10 00:58:31 -07:00
|
|
|
block_table = seq_group_metadata.block_tables[seq_id]
|
2023-02-22 19:01:38 +00:00
|
|
|
for i in range(prompt_len):
|
|
|
|
block_number = block_table[i // self.block_size]
|
|
|
|
block_offset = i % self.block_size
|
|
|
|
slot = block_number * self.block_size + block_offset
|
|
|
|
slot_mapping.append(slot)
|
|
|
|
|
2023-03-10 09:58:21 -08:00
|
|
|
# Add generation tokens.
|
2023-02-22 19:01:38 +00:00
|
|
|
max_context_len = 0
|
|
|
|
max_num_blocks_per_seq = 0
|
2023-03-10 09:58:21 -08:00
|
|
|
context_lens: List[int] = []
|
2023-02-22 19:01:38 +00:00
|
|
|
generation_block_tables: List[List[int]] = []
|
2023-05-10 00:58:31 -07:00
|
|
|
for seq_group_metadata in seq_group_metadata_list:
|
|
|
|
if seq_group_metadata.is_prompt:
|
2023-03-10 09:58:21 -08:00
|
|
|
continue
|
|
|
|
|
2023-05-10 23:39:12 -07:00
|
|
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
2023-05-10 00:58:31 -07:00
|
|
|
sampling_params = seq_group_metadata.sampling_params
|
2023-03-10 09:58:21 -08:00
|
|
|
seq_groups.append((seq_ids, sampling_params))
|
|
|
|
|
|
|
|
for seq_id in seq_ids:
|
2023-05-10 23:39:12 -07:00
|
|
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
|
|
|
generation_token = seq_data.get_last_token_id()
|
2023-03-10 09:58:21 -08:00
|
|
|
input_tokens.append(generation_token)
|
|
|
|
|
2023-05-10 23:39:12 -07:00
|
|
|
context_len = seq_data.get_len()
|
|
|
|
position = context_len - 1
|
2023-03-10 09:58:21 -08:00
|
|
|
input_positions.append(position)
|
|
|
|
|
2023-05-10 00:58:31 -07:00
|
|
|
block_table = seq_group_metadata.block_tables[seq_id]
|
2023-03-10 09:58:21 -08:00
|
|
|
generation_block_tables.append(block_table)
|
|
|
|
|
2023-05-10 23:39:12 -07:00
|
|
|
max_context_len = max(max_context_len, context_len)
|
2023-03-10 09:58:21 -08:00
|
|
|
max_num_blocks_per_seq = max(
|
|
|
|
max_num_blocks_per_seq, len(block_table))
|
2023-05-10 23:39:12 -07:00
|
|
|
context_lens.append(context_len)
|
2023-03-10 09:58:21 -08:00
|
|
|
|
|
|
|
block_number = block_table[position // self.block_size]
|
|
|
|
block_offset = position % self.block_size
|
|
|
|
slot = block_number * self.block_size + block_offset
|
|
|
|
slot_mapping.append(slot)
|
2023-02-22 19:01:38 +00:00
|
|
|
|
|
|
|
# Optimization: Pad the input length to be a multiple of 8.
|
|
|
|
# This is required for utilizing the Tensor Cores in NVIDIA GPUs.
|
|
|
|
input_tokens = _pad_to_alignment(input_tokens, multiple_of=8)
|
|
|
|
input_positions = _pad_to_alignment(input_positions, multiple_of=8)
|
|
|
|
|
|
|
|
# Convert to tensors.
|
|
|
|
tokens_tensor = torch.tensor(
|
2023-03-22 04:45:42 +08:00
|
|
|
input_tokens, dtype=torch.long, device='cuda')
|
2023-02-22 19:01:38 +00:00
|
|
|
positions_tensor = torch.tensor(
|
2023-03-22 04:45:42 +08:00
|
|
|
input_positions, dtype=torch.long, device='cuda')
|
2023-02-22 19:01:38 +00:00
|
|
|
slot_mapping_tensor = torch.tensor(
|
2023-03-22 04:45:42 +08:00
|
|
|
slot_mapping, dtype=torch.int, device='cuda')
|
2023-02-22 19:01:38 +00:00
|
|
|
context_lens_tensor = torch.tensor(
|
2023-03-22 04:45:42 +08:00
|
|
|
context_lens, dtype=torch.int, device='cuda')
|
2023-02-23 00:10:07 +00:00
|
|
|
padded_block_tables = [
|
|
|
|
_pad_to_max(block_table, max_num_blocks_per_seq)
|
|
|
|
for block_table in generation_block_tables]
|
2023-02-22 19:01:38 +00:00
|
|
|
block_tables_tensor = torch.tensor(
|
2023-03-22 04:45:42 +08:00
|
|
|
padded_block_tables, dtype=torch.int, device='cuda')
|
2023-02-22 19:01:38 +00:00
|
|
|
|
2023-05-10 23:39:12 -07:00
|
|
|
seq_data: Dict[int, SequenceData] = {}
|
|
|
|
for seq_group_metadata in seq_group_metadata_list:
|
|
|
|
seq_data.update(seq_group_metadata.seq_data)
|
|
|
|
|
2023-02-22 19:01:38 +00:00
|
|
|
input_metadata = InputMetadata(
|
2023-03-10 09:58:21 -08:00
|
|
|
seq_groups=seq_groups,
|
2023-05-10 23:39:12 -07:00
|
|
|
seq_data=seq_data,
|
2023-02-22 19:01:38 +00:00
|
|
|
prompt_lens=prompt_lens,
|
|
|
|
slot_mapping=slot_mapping_tensor,
|
|
|
|
context_lens=context_lens_tensor,
|
|
|
|
max_context_len=max_context_len,
|
|
|
|
block_tables=block_tables_tensor,
|
|
|
|
)
|
|
|
|
return tokens_tensor, positions_tensor, input_metadata
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
|
def execute_stage(
|
|
|
|
self,
|
2023-05-10 00:58:31 -07:00
|
|
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
2023-02-22 19:01:38 +00:00
|
|
|
blocks_to_swap_in: Dict[int, int],
|
|
|
|
blocks_to_swap_out: Dict[int, int],
|
2023-03-10 09:58:21 -08:00
|
|
|
blocks_to_copy: Dict[int, List[int]],
|
|
|
|
) -> Dict[int, SequenceOutputs]:
|
2023-02-22 19:01:38 +00:00
|
|
|
# Issue cache operations.
|
2023-05-10 00:58:31 -07:00
|
|
|
issued_cache_op = False
|
2023-02-22 19:01:38 +00:00
|
|
|
if blocks_to_swap_in:
|
|
|
|
self.cache_engine.swap_in(blocks_to_swap_in)
|
2023-05-10 00:58:31 -07:00
|
|
|
issued_cache_op = True
|
2023-02-22 19:01:38 +00:00
|
|
|
if blocks_to_swap_out:
|
|
|
|
self.cache_engine.swap_out(blocks_to_swap_out)
|
2023-05-10 00:58:31 -07:00
|
|
|
issued_cache_op = True
|
2023-02-22 19:01:38 +00:00
|
|
|
if blocks_to_copy:
|
|
|
|
self.cache_engine.copy(blocks_to_copy)
|
2023-05-10 00:58:31 -07:00
|
|
|
issued_cache_op = True
|
2023-02-22 19:01:38 +00:00
|
|
|
|
2023-05-10 00:58:31 -07:00
|
|
|
if issued_cache_op:
|
2023-02-22 19:01:38 +00:00
|
|
|
cache_events = self.cache_events
|
|
|
|
else:
|
|
|
|
cache_events = None
|
|
|
|
|
2023-03-13 13:48:38 -07:00
|
|
|
# If there is no input, we don't need to execute the model.
|
2023-05-10 00:58:31 -07:00
|
|
|
if not seq_group_metadata_list:
|
2023-03-13 13:48:38 -07:00
|
|
|
if cache_events is not None:
|
|
|
|
for event in cache_events:
|
|
|
|
event.wait()
|
|
|
|
return {}
|
|
|
|
|
2023-02-22 19:01:38 +00:00
|
|
|
# Prepare input tensors.
|
|
|
|
input_tokens, input_positions, input_metadata = self.prepare_inputs(
|
2023-05-10 00:58:31 -07:00
|
|
|
seq_group_metadata_list)
|
2023-02-22 19:01:38 +00:00
|
|
|
|
|
|
|
# Execute the model.
|
|
|
|
output = self.model(
|
|
|
|
input_ids=input_tokens,
|
|
|
|
positions=input_positions,
|
2023-02-23 20:23:47 +00:00
|
|
|
kv_caches=self.gpu_cache,
|
2023-02-22 19:01:38 +00:00
|
|
|
input_metadata=input_metadata,
|
|
|
|
cache_events=cache_events,
|
|
|
|
)
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
|
|
def _pad_to_alignment(x: List[int], multiple_of: int) -> List[int]:
|
|
|
|
return x + [0] * ((-len(x)) % multiple_of)
|
|
|
|
|
|
|
|
|
|
|
|
def _pad_to_max(x: List[int], max_len: int) -> List[int]:
|
|
|
|
return x + [0] * (max_len - len(x))
|