[Neuron] Support inference with transformers-neuronx (#2569)
This commit is contained in:
parent
e46fa5d52e
commit
3b7178cfa4
33
examples/offline_inference_neuron.py
Normal file
33
examples/offline_inference_neuron.py
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
# Sample prompts.
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
# Create a sampling params object.
|
||||||
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||||
|
|
||||||
|
# Create an LLM.
|
||||||
|
llm = LLM(
|
||||||
|
model="openlm-research/open_llama_3b",
|
||||||
|
max_num_seqs=8,
|
||||||
|
# The max_model_len and block_size arguments are required to be same as max sequence length,
|
||||||
|
# when targeting neuron device. Currently, this is a known limitation in continuous batching
|
||||||
|
# support in transformers-neuronx.
|
||||||
|
# TODO(liangfu): Support paged-attention in transformers-neuronx.
|
||||||
|
max_model_len=128,
|
||||||
|
block_size=128,
|
||||||
|
# The device can be automatically detected when AWS Neuron SDK is installed.
|
||||||
|
# The device argument can be either unspecified for automated detection, or explicitly assigned.
|
||||||
|
device="neuron")
|
||||||
|
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||||
|
# that contain the prompt, generated text, and other information.
|
||||||
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
# Print the outputs.
|
||||||
|
for output in outputs:
|
||||||
|
prompt = output.prompt
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
@ -131,9 +131,11 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module:
|
|||||||
cleanup()
|
cleanup()
|
||||||
get_model_old = get_model
|
get_model_old = get_model
|
||||||
|
|
||||||
def get_model_patched(model_config, device_config, lora_config=None):
|
def get_model_patched(model_config, device_config, **kwargs):
|
||||||
return get_model_old(model_config, device_config,
|
return get_model_old(model_config,
|
||||||
LoRAConfig(max_loras=4, max_lora_rank=8))
|
device_config,
|
||||||
|
lora_config=LoRAConfig(max_loras=4,
|
||||||
|
max_lora_rank=8))
|
||||||
|
|
||||||
with patch("vllm.worker.model_runner.get_model", get_model_patched):
|
with patch("vllm.worker.model_runner.get_model", get_model_patched):
|
||||||
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
|
engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False)
|
||||||
|
@ -8,7 +8,7 @@ from transformers import PretrainedConfig
|
|||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.config import get_config
|
from vllm.transformers_utils.config import get_config
|
||||||
from vllm.utils import get_cpu_memory, is_hip, get_nvcc_cuda_version
|
from vllm.utils import get_cpu_memory, is_hip, is_neuron, get_nvcc_cuda_version
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -380,13 +380,21 @@ class ParallelConfig:
|
|||||||
disable_custom_all_reduce: bool = False,
|
disable_custom_all_reduce: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.pipeline_parallel_size = pipeline_parallel_size
|
self.pipeline_parallel_size = pipeline_parallel_size
|
||||||
self.tensor_parallel_size = tensor_parallel_size
|
if is_neuron():
|
||||||
|
# For Neuron device support, here we assign TP=1 to avoid sharding within vLLM directly.
|
||||||
|
# Transformer-neuronx would take neuron_tp_degree attribute, and distribute the workload
|
||||||
|
# to multiple NeuronCores.
|
||||||
|
self.tensor_parallel_size = 1
|
||||||
|
self.neuron_tp_degree = tensor_parallel_size
|
||||||
|
else:
|
||||||
|
self.tensor_parallel_size = tensor_parallel_size
|
||||||
self.worker_use_ray = worker_use_ray
|
self.worker_use_ray = worker_use_ray
|
||||||
self.max_parallel_loading_workers = max_parallel_loading_workers
|
self.max_parallel_loading_workers = max_parallel_loading_workers
|
||||||
self.disable_custom_all_reduce = disable_custom_all_reduce
|
self.disable_custom_all_reduce = disable_custom_all_reduce
|
||||||
|
|
||||||
self.world_size = pipeline_parallel_size * tensor_parallel_size
|
self.world_size = pipeline_parallel_size * self.tensor_parallel_size
|
||||||
if self.world_size > 1:
|
# Ray worker is not supported for Neuron backend.
|
||||||
|
if self.world_size > 1 and not is_neuron():
|
||||||
self.worker_use_ray = True
|
self.worker_use_ray = True
|
||||||
self._verify_args()
|
self._verify_args()
|
||||||
|
|
||||||
@ -465,8 +473,29 @@ class SchedulerConfig:
|
|||||||
|
|
||||||
class DeviceConfig:
|
class DeviceConfig:
|
||||||
|
|
||||||
def __init__(self, device: str = "cuda") -> None:
|
def __init__(self, device: str = "auto") -> None:
|
||||||
self.device = torch.device(device)
|
if device == "auto":
|
||||||
|
# Automated device type detection
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
self.device_type = "cuda"
|
||||||
|
elif is_neuron():
|
||||||
|
self.device_type = "neuron"
|
||||||
|
else:
|
||||||
|
raise RuntimeError("No supported device detected.")
|
||||||
|
else:
|
||||||
|
# Device type is assigned explicitly
|
||||||
|
self.device_type = device
|
||||||
|
|
||||||
|
# Some device types require processing inputs on CPU
|
||||||
|
if self.device_type in ["neuron"]:
|
||||||
|
self.device = torch.device("cpu")
|
||||||
|
else:
|
||||||
|
# Set device with device type
|
||||||
|
self.device = torch.device(self.device_type)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_neuron(self):
|
||||||
|
return self.device_type == "neuron"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -44,7 +44,7 @@ class EngineArgs:
|
|||||||
lora_extra_vocab_size: int = 256
|
lora_extra_vocab_size: int = 256
|
||||||
lora_dtype = 'auto'
|
lora_dtype = 'auto'
|
||||||
max_cpu_loras: Optional[int] = None
|
max_cpu_loras: Optional[int] = None
|
||||||
device: str = 'cuda'
|
device: str = 'auto'
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.tokenizer is None:
|
if self.tokenizer is None:
|
||||||
@ -171,7 +171,7 @@ class EngineArgs:
|
|||||||
parser.add_argument('--block-size',
|
parser.add_argument('--block-size',
|
||||||
type=int,
|
type=int,
|
||||||
default=EngineArgs.block_size,
|
default=EngineArgs.block_size,
|
||||||
choices=[8, 16, 32],
|
choices=[8, 16, 32, 128],
|
||||||
help='token block size')
|
help='token block size')
|
||||||
parser.add_argument('--seed',
|
parser.add_argument('--seed',
|
||||||
type=int,
|
type=int,
|
||||||
@ -264,13 +264,11 @@ class EngineArgs:
|
|||||||
help=('Maximum number of LoRAs to store in CPU memory. '
|
help=('Maximum number of LoRAs to store in CPU memory. '
|
||||||
'Must be >= than max_num_seqs. '
|
'Must be >= than max_num_seqs. '
|
||||||
'Defaults to max_num_seqs.'))
|
'Defaults to max_num_seqs.'))
|
||||||
parser.add_argument(
|
parser.add_argument("--device",
|
||||||
"--device",
|
type=str,
|
||||||
type=str,
|
default=EngineArgs.device,
|
||||||
default=EngineArgs.device,
|
choices=["auto", "cuda", "neuron"],
|
||||||
choices=["cuda"],
|
help='Device type for vLLM execution.')
|
||||||
help=('Device type for vLLM execution. '
|
|
||||||
'Currently, only CUDA-compatible devices are supported.'))
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -3,6 +3,7 @@ from collections import defaultdict
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import pickle
|
import pickle
|
||||||
|
import importlib
|
||||||
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
|
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
|
||||||
Union)
|
Union)
|
||||||
|
|
||||||
@ -20,7 +21,8 @@ from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
|
|||||||
SequenceGroupOutput, SequenceOutput, SequenceStatus)
|
SequenceGroupOutput, SequenceOutput, SequenceStatus)
|
||||||
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
|
||||||
TokenizerGroup)
|
TokenizerGroup)
|
||||||
from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port, get_distributed_init_method
|
from vllm.utils import (Counter, set_cuda_visible_devices, get_ip,
|
||||||
|
get_open_port, get_distributed_init_method)
|
||||||
|
|
||||||
if ray:
|
if ray:
|
||||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||||
@ -31,6 +33,12 @@ if TYPE_CHECKING:
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
||||||
|
|
||||||
|
# A map between the device type (in device config) to its worker module.
|
||||||
|
DEVICE_TO_WORKER_MODULE_MAP = {
|
||||||
|
"cuda": "vllm.worker.worker",
|
||||||
|
"neuron": "vllm.worker.neuron_worker",
|
||||||
|
}
|
||||||
|
|
||||||
# If the env var is set, it uses the Ray's compiled DAG API
|
# If the env var is set, it uses the Ray's compiled DAG API
|
||||||
# which optimizes the control plane overhead.
|
# which optimizes the control plane overhead.
|
||||||
# Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
# Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
||||||
@ -138,10 +146,17 @@ class LLMEngine:
|
|||||||
def get_tokenizer_for_seq(self, sequence: Sequence):
|
def get_tokenizer_for_seq(self, sequence: Sequence):
|
||||||
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
|
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
|
||||||
|
|
||||||
|
def _dispatch_worker(self):
|
||||||
|
worker_module = DEVICE_TO_WORKER_MODULE_MAP[
|
||||||
|
self.device_config.device_type]
|
||||||
|
imported_worker = importlib.import_module(worker_module)
|
||||||
|
Worker = imported_worker.Worker
|
||||||
|
return Worker
|
||||||
|
|
||||||
def _init_workers(self):
|
def _init_workers(self):
|
||||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||||
from vllm.worker.worker import Worker
|
Worker = self._dispatch_worker()
|
||||||
|
|
||||||
assert self.parallel_config.world_size == 1, (
|
assert self.parallel_config.world_size == 1, (
|
||||||
"Ray is required if parallel_config.world_size > 1.")
|
"Ray is required if parallel_config.world_size > 1.")
|
||||||
@ -243,7 +258,7 @@ class LLMEngine:
|
|||||||
|
|
||||||
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
# Lazy import the Worker to avoid importing torch.cuda/xformers
|
||||||
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
# before CUDA_VISIBLE_DEVICES is set in the Worker
|
||||||
from vllm.worker.worker import Worker
|
Worker = self._dispatch_worker()
|
||||||
|
|
||||||
# Initialize torch distributed process group for the workers.
|
# Initialize torch distributed process group for the workers.
|
||||||
model_config = copy.deepcopy(self.model_config)
|
model_config = copy.deepcopy(self.model_config)
|
||||||
|
@ -795,6 +795,10 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
|
|||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logits_as_hidden_states(self):
|
||||||
|
return self.base_layer.logits_as_hidden_states
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vocab_size(self):
|
def vocab_size(self):
|
||||||
return self.base_layer.vocab_size
|
return self.base_layer.vocab_size
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
from vllm.model_executor.input_metadata import InputMetadata
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
from vllm.model_executor.model_loader import get_model
|
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.model_executor.utils import set_random_seed
|
from vllm.model_executor.utils import set_random_seed, get_model
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"InputMetadata",
|
"InputMetadata",
|
||||||
|
@ -10,6 +10,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTens
|
|||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
|
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
|
||||||
SequenceData, SequenceGroupOutput, SequenceOutput)
|
SequenceData, SequenceGroupOutput, SequenceOutput)
|
||||||
|
from vllm.utils import is_neuron
|
||||||
|
|
||||||
|
|
||||||
class Sampler(nn.Module):
|
class Sampler(nn.Module):
|
||||||
@ -32,6 +33,8 @@ class Sampler(nn.Module):
|
|||||||
org_vocab_size: Optional[int] = None) -> None:
|
org_vocab_size: Optional[int] = None) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
|
# Transformers-neuronx generate outputs as logits directly.
|
||||||
|
self.logits_as_hidden_states = is_neuron()
|
||||||
# original vocabulary size (without LoRA).
|
# original vocabulary size (without LoRA).
|
||||||
self.org_vocab_size = org_vocab_size or vocab_size
|
self.org_vocab_size = org_vocab_size or vocab_size
|
||||||
|
|
||||||
@ -55,10 +58,14 @@ class Sampler(nn.Module):
|
|||||||
embedding_bias: Optional[torch.Tensor] = None,
|
embedding_bias: Optional[torch.Tensor] = None,
|
||||||
) -> Optional[SamplerOutput]:
|
) -> Optional[SamplerOutput]:
|
||||||
# Get the hidden states that we use for sampling.
|
# Get the hidden states that we use for sampling.
|
||||||
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
|
if self.logits_as_hidden_states:
|
||||||
|
logits = hidden_states
|
||||||
|
else:
|
||||||
|
hidden_states = _prune_hidden_states(hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
|
||||||
# Get the logits for the next tokens.
|
# Get the logits for the next tokens.
|
||||||
logits = self._get_logits(hidden_states, embedding, embedding_bias)
|
logits = self._get_logits(hidden_states, embedding, embedding_bias)
|
||||||
|
|
||||||
# Only perform sampling in the driver worker.
|
# Only perform sampling in the driver worker.
|
||||||
# Note: `_get_logits` is still distributed across TP workers because
|
# Note: `_get_logits` is still distributed across TP workers because
|
||||||
@ -395,7 +402,8 @@ def _sample(
|
|||||||
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
|
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
|
||||||
is_prompts, sample_indices)
|
is_prompts, sample_indices)
|
||||||
if sampling_type == SamplingType.GREEDY:
|
if sampling_type == SamplingType.GREEDY:
|
||||||
greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1)
|
greedy_samples = torch.argmax(logprobs[sample_indices.long()],
|
||||||
|
dim=-1)
|
||||||
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
||||||
max_best_of = 1
|
max_best_of = 1
|
||||||
for seq_group, is_prompt in zip(seq_groups, is_prompts):
|
for seq_group, is_prompt in zip(seq_groups, is_prompts):
|
||||||
@ -407,7 +415,7 @@ def _sample(
|
|||||||
"generators": sampling_metadata.generators,
|
"generators": sampling_metadata.generators,
|
||||||
}
|
}
|
||||||
multinomial_samples[sampling_type] = _multinomial(
|
multinomial_samples[sampling_type] = _multinomial(
|
||||||
probs[sample_indices], max_best_of, **seeded_args)
|
probs[sample_indices.long()], max_best_of, **seeded_args)
|
||||||
elif sampling_type == SamplingType.BEAM:
|
elif sampling_type == SamplingType.BEAM:
|
||||||
beam_search_logprobs = logprobs[sample_indices]
|
beam_search_logprobs = logprobs[sample_indices]
|
||||||
else:
|
else:
|
||||||
|
@ -1,11 +1,11 @@
|
|||||||
"""Utilities for selecting and loading models."""
|
"""Utilities for selecting and loading models."""
|
||||||
import contextlib
|
import contextlib
|
||||||
from typing import Optional, Type
|
from typing import Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.config import DeviceConfig, ModelConfig, LoRAConfig
|
from vllm.config import DeviceConfig, ModelConfig
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.model_executor.weight_utils import (get_quant_config,
|
from vllm.model_executor.weight_utils import (get_quant_config,
|
||||||
initialize_dummy_weights)
|
initialize_dummy_weights)
|
||||||
@ -37,9 +37,9 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
|
|||||||
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
||||||
|
|
||||||
|
|
||||||
def get_model(model_config: ModelConfig,
|
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
|
||||||
device_config: DeviceConfig,
|
**kwargs) -> nn.Module:
|
||||||
lora_config: Optional[LoRAConfig] = None) -> nn.Module:
|
lora_config = kwargs.get("lora_config", None)
|
||||||
model_class = _get_model_architecture(model_config)
|
model_class = _get_model_architecture(model_config)
|
||||||
|
|
||||||
# Get the (maybe quantized) linear method.
|
# Get the (maybe quantized) linear method.
|
||||||
|
@ -4,7 +4,7 @@ from typing import List, Optional, Type
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import is_hip
|
from vllm.utils import is_hip, is_neuron
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -61,6 +61,9 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS = {
|
|||||||
"Sliding window attention is not yet supported in ROCm's flash attention",
|
"Sliding window attention is not yet supported in ROCm's flash attention",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Models not supported by Neuron.
|
||||||
|
_NEURON_SUPPORTED_MODELS = {"LlamaForCausalLM": "neuron.llama"}
|
||||||
|
|
||||||
|
|
||||||
class ModelRegistry:
|
class ModelRegistry:
|
||||||
|
|
||||||
@ -77,8 +80,15 @@ class ModelRegistry:
|
|||||||
logger.warning(
|
logger.warning(
|
||||||
f"Model architecture {model_arch} is partially supported "
|
f"Model architecture {model_arch} is partially supported "
|
||||||
"by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
|
"by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
|
||||||
|
elif is_neuron():
|
||||||
|
if model_arch not in _NEURON_SUPPORTED_MODELS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model architecture {model_arch} is not supported by "
|
||||||
|
"Neuron for now.")
|
||||||
|
|
||||||
module_name, model_cls_name = _MODELS[model_arch]
|
module_name, model_cls_name = _MODELS[model_arch]
|
||||||
|
if is_neuron():
|
||||||
|
module_name = _NEURON_SUPPORTED_MODELS[model_arch]
|
||||||
module = importlib.import_module(
|
module = importlib.import_module(
|
||||||
f"vllm.model_executor.models.{module_name}")
|
f"vllm.model_executor.models.{module_name}")
|
||||||
return getattr(module, model_cls_name, None)
|
return getattr(module, model_cls_name, None)
|
||||||
|
79
vllm/model_executor/models/neuron/llama.py
Normal file
79
vllm/model_executor/models/neuron/llama.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
||||||
|
import os
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
|
from vllm.model_executor.input_metadata import InputMetadata
|
||||||
|
from vllm.model_executor.layers.sampler import Sampler
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.sequence import SamplerOutput
|
||||||
|
|
||||||
|
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaForCausalLM(nn.Module):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: LlamaConfig,
|
||||||
|
linear_method=None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
self.linear_method = linear_method
|
||||||
|
self.model = None
|
||||||
|
self.sampler = Sampler(config.vocab_size)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[KVCache],
|
||||||
|
input_metadata: InputMetadata,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
with torch.inference_mode():
|
||||||
|
block_size = self.model.context_buckets[-1]
|
||||||
|
if input_metadata.is_prompt:
|
||||||
|
seq_ids = input_metadata.slot_mapping[:, 0] // block_size
|
||||||
|
else:
|
||||||
|
seq_ids = input_metadata.block_tables
|
||||||
|
logits = self.model(input_ids,
|
||||||
|
cache_ids=positions,
|
||||||
|
start_ids=seq_ids.flatten())
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
next_tokens = self.sampler(self.model.chkpt_model.lm_head,
|
||||||
|
hidden_states, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
def load_weights(self,
|
||||||
|
model_name_or_path: str,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
load_format: str = "auto",
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
**kwargs):
|
||||||
|
from transformers_neuronx.llama.model import LlamaForSampling
|
||||||
|
|
||||||
|
split_model_dir = f"{model_name_or_path}-split"
|
||||||
|
if os.path.isdir(os.path.join(model_name_or_path,
|
||||||
|
"pytorch_model.bin")):
|
||||||
|
split_model_dir = model_name_or_path
|
||||||
|
elif not os.path.exists(f"{model_name_or_path}-split"):
|
||||||
|
from transformers.models.llama import LlamaForCausalLM
|
||||||
|
from transformers_neuronx.module import save_pretrained_split
|
||||||
|
|
||||||
|
hf_model = LlamaForCausalLM.from_pretrained(model_name_or_path,
|
||||||
|
low_cpu_mem_usage=True)
|
||||||
|
save_pretrained_split(hf_model, f"{model_name_or_path}-split")
|
||||||
|
|
||||||
|
self.model = LlamaForSampling.from_pretrained(split_model_dir,
|
||||||
|
**kwargs)
|
||||||
|
self.model.to_neuron()
|
66
vllm/model_executor/neuron_model_loader.py
Normal file
66
vllm/model_executor/neuron_model_loader.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
"""Utilities for selecting and loading models."""
|
||||||
|
from typing import Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig, DeviceConfig
|
||||||
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
|
||||||
|
TORCH_DTYPE_TO_NEURON_AMP = {
|
||||||
|
"auto": "f32",
|
||||||
|
"half": "f16",
|
||||||
|
"float16": "f16",
|
||||||
|
"bfloat16": "bf16",
|
||||||
|
"float": "f32",
|
||||||
|
"float32": "f32",
|
||||||
|
torch.float16: "f16",
|
||||||
|
torch.bfloat16: "bf16",
|
||||||
|
torch.float32: "f32",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||||
|
architectures = getattr(config, "architectures", [])
|
||||||
|
for arch in architectures:
|
||||||
|
model_cls = ModelRegistry.load_model_cls(arch)
|
||||||
|
if model_cls is not None:
|
||||||
|
return model_cls
|
||||||
|
raise ValueError(
|
||||||
|
f"Model architectures {architectures} are not supported for now. "
|
||||||
|
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
|
||||||
|
**kwargs) -> nn.Module:
|
||||||
|
from transformers_neuronx.config import NeuronConfig, ContinuousBatchingConfig
|
||||||
|
|
||||||
|
parallel_config = kwargs.get("parallel_config")
|
||||||
|
scheduler_config = kwargs.get("scheduler_config")
|
||||||
|
|
||||||
|
model_class = _get_model_architecture(model_config.hf_config)
|
||||||
|
linear_method = None
|
||||||
|
|
||||||
|
# Create a model instance.
|
||||||
|
model = model_class(model_config.hf_config, linear_method)
|
||||||
|
|
||||||
|
continuous_batching_config = ContinuousBatchingConfig(
|
||||||
|
batch_size_for_shared_caches=scheduler_config.max_num_seqs)
|
||||||
|
neuron_config = NeuronConfig(
|
||||||
|
continuous_batching=continuous_batching_config)
|
||||||
|
|
||||||
|
# Load the weights from the cached or downloaded files.
|
||||||
|
model.load_weights(
|
||||||
|
model_config.model,
|
||||||
|
model_config.download_dir,
|
||||||
|
model_config.load_format,
|
||||||
|
model_config.revision,
|
||||||
|
tp_degree=parallel_config.neuron_tp_degree,
|
||||||
|
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
|
||||||
|
neuron_config=neuron_config,
|
||||||
|
context_length_estimate=[scheduler_config.max_model_len],
|
||||||
|
n_positions=[scheduler_config.max_model_len],
|
||||||
|
batch_size=scheduler_config.max_num_seqs)
|
||||||
|
|
||||||
|
return model.eval()
|
@ -5,7 +5,7 @@ import torch
|
|||||||
|
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
from vllm.sequence import SequenceData
|
from vllm.sequence import SequenceData
|
||||||
from vllm.utils import in_wsl
|
from vllm.utils import in_wsl, is_neuron
|
||||||
|
|
||||||
_SAMPLING_EPS = 1e-5
|
_SAMPLING_EPS = 1e-5
|
||||||
|
|
||||||
@ -155,7 +155,7 @@ class SamplingTensors:
|
|||||||
dtype: torch.dtype) -> "SamplingTensors":
|
dtype: torch.dtype) -> "SamplingTensors":
|
||||||
# Note that the performance will be very bad without
|
# Note that the performance will be very bad without
|
||||||
# pinned memory.
|
# pinned memory.
|
||||||
pin_memory = not in_wsl()
|
pin_memory = not in_wsl() and not is_neuron()
|
||||||
prompt_max_len = max(len(tokens) for tokens in prompt_tokens)
|
prompt_max_len = max(len(tokens) for tokens in prompt_tokens)
|
||||||
prompt_padded_tokens = [
|
prompt_padded_tokens = [
|
||||||
tokens + [vocab_size] * (prompt_max_len - len(tokens))
|
tokens + [vocab_size] * (prompt_max_len - len(tokens))
|
||||||
|
@ -1,10 +1,18 @@
|
|||||||
"""Utils for model executor."""
|
"""Utils for model executor."""
|
||||||
import random
|
import random
|
||||||
|
import importlib
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.config import DeviceConfig, ModelConfig
|
||||||
|
|
||||||
|
DEVICE_TO_MODEL_LOADER_MAP = {
|
||||||
|
"cuda": "model_loader",
|
||||||
|
"neuron": "neuron_model_loader",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def set_random_seed(seed: int) -> None:
|
def set_random_seed(seed: int) -> None:
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
@ -33,3 +41,12 @@ def set_weight_attrs(
|
|||||||
assert not hasattr(
|
assert not hasattr(
|
||||||
weight, key), (f"Overwriting existing tensor attribute: {key}")
|
weight, key), (f"Overwriting existing tensor attribute: {key}")
|
||||||
setattr(weight, key, value)
|
setattr(weight, key, value)
|
||||||
|
|
||||||
|
|
||||||
|
def get_model(model_config: ModelConfig, device_config: DeviceConfig,
|
||||||
|
**kwargs) -> torch.nn.Module:
|
||||||
|
model_loader_module = DEVICE_TO_MODEL_LOADER_MAP[device_config.device_type]
|
||||||
|
imported_model_loader = importlib.import_module(
|
||||||
|
f"vllm.model_executor.{model_loader_module}")
|
||||||
|
get_model_fn = imported_model_loader.get_model
|
||||||
|
return get_model_fn(model_config, device_config, **kwargs)
|
||||||
|
@ -118,6 +118,14 @@ def is_hip() -> bool:
|
|||||||
return torch.version.hip is not None
|
return torch.version.hip is not None
|
||||||
|
|
||||||
|
|
||||||
|
def is_neuron() -> bool:
|
||||||
|
try:
|
||||||
|
import transformers_neuronx
|
||||||
|
except ImportError:
|
||||||
|
transformers_neuronx = None
|
||||||
|
return transformers_neuronx is not None
|
||||||
|
|
||||||
|
|
||||||
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
||||||
"""Returns the maximum shared memory per thread block in bytes."""
|
"""Returns the maximum shared memory per thread block in bytes."""
|
||||||
# NOTE: This import statement should be executed lazily since
|
# NOTE: This import statement should be executed lazily since
|
||||||
|
@ -3,10 +3,9 @@ from typing import Dict, List, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm._C import cache_ops
|
|
||||||
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
|
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import in_wsl, STR_DTYPE_TO_TORCH_DTYPE
|
from vllm.utils import in_wsl, is_neuron, STR_DTYPE_TO_TORCH_DTYPE
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -39,6 +38,10 @@ class CacheEngine:
|
|||||||
self.num_gpu_blocks = cache_config.num_gpu_blocks
|
self.num_gpu_blocks = cache_config.num_gpu_blocks
|
||||||
self.num_cpu_blocks = cache_config.num_cpu_blocks
|
self.num_cpu_blocks = cache_config.num_cpu_blocks
|
||||||
|
|
||||||
|
# Skip initializing CUDA stream and buffer for Neuron backend.
|
||||||
|
if is_neuron():
|
||||||
|
return
|
||||||
|
|
||||||
if cache_config.cache_dtype == "auto":
|
if cache_config.cache_dtype == "auto":
|
||||||
self.dtype = model_config.dtype
|
self.dtype = model_config.dtype
|
||||||
else:
|
else:
|
||||||
@ -121,6 +124,8 @@ class CacheEngine:
|
|||||||
dst: List[KVCache],
|
dst: List[KVCache],
|
||||||
src_to_dst: Dict[int, int],
|
src_to_dst: Dict[int, int],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
from vllm._C import cache_ops
|
||||||
|
|
||||||
with torch.cuda.stream(self.cache_stream):
|
with torch.cuda.stream(self.cache_stream):
|
||||||
for i in range(self.num_layers):
|
for i in range(self.num_layers):
|
||||||
src_key_cache, src_value_cache = src[i]
|
src_key_cache, src_value_cache = src[i]
|
||||||
@ -140,6 +145,8 @@ class CacheEngine:
|
|||||||
self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)
|
self._swap(self.gpu_cache, self.cpu_cache, src_to_dst)
|
||||||
|
|
||||||
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
|
def copy(self, src_to_dsts: Dict[int, List[int]]) -> None:
|
||||||
|
from vllm._C import cache_ops
|
||||||
|
|
||||||
key_caches = [key_cache for key_cache, _ in self.gpu_cache]
|
key_caches = [key_cache for key_cache, _ in self.gpu_cache]
|
||||||
value_caches = [value_cache for _, value_cache in self.gpu_cache]
|
value_caches = [value_cache for _, value_cache in self.gpu_cache]
|
||||||
# NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
|
# NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU.
|
||||||
|
@ -80,9 +80,16 @@ class ModelRunner:
|
|||||||
self.in_wsl = in_wsl()
|
self.in_wsl = in_wsl()
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
|
||||||
|
# Set enforce_eager to True for Neuron backend, to avoid capturing graph
|
||||||
|
if self.device_config.is_neuron:
|
||||||
|
self.model_config.enforce_eager = True
|
||||||
|
|
||||||
def load_model(self) -> None:
|
def load_model(self) -> None:
|
||||||
self.model = get_model(self.model_config, self.device_config,
|
self.model = get_model(self.model_config,
|
||||||
self.lora_config)
|
self.device_config,
|
||||||
|
lora_config=self.lora_config,
|
||||||
|
parallel_config=self.parallel_config,
|
||||||
|
scheduler_config=self.scheduler_config)
|
||||||
|
|
||||||
vocab_size = self.model.config.vocab_size
|
vocab_size = self.model.config.vocab_size
|
||||||
|
|
||||||
@ -393,6 +400,7 @@ class ModelRunner:
|
|||||||
selected_token_start_idx = 0
|
selected_token_start_idx = 0
|
||||||
categorized_sample_indices = {t: [] for t in SamplingType}
|
categorized_sample_indices = {t: [] for t in SamplingType}
|
||||||
categorized_sample_indices_start_idx = 0
|
categorized_sample_indices_start_idx = 0
|
||||||
|
pin_memory = not self.in_wsl and not self.device_config.is_neuron
|
||||||
|
|
||||||
max_subquery_len = max(subquery_lens) if subquery_lens else 1
|
max_subquery_len = max(subquery_lens) if subquery_lens else 1
|
||||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||||
@ -443,12 +451,12 @@ class ModelRunner:
|
|||||||
selected_token_indices = _async_h2d(selected_token_indices,
|
selected_token_indices = _async_h2d(selected_token_indices,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
target_device=self.device,
|
target_device=self.device,
|
||||||
pin_memory=not self.in_wsl)
|
pin_memory=pin_memory)
|
||||||
categorized_sample_indices = {
|
categorized_sample_indices = {
|
||||||
t: _async_h2d(seq_ids,
|
t: _async_h2d(seq_ids,
|
||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
target_device=self.device,
|
target_device=self.device,
|
||||||
pin_memory=not self.in_wsl)
|
pin_memory=pin_memory)
|
||||||
for t, seq_ids in categorized_sample_indices.items()
|
for t, seq_ids in categorized_sample_indices.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
191
vllm/worker/neuron_worker.py
Normal file
191
vllm/worker/neuron_worker.py
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
"""A Neuron worker class."""
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||||
|
ParallelConfig, SchedulerConfig, LoRAConfig)
|
||||||
|
from vllm.model_executor import set_random_seed
|
||||||
|
from vllm.model_executor.parallel_utils.communication_op import (
|
||||||
|
broadcast_tensor_dict)
|
||||||
|
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||||
|
ensure_model_parallel_initialized)
|
||||||
|
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||||
|
from vllm.worker.cache_engine import CacheEngine
|
||||||
|
from vllm.worker.model_runner import ModelRunner
|
||||||
|
|
||||||
|
|
||||||
|
class Worker:
|
||||||
|
"""A worker class that executes the model on a group of neuron cores.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
scheduler_config: SchedulerConfig,
|
||||||
|
device_config: DeviceConfig,
|
||||||
|
local_rank: int,
|
||||||
|
rank: int,
|
||||||
|
distributed_init_method: str,
|
||||||
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
|
kv_cache_dtype: Optional[str] = "auto",
|
||||||
|
is_driver_worker: bool = False,
|
||||||
|
) -> None:
|
||||||
|
self.model_config = model_config
|
||||||
|
self.parallel_config = parallel_config
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.device_config = device_config
|
||||||
|
self.local_rank = local_rank
|
||||||
|
self.rank = rank
|
||||||
|
self.distributed_init_method = distributed_init_method
|
||||||
|
self.lora_config = lora_config
|
||||||
|
self.is_driver_worker = is_driver_worker
|
||||||
|
if self.is_driver_worker:
|
||||||
|
assert self.rank == 0, "The driver worker must have rank 0."
|
||||||
|
|
||||||
|
self.model_runner = ModelRunner(model_config,
|
||||||
|
parallel_config,
|
||||||
|
scheduler_config,
|
||||||
|
device_config,
|
||||||
|
lora_config=self.lora_config,
|
||||||
|
is_driver_worker=is_driver_worker)
|
||||||
|
# Uninitialized cache engine. Will be initialized by
|
||||||
|
# self.init_cache_engine().
|
||||||
|
self.cache_config = None
|
||||||
|
self.cache_engine = None
|
||||||
|
self.cache_events = None
|
||||||
|
self.gpu_cache = None
|
||||||
|
|
||||||
|
def init_model(self) -> None:
|
||||||
|
# Initialize the distributed environment.
|
||||||
|
_init_distributed_environment(self.parallel_config,
|
||||||
|
self.rank,
|
||||||
|
self.distributed_init_method,
|
||||||
|
distributed_backend="gloo")
|
||||||
|
|
||||||
|
# Initialize the model.
|
||||||
|
set_random_seed(self.model_config.seed)
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
self.model_runner.load_model()
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def profile_num_available_blocks(
|
||||||
|
self,
|
||||||
|
block_size: int = 128,
|
||||||
|
gpu_memory_utilization: float = 0.9,
|
||||||
|
cpu_swap_space: int = 0,
|
||||||
|
cache_dtype: str = "float16",
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
"""Simply returns max_num_seqs as num_gpu_blocks, 0 as num_cpu_blocks."""
|
||||||
|
num_gpu_blocks = self.scheduler_config.max_num_seqs
|
||||||
|
num_cpu_blocks = 0
|
||||||
|
return num_gpu_blocks, num_cpu_blocks
|
||||||
|
|
||||||
|
def init_cache_engine(self, cache_config: CacheConfig) -> None:
|
||||||
|
self.cache_config = cache_config
|
||||||
|
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
|
||||||
|
self.parallel_config)
|
||||||
|
self.model_runner.set_block_size(self.cache_engine.block_size)
|
||||||
|
|
||||||
|
def warm_up_model(self) -> None:
|
||||||
|
# Warm up is maintained in transformers-neuronx
|
||||||
|
pass
|
||||||
|
|
||||||
|
def cache_swap(
|
||||||
|
self,
|
||||||
|
blocks_to_swap_in: Dict[int, int],
|
||||||
|
blocks_to_swap_out: Dict[int, int],
|
||||||
|
blocks_to_copy: Dict[int, List[int]],
|
||||||
|
) -> None:
|
||||||
|
# Issue cache operations.
|
||||||
|
issued_cache_op = False
|
||||||
|
if blocks_to_swap_in:
|
||||||
|
self.cache_engine.swap_in(blocks_to_swap_in)
|
||||||
|
issued_cache_op = True
|
||||||
|
if blocks_to_swap_out:
|
||||||
|
self.cache_engine.swap_out(blocks_to_swap_out)
|
||||||
|
issued_cache_op = True
|
||||||
|
if blocks_to_copy:
|
||||||
|
self.cache_engine.copy(blocks_to_copy)
|
||||||
|
issued_cache_op = True
|
||||||
|
|
||||||
|
cache_events = self.cache_events if issued_cache_op else None
|
||||||
|
|
||||||
|
# Wait for cache operations to finish.
|
||||||
|
if cache_events is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"cache operations are not implemented for neuron backend.")
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None,
|
||||||
|
blocks_to_swap_in: Optional[Dict[int, int]] = None,
|
||||||
|
blocks_to_swap_out: Optional[Dict[int, int]] = None,
|
||||||
|
blocks_to_copy: Optional[Dict[int, List[int]]] = None,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
if self.is_driver_worker:
|
||||||
|
assert seq_group_metadata_list is not None
|
||||||
|
num_seq_groups = len(seq_group_metadata_list)
|
||||||
|
assert blocks_to_swap_in is not None
|
||||||
|
assert blocks_to_swap_out is not None
|
||||||
|
assert blocks_to_copy is not None
|
||||||
|
data = {
|
||||||
|
"num_seq_groups": num_seq_groups,
|
||||||
|
"blocks_to_swap_in": blocks_to_swap_in,
|
||||||
|
"blocks_to_swap_out": blocks_to_swap_out,
|
||||||
|
"blocks_to_copy": blocks_to_copy,
|
||||||
|
}
|
||||||
|
broadcast_tensor_dict(data, src=0)
|
||||||
|
else:
|
||||||
|
data = broadcast_tensor_dict(src=0)
|
||||||
|
num_seq_groups = data["num_seq_groups"]
|
||||||
|
blocks_to_swap_in = data["blocks_to_swap_in"]
|
||||||
|
blocks_to_swap_out = data["blocks_to_swap_out"]
|
||||||
|
blocks_to_copy = data["blocks_to_copy"]
|
||||||
|
|
||||||
|
self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy)
|
||||||
|
|
||||||
|
# If there is no input, we don't need to execute the model.
|
||||||
|
if num_seq_groups == 0:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
output = self.model_runner.execute_model(seq_group_metadata_list,
|
||||||
|
self.gpu_cache)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def _init_distributed_environment(
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
rank: int,
|
||||||
|
distributed_init_method: Optional[str] = None,
|
||||||
|
distributed_backend: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the distributed environment."""
|
||||||
|
if torch.distributed.is_initialized():
|
||||||
|
torch_world_size = torch.distributed.get_world_size()
|
||||||
|
if torch_world_size != parallel_config.world_size:
|
||||||
|
raise RuntimeError(
|
||||||
|
"torch.distributed is already initialized but the torch world "
|
||||||
|
"size does not match parallel_config.world_size "
|
||||||
|
f"({torch_world_size} vs. {parallel_config.world_size}).")
|
||||||
|
elif not distributed_init_method:
|
||||||
|
raise ValueError(
|
||||||
|
"distributed_init_method must be set if torch.distributed "
|
||||||
|
"is not already initialized")
|
||||||
|
else:
|
||||||
|
distributed_backend = distributed_backend if distributed_backend else "nccl"
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend=distributed_backend,
|
||||||
|
world_size=parallel_config.world_size,
|
||||||
|
rank=rank,
|
||||||
|
init_method=distributed_init_method,
|
||||||
|
)
|
||||||
|
|
||||||
|
# A small all_reduce for warmup.
|
||||||
|
torch.distributed.all_reduce(torch.zeros(1))
|
||||||
|
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
|
||||||
|
parallel_config.pipeline_parallel_size)
|
Loading…
x
Reference in New Issue
Block a user