[Neuron] Support inference with transformers-neuronx (#2569)

This commit is contained in:
Liangfu Chen 2024-02-28 09:34:34 -08:00 committed by GitHub
parent e46fa5d52e
commit 3b7178cfa4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 516 additions and 42 deletions

View File

@ -0,0 +1,33 @@
from vllm import LLM, SamplingParams
# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
llm = LLM(
model="openlm-research/open_llama_3b",
max_num_seqs=8,
# The max_model_len and block_size arguments are required to be same as max sequence length,
# when targeting neuron device. Currently, this is a known limitation in continuous batching
# support in transformers-neuronx.
# TODO(liangfu): Support paged-attention in transformers-neuronx.
max_model_len=128,
block_size=128,
# The device can be automatically detected when AWS Neuron SDK is installed.
# The device argument can be either unspecified for automated detection, or explicitly assigned.
device="neuron")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

View File

@ -131,9 +131,11 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module:
cleanup()
get_model_old = get_model
def get_model_patched(model_config, device_config, lora_config=None):
return get_model_old(model_config, device_config,
LoRAConfig(max_loras=4, max_lora_rank=8))
def get_model_patched(model_config, device_config, **kwargs):
return get_model_old(model_config,
device_config,
lora_config=LoRAConfig(max_loras=4,
max_lora_rank=8))
with patch("vllm.worker.model_runner.get_model", get_model_patched):
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)

View File

@ -8,7 +8,7 @@ from transformers import PretrainedConfig
from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config
from vllm.utils import get_cpu_memory, is_hip, get_nvcc_cuda_version
from vllm.utils import get_cpu_memory, is_hip, is_neuron, get_nvcc_cuda_version
logger = init_logger(__name__)
@ -380,13 +380,21 @@ class ParallelConfig:
disable_custom_all_reduce: bool = False,
) -> None:
self.pipeline_parallel_size = pipeline_parallel_size
self.tensor_parallel_size = tensor_parallel_size
if is_neuron():
# For Neuron device support, here we assign TP=1 to avoid sharding within vLLM directly.
# Transformer-neuronx would take neuron_tp_degree attribute, and distribute the workload
# to multiple NeuronCores.
self.tensor_parallel_size = 1
self.neuron_tp_degree = tensor_parallel_size
else:
self.tensor_parallel_size = tensor_parallel_size
self.worker_use_ray = worker_use_ray
self.max_parallel_loading_workers = max_parallel_loading_workers
self.disable_custom_all_reduce = disable_custom_all_reduce
self.world_size = pipeline_parallel_size * tensor_parallel_size
if self.world_size > 1:
self.world_size = pipeline_parallel_size * self.tensor_parallel_size
# Ray worker is not supported for Neuron backend.
if self.world_size > 1 and not is_neuron():
self.worker_use_ray = True
self._verify_args()
@ -465,8 +473,29 @@ class SchedulerConfig:
class DeviceConfig:
def __init__(self, device: str = "cuda") -> None:
self.device = torch.device(device)
def __init__(self, device: str = "auto") -> None:
if device == "auto":
# Automated device type detection
if torch.cuda.is_available():
self.device_type = "cuda"
elif is_neuron():
self.device_type = "neuron"
else:
raise RuntimeError("No supported device detected.")
else:
# Device type is assigned explicitly
self.device_type = device
# Some device types require processing inputs on CPU
if self.device_type in ["neuron"]:
self.device = torch.device("cpu")
else:
# Set device with device type
self.device = torch.device(self.device_type)
@property
def is_neuron(self):
return self.device_type == "neuron"
@dataclass

View File

@ -44,7 +44,7 @@ class EngineArgs:
lora_extra_vocab_size: int = 256
lora_dtype = 'auto'
max_cpu_loras: Optional[int] = None
device: str = 'cuda'
device: str = 'auto'
def __post_init__(self):
if self.tokenizer is None:
@ -171,7 +171,7 @@ class EngineArgs:
parser.add_argument('--block-size',
type=int,
default=EngineArgs.block_size,
choices=[8, 16, 32],
choices=[8, 16, 32, 128],
help='token block size')
parser.add_argument('--seed',
type=int,
@ -264,13 +264,11 @@ class EngineArgs:
help=('Maximum number of LoRAs to store in CPU memory. '
'Must be >= than max_num_seqs. '
'Defaults to max_num_seqs.'))
parser.add_argument(
"--device",
type=str,
default=EngineArgs.device,
choices=["cuda"],
help=('Device type for vLLM execution. '
'Currently, only CUDA-compatible devices are supported.'))
parser.add_argument("--device",
type=str,
default=EngineArgs.device,
choices=["auto", "cuda", "neuron"],
help='Device type for vLLM execution.')
return parser
@classmethod

View File

@ -3,6 +3,7 @@ from collections import defaultdict
import os
import time
import pickle
import importlib
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
Union)
@ -20,7 +21,8 @@ from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceGroupOutput, SequenceOutput, SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
TokenizerGroup)
from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port, get_distributed_init_method
from vllm.utils import (Counter, set_cuda_visible_devices, get_ip,
get_open_port, get_distributed_init_method)
if ray:
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
@ -31,6 +33,12 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5
# A map between the device type (in device config) to its worker module.
DEVICE_TO_WORKER_MODULE_MAP = {
"cuda": "vllm.worker.worker",
"neuron": "vllm.worker.neuron_worker",
}
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
@ -138,10 +146,17 @@ class LLMEngine:
def get_tokenizer_for_seq(self, sequence: Sequence):
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
def _dispatch_worker(self):
worker_module = DEVICE_TO_WORKER_MODULE_MAP[
self.device_config.device_type]
imported_worker = importlib.import_module(worker_module)
Worker = imported_worker.Worker
return Worker
def _init_workers(self):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker
Worker = self._dispatch_worker()
assert self.parallel_config.world_size == 1, (
"Ray is required if parallel_config.world_size > 1.")
@ -243,7 +258,7 @@ class LLMEngine:
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker
Worker = self._dispatch_worker()
# Initialize torch distributed process group for the workers.
model_config = copy.deepcopy(self.model_config)

View File

@ -795,6 +795,10 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
self.dtype = dtype
self.device = device
@property
def logits_as_hidden_states(self):
return self.base_layer.logits_as_hidden_states
@property
def vocab_size(self):
return self.base_layer.vocab_size

View File

@ -1,7 +1,6 @@
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed
from vllm.model_executor.utils import set_random_seed, get_model
__all__ = [
"InputMetadata",

View File

@ -10,6 +10,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTens
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
SequenceData, SequenceGroupOutput, SequenceOutput)
from vllm.utils import is_neuron
class Sampler(nn.Module):
@ -32,6 +33,8 @@ class Sampler(nn.Module):
org_vocab_size: Optional[int] = None) -> None:
super().__init__()
self.vocab_size = vocab_size
# Transformers-neuronx generate outputs as logits directly.
self.logits_as_hidden_states = is_neuron()
# original vocabulary size (without LoRA).
self.org_vocab_size = org_vocab_size or vocab_size
@ -55,10 +58,14 @@ class Sampler(nn.Module):
embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[SamplerOutput]:
# Get the hidden states that we use for sampling.
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
if self.logits_as_hidden_states:
logits = hidden_states
else:
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)
# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, embedding, embedding_bias)
# Get the logits for the next tokens.
logits = self._get_logits(hidden_states, embedding, embedding_bias)
# Only perform sampling in the driver worker.
# Note: `_get_logits` is still distributed across TP workers because
@ -395,7 +402,8 @@ def _sample(
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
is_prompts, sample_indices)
if sampling_type == SamplingType.GREEDY:
greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1)
greedy_samples = torch.argmax(logprobs[sample_indices.long()],
dim=-1)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
max_best_of = 1
for seq_group, is_prompt in zip(seq_groups, is_prompts):
@ -407,7 +415,7 @@ def _sample(
"generators": sampling_metadata.generators,
}
multinomial_samples[sampling_type] = _multinomial(
probs[sample_indices], max_best_of, **seeded_args)
probs[sample_indices.long()], max_best_of, **seeded_args)
elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices]
else:

View File

@ -1,11 +1,11 @@
"""Utilities for selecting and loading models."""
import contextlib
from typing import Optional, Type
from typing import Type
import torch
import torch.nn as nn
from vllm.config import DeviceConfig, ModelConfig, LoRAConfig
from vllm.config import DeviceConfig, ModelConfig
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights)
@ -37,9 +37,9 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def get_model(model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig] = None) -> nn.Module:
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
**kwargs) -> nn.Module:
lora_config = kwargs.get("lora_config", None)
model_class = _get_model_architecture(model_config)
# Get the (maybe quantized) linear method.

View File

@ -4,7 +4,7 @@ from typing import List, Optional, Type
import torch.nn as nn
from vllm.logger import init_logger
from vllm.utils import is_hip
from vllm.utils import is_hip, is_neuron
logger = init_logger(__name__)
@ -61,6 +61,9 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS = {
"Sliding window attention is not yet supported in ROCm's flash attention",
}
# Models not supported by Neuron.
_NEURON_SUPPORTED_MODELS = {"LlamaForCausalLM": "neuron.llama"}
class ModelRegistry:
@ -77,8 +80,15 @@ class ModelRegistry:
logger.warning(
f"Model architecture {model_arch} is partially supported "
"by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
elif is_neuron():
if model_arch not in _NEURON_SUPPORTED_MODELS:
raise ValueError(
f"Model architecture {model_arch} is not supported by "
"Neuron for now.")
module_name, model_cls_name = _MODELS[model_arch]
if is_neuron():
module_name = _NEURON_SUPPORTED_MODELS[model_arch]
module = importlib.import_module(
f"vllm.model_executor.models.{module_name}")
return getattr(module, model_cls_name, None)

View File

@ -0,0 +1,79 @@
"""Inference-only LLaMA model compatible with HuggingFace weights."""
import os
from typing import List, Optional, Tuple
import torch
from torch import nn
from transformers import LlamaConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class LlamaForCausalLM(nn.Module):
def __init__(
self,
config: LlamaConfig,
linear_method=None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = None
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
with torch.inference_mode():
block_size = self.model.context_buckets[-1]
if input_metadata.is_prompt:
seq_ids = input_metadata.slot_mapping[:, 0] // block_size
else:
seq_ids = input_metadata.block_tables
logits = self.model(input_ids,
cache_ids=positions,
start_ids=seq_ids.flatten())
return logits
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.model.chkpt_model.lm_head,
hidden_states, sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None,
**kwargs):
from transformers_neuronx.llama.model import LlamaForSampling
split_model_dir = f"{model_name_or_path}-split"
if os.path.isdir(os.path.join(model_name_or_path,
"pytorch_model.bin")):
split_model_dir = model_name_or_path
elif not os.path.exists(f"{model_name_or_path}-split"):
from transformers.models.llama import LlamaForCausalLM
from transformers_neuronx.module import save_pretrained_split
hf_model = LlamaForCausalLM.from_pretrained(model_name_or_path,
low_cpu_mem_usage=True)
save_pretrained_split(hf_model, f"{model_name_or_path}-split")
self.model = LlamaForSampling.from_pretrained(split_model_dir,
**kwargs)
self.model.to_neuron()

View File

@ -0,0 +1,66 @@
"""Utilities for selecting and loading models."""
from typing import Type
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import ModelConfig, DeviceConfig
from vllm.model_executor.models import ModelRegistry
TORCH_DTYPE_TO_NEURON_AMP = {
"auto": "f32",
"half": "f16",
"float16": "f16",
"bfloat16": "bf16",
"float": "f32",
"float32": "f32",
torch.float16: "f16",
torch.bfloat16: "bf16",
torch.float32: "f32",
}
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", [])
for arch in architectures:
model_cls = ModelRegistry.load_model_cls(arch)
if model_cls is not None:
return model_cls
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
**kwargs) -> nn.Module:
from transformers_neuronx.config import NeuronConfig, ContinuousBatchingConfig
parallel_config = kwargs.get("parallel_config")
scheduler_config = kwargs.get("scheduler_config")
model_class = _get_model_architecture(model_config.hf_config)
linear_method = None
# Create a model instance.
model = model_class(model_config.hf_config, linear_method)
continuous_batching_config = ContinuousBatchingConfig(
batch_size_for_shared_caches=scheduler_config.max_num_seqs)
neuron_config = NeuronConfig(
continuous_batching=continuous_batching_config)
# Load the weights from the cached or downloaded files.
model.load_weights(
model_config.model,
model_config.download_dir,
model_config.load_format,
model_config.revision,
tp_degree=parallel_config.neuron_tp_degree,
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
neuron_config=neuron_config,
context_length_estimate=[scheduler_config.max_model_len],
n_positions=[scheduler_config.max_model_len],
batch_size=scheduler_config.max_num_seqs)
return model.eval()

View File

@ -5,7 +5,7 @@ import torch
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData
from vllm.utils import in_wsl
from vllm.utils import in_wsl, is_neuron
_SAMPLING_EPS = 1e-5
@ -155,7 +155,7 @@ class SamplingTensors:
dtype: torch.dtype) -> "SamplingTensors":
# Note that the performance will be very bad without
# pinned memory.
pin_memory = not in_wsl()
pin_memory = not in_wsl() and not is_neuron()
prompt_max_len = max(len(tokens) for tokens in prompt_tokens)
prompt_padded_tokens = [
tokens + [vocab_size] * (prompt_max_len - len(tokens))

View File

@ -1,10 +1,18 @@
"""Utils for model executor."""
import random
import importlib
from typing import Any, Dict, Optional
import numpy as np
import torch
from vllm.config import DeviceConfig, ModelConfig
DEVICE_TO_MODEL_LOADER_MAP = {
"cuda": "model_loader",
"neuron": "neuron_model_loader",
}
def set_random_seed(seed: int) -> None:
random.seed(seed)
@ -33,3 +41,12 @@ def set_weight_attrs(
assert not hasattr(
weight, key), (f"Overwriting existing tensor attribute: {key}")
setattr(weight, key, value)
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
**kwargs) -> torch.nn.Module:
model_loader_module = DEVICE_TO_MODEL_LOADER_MAP[device_config.device_type]
imported_model_loader = importlib.import_module(
f"vllm.model_executor.{model_loader_module}")
get_model_fn = imported_model_loader.get_model
return get_model_fn(model_config, device_config, **kwargs)

View File

@ -118,6 +118,14 @@ def is_hip() -> bool:
return torch.version.hip is not None
def is_neuron() -> bool:
try:
import transformers_neuronx
except ImportError:
transformers_neuronx = None
return transformers_neuronx is not None
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes."""
# NOTE: This import statement should be executed lazily since

View File

@ -3,10 +3,9 @@ from typing import Dict, List, Tuple
import torch
from vllm._C import cache_ops
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
from vllm.logger import init_logger
from vllm.utils import in_wsl, STR_DTYPE_TO_TORCH_DTYPE
from vllm.utils import in_wsl, is_neuron, STR_DTYPE_TO_TORCH_DTYPE
logger = init_logger(__name__)
@ -39,6 +38,10 @@ class CacheEngine:
self.num_gpu_blocks = cache_config.num_gpu_blocks
self.num_cpu_blocks = cache_config.num_cpu_blocks
# Skip initializing CUDA stream and buffer for Neuron backend.
if is_neuron():
return
if cache_config.cache_dtype == "auto":
self.dtype = model_config.dtype
else:
@ -121,6 +124,8 @@ class CacheEngine:
dst: List[KVCache],
src_to_dst: Dict[int, int],
) -> None:
from vllm._C import cache_ops
with torch.cuda.stream(self.cache_stream):
for i in range(self.num_layers):
src_key_cache, src_value_cache = src[i]
@ -140,6 +145,8 @@ class CacheEngine:
self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
from vllm._C import cache_ops
key_caches = [key_cache for key_cache, _ in self.gpu_cache]
value_caches = [value_cache for _, value_cache in self.gpu_cache]
# NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.

View File

@ -80,9 +80,16 @@ class ModelRunner:
self.in_wsl = in_wsl()
self.kv_cache_dtype = kv_cache_dtype
# Set enforce_eager to True for Neuron backend, to avoid capturing graph
if self.device_config.is_neuron:
self.model_config.enforce_eager = True
def load_model(self) -> None:
self.model = get_model(self.model_config, self.device_config,
self.lora_config)
self.model = get_model(self.model_config,
self.device_config,
lora_config=self.lora_config,
parallel_config=self.parallel_config,
scheduler_config=self.scheduler_config)
vocab_size = self.model.config.vocab_size
@ -393,6 +400,7 @@ class ModelRunner:
selected_token_start_idx = 0
categorized_sample_indices = {t: [] for t in SamplingType}
categorized_sample_indices_start_idx = 0
pin_memory = not self.in_wsl and not self.device_config.is_neuron
max_subquery_len = max(subquery_lens) if subquery_lens else 1
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
@ -443,12 +451,12 @@ class ModelRunner:
selected_token_indices = _async_h2d(selected_token_indices,
dtype=torch.long,
target_device=self.device,
pin_memory=not self.in_wsl)
pin_memory=pin_memory)
categorized_sample_indices = {
t: _async_h2d(seq_ids,
dtype=torch.int,
target_device=self.device,
pin_memory=not self.in_wsl)
pin_memory=pin_memory)
for t, seq_ids in categorized_sample_indices.items()
}

View File

@ -0,0 +1,191 @@
"""A Neuron worker class."""
from typing import Dict, List, Optional, Tuple
import torch
import torch.distributed
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, LoRAConfig)
from vllm.model_executor import set_random_seed
from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict)
from vllm.model_executor.parallel_utils.parallel_state import (
ensure_model_parallel_initialized)
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
class Worker:
"""A worker class that executes the model on a group of neuron cores.
"""
def __init__(
self,
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
lora_config: Optional[LoRAConfig] = None,
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
) -> None:
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
self.lora_config = lora_config
self.is_driver_worker = is_driver_worker
if self.is_driver_worker:
assert self.rank == 0, "The driver worker must have rank 0."
self.model_runner = ModelRunner(model_config,
parallel_config,
scheduler_config,
device_config,
lora_config=self.lora_config,
is_driver_worker=is_driver_worker)
# Uninitialized cache engine. Will be initialized by
# self.init_cache_engine().
self.cache_config = None
self.cache_engine = None
self.cache_events = None
self.gpu_cache = None
def init_model(self) -> None:
# Initialize the distributed environment.
_init_distributed_environment(self.parallel_config,
self.rank,
self.distributed_init_method,
distributed_backend="gloo")
# Initialize the model.
set_random_seed(self.model_config.seed)
def load_model(self):
self.model_runner.load_model()
@torch.inference_mode()
def profile_num_available_blocks(
self,
block_size: int = 128,
gpu_memory_utilization: float = 0.9,
cpu_swap_space: int = 0,
cache_dtype: str = "float16",
) -> Tuple[int, int]:
"""Simply returns max_num_seqs as num_gpu_blocks, 0 as num_cpu_blocks."""
num_gpu_blocks = self.scheduler_config.max_num_seqs
num_cpu_blocks = 0
return num_gpu_blocks, num_cpu_blocks
def init_cache_engine(self, cache_config: CacheConfig) -> None:
self.cache_config = cache_config
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.parallel_config)
self.model_runner.set_block_size(self.cache_engine.block_size)
def warm_up_model(self) -> None:
# Warm up is maintained in transformers-neuronx
pass
def cache_swap(
self,
blocks_to_swap_in: Dict[int, int],
blocks_to_swap_out: Dict[int, int],
blocks_to_copy: Dict[int, List[int]],
) -> None:
# Issue cache operations.
issued_cache_op = False
if blocks_to_swap_in:
self.cache_engine.swap_in(blocks_to_swap_in)
issued_cache_op = True
if blocks_to_swap_out:
self.cache_engine.swap_out(blocks_to_swap_out)
issued_cache_op = True
if blocks_to_copy:
self.cache_engine.copy(blocks_to_copy)
issued_cache_op = True
cache_events = self.cache_events if issued_cache_op else None
# Wait for cache operations to finish.
if cache_events is not None:
raise NotImplementedError(
"cache operations are not implemented for neuron backend.")
@torch.inference_mode()
def execute_model(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
blocks_to_swap_in: Optional[Dict[int, int]] = None,
blocks_to_swap_out: Optional[Dict[int, int]] = None,
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
) -> Optional[SamplerOutput]:
if self.is_driver_worker:
assert seq_group_metadata_list is not None
num_seq_groups = len(seq_group_metadata_list)
assert blocks_to_swap_in is not None
assert blocks_to_swap_out is not None
assert blocks_to_copy is not None
data = {
"num_seq_groups": num_seq_groups,
"blocks_to_swap_in": blocks_to_swap_in,
"blocks_to_swap_out": blocks_to_swap_out,
"blocks_to_copy": blocks_to_copy,
}
broadcast_tensor_dict(data, src=0)
else:
data = broadcast_tensor_dict(src=0)
num_seq_groups = data["num_seq_groups"]
blocks_to_swap_in = data["blocks_to_swap_in"]
blocks_to_swap_out = data["blocks_to_swap_out"]
blocks_to_copy = data["blocks_to_copy"]
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
# If there is no input, we don't need to execute the model.
if num_seq_groups == 0:
return {}
output = self.model_runner.execute_model(seq_group_metadata_list,
self.gpu_cache)
return output
def _init_distributed_environment(
parallel_config: ParallelConfig,
rank: int,
distributed_init_method: Optional[str] = None,
distributed_backend: Optional[str] = None,
) -> None:
"""Initialize the distributed environment."""
if torch.distributed.is_initialized():
torch_world_size = torch.distributed.get_world_size()
if torch_world_size != parallel_config.world_size:
raise RuntimeError(
"torch.distributed is already initialized but the torch world "
"size does not match parallel_config.world_size "
f"({torch_world_size} vs. {parallel_config.world_size}).")
elif not distributed_init_method:
raise ValueError(
"distributed_init_method must be set if torch.distributed "
"is not already initialized")
else:
distributed_backend = distributed_backend if distributed_backend else "nccl"
torch.distributed.init_process_group(
backend=distributed_backend,
world_size=parallel_config.world_size,
rank=rank,
init_method=distributed_init_method,
)
# A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1))
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)