[Neuron] Adding support for adding/ overriding neuron configuration a… (#8062)
Co-authored-by: Harsha Bikki <harbikh@amazon.com>
This commit is contained in:
parent
77d9e514a2
commit
008cf886c9
50
examples/offline_inference_neuron_int8_quantization.py
Normal file
50
examples/offline_inference_neuron_int8_quantization.py
Normal file
@ -0,0 +1,50 @@
|
||||
import os
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# creates XLA hlo graphs for all the context length buckets.
|
||||
os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048"
|
||||
# creates XLA hlo graphs for all the token gen buckets.
|
||||
os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048"
|
||||
# Quantizes neuron model weight to int8 ,
|
||||
# The default config for quantization is int8 dtype.
|
||||
os.environ['NEURON_QUANT_DTYPE'] = "s8"
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
||||
|
||||
# Create an LLM.
|
||||
llm = LLM(
|
||||
model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
max_num_seqs=8,
|
||||
# The max_model_len and block_size arguments are required to be same as
|
||||
# max sequence length when targeting neuron device.
|
||||
# Currently, this is a known limitation in continuous batching support
|
||||
# in transformers-neuronx.
|
||||
# TODO(liangfu): Support paged-attention in transformers-neuronx.
|
||||
max_model_len=2048,
|
||||
block_size=2048,
|
||||
# The device can be automatically detected when AWS Neuron SDK is installed.
|
||||
# The device argument can be either unspecified for automated detection,
|
||||
# or explicitly assigned.
|
||||
device="neuron",
|
||||
quantization="neuron_quant",
|
||||
override_neuron_config={
|
||||
"cast_logits_dtype": "bfloat16",
|
||||
},
|
||||
tensor_parallel_size=2)
|
||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||
# that contain the prompt, generated text, and other information.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
# Print the outputs.
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
@ -1,8 +1,8 @@
|
||||
import enum
|
||||
import json
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import (TYPE_CHECKING, ClassVar, List, Mapping, Optional, Tuple,
|
||||
Type, Union)
|
||||
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Mapping,
|
||||
Optional, Tuple, Type, Union)
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
@ -115,35 +115,39 @@ class ModelConfig:
|
||||
the model name will be the same as `model`.
|
||||
limit_mm_per_prompt: Maximum number of data instances per modality
|
||||
per prompt. Only applicable for multimodal models.
|
||||
override_neuron_config: Initialize non default neuron config or
|
||||
override default neuron config that are specific to Neuron devices,
|
||||
this argument will be used to configure the neuron config that
|
||||
can not be gathered from the vllm arguments.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
tokenizer: str,
|
||||
tokenizer_mode: str,
|
||||
trust_remote_code: bool,
|
||||
dtype: Union[str, torch.dtype],
|
||||
seed: int,
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
rope_scaling: Optional[dict] = None,
|
||||
rope_theta: Optional[float] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
max_model_len: Optional[int] = None,
|
||||
spec_target_max_model_len: Optional[int] = None,
|
||||
quantization: Optional[str] = None,
|
||||
quantization_param_path: Optional[str] = None,
|
||||
enforce_eager: Optional[bool] = None,
|
||||
max_context_len_to_capture: Optional[int] = None,
|
||||
max_seq_len_to_capture: Optional[int] = None,
|
||||
max_logprobs: int = 20,
|
||||
disable_sliding_window: bool = False,
|
||||
skip_tokenizer_init: bool = False,
|
||||
served_model_name: Optional[Union[str, List[str]]] = None,
|
||||
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
|
||||
use_async_output_proc: bool = True,
|
||||
) -> None:
|
||||
self,
|
||||
model: str,
|
||||
tokenizer: str,
|
||||
tokenizer_mode: str,
|
||||
trust_remote_code: bool,
|
||||
dtype: Union[str, torch.dtype],
|
||||
seed: int,
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
rope_scaling: Optional[dict] = None,
|
||||
rope_theta: Optional[float] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
max_model_len: Optional[int] = None,
|
||||
spec_target_max_model_len: Optional[int] = None,
|
||||
quantization: Optional[str] = None,
|
||||
quantization_param_path: Optional[str] = None,
|
||||
enforce_eager: Optional[bool] = None,
|
||||
max_context_len_to_capture: Optional[int] = None,
|
||||
max_seq_len_to_capture: Optional[int] = None,
|
||||
max_logprobs: int = 20,
|
||||
disable_sliding_window: bool = False,
|
||||
skip_tokenizer_init: bool = False,
|
||||
served_model_name: Optional[Union[str, List[str]]] = None,
|
||||
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
|
||||
use_async_output_proc: bool = True,
|
||||
override_neuron_config: Optional[Dict[str, Any]] = None) -> None:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
@ -227,6 +231,9 @@ class ModelConfig:
|
||||
limit_mm_per_prompt)
|
||||
if not self.skip_tokenizer_init:
|
||||
self._verify_tokenizer_mode()
|
||||
|
||||
self.override_neuron_config = override_neuron_config if is_neuron(
|
||||
) else None
|
||||
self._verify_embedding_mode()
|
||||
self._verify_quantization()
|
||||
self._verify_cuda_graph()
|
||||
@ -275,6 +282,7 @@ class ModelConfig:
|
||||
"experts_int8"
|
||||
]
|
||||
tpu_supported_quantization = ["tpu_int8"]
|
||||
neuron_supported_quantization = ["neuron_quant"]
|
||||
if self.quantization is not None:
|
||||
self.quantization = self.quantization.lower()
|
||||
|
||||
@ -329,6 +337,11 @@ class ModelConfig:
|
||||
"Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ"
|
||||
" is not set, enabling VLLM_USE_TRITON_AWQ.")
|
||||
envs.VLLM_USE_TRITON_AWQ = True
|
||||
if is_neuron(
|
||||
) and self.quantization not in neuron_supported_quantization:
|
||||
raise ValueError(
|
||||
f"{self.quantization} quantization is currently not "
|
||||
f"supported in Neuron Backend.")
|
||||
|
||||
def _verify_cuda_graph(self) -> None:
|
||||
if self.max_seq_len_to_capture is None:
|
||||
|
@ -2,8 +2,8 @@ import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type,
|
||||
Union)
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple,
|
||||
Type, Union)
|
||||
|
||||
import torch
|
||||
|
||||
@ -149,6 +149,7 @@ class EngineArgs:
|
||||
otlp_traces_endpoint: Optional[str] = None
|
||||
collect_detailed_traces: Optional[str] = None
|
||||
disable_async_output_proc: bool = False
|
||||
override_neuron_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer is None:
|
||||
@ -742,6 +743,16 @@ class EngineArgs:
|
||||
default=EngineArgs.disable_async_output_proc,
|
||||
help="Disable async output processing. This may result in "
|
||||
"lower performance.")
|
||||
parser.add_argument(
|
||||
'--override-neuron-config',
|
||||
type=lambda configs: {
|
||||
str(key): value
|
||||
for key, value in
|
||||
(config.split(':') for config in configs.split(','))
|
||||
},
|
||||
default=None,
|
||||
help="override or set neuron device configuration.")
|
||||
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
@ -802,7 +813,7 @@ class EngineArgs:
|
||||
served_model_name=self.served_model_name,
|
||||
limit_mm_per_prompt=self.limit_mm_per_prompt,
|
||||
use_async_output_proc=not self.disable_async_output_proc,
|
||||
)
|
||||
override_neuron_config=self.override_neuron_config)
|
||||
cache_config = CacheConfig(
|
||||
block_size=self.block_size if self.device != "neuron" else
|
||||
self.max_model_len, # neuron needs block_size = max_model_len
|
||||
|
@ -214,6 +214,7 @@ class LLMEngine:
|
||||
"Initializing an LLM engine (v%s) with config: "
|
||||
"model=%r, speculative_config=%r, tokenizer=%r, "
|
||||
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
|
||||
"override_neuron_config=%s, "
|
||||
"rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, "
|
||||
"trust_remote_code=%s, dtype=%s, max_seq_len=%d, "
|
||||
"download_dir=%r, load_format=%s, tensor_parallel_size=%d, "
|
||||
@ -232,6 +233,7 @@ class LLMEngine:
|
||||
model_config.skip_tokenizer_init,
|
||||
model_config.tokenizer_mode,
|
||||
model_config.revision,
|
||||
model_config.override_neuron_config,
|
||||
model_config.rope_scaling,
|
||||
model_config.rope_theta,
|
||||
model_config.tokenizer_revision,
|
||||
|
@ -22,6 +22,8 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQMarlin24Config)
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
|
||||
from vllm.model_executor.layers.quantization.neuron_quant import (
|
||||
NeuronQuantConfig)
|
||||
from vllm.model_executor.layers.quantization.qqq import QQQConfig
|
||||
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
|
||||
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
|
||||
@ -46,6 +48,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
"bitsandbytes": BitsAndBytesConfig,
|
||||
"qqq": QQQConfig,
|
||||
"experts_int8": ExpertsInt8Config,
|
||||
"neuron_quant": NeuronQuantConfig,
|
||||
}
|
||||
|
||||
|
||||
|
67
vllm/model_executor/layers/quantization/neuron_quant.py
Normal file
67
vllm/model_executor/layers/quantization/neuron_quant.py
Normal file
@ -0,0 +1,67 @@
|
||||
import os
|
||||
from importlib.util import find_spec
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from torch.nn import Module
|
||||
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
|
||||
SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn']
|
||||
|
||||
|
||||
class NeuronQuantConfig(QuantizationConfig):
|
||||
"""Int8 Quantization Config class for Neuron Backend."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dequant_dtype: str = "f16",
|
||||
quantize_method: str = "vector_dynamic",
|
||||
) -> None:
|
||||
self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8")
|
||||
if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST:
|
||||
raise ValueError(
|
||||
f"Neuron quantization datatype {self.quant_dtype} is not valid,"
|
||||
f"the quantization datatype should match one of the below types"
|
||||
f"{SUPPORTED_QUANT_DTYPE_LIST}")
|
||||
self.dequant_dtype = dequant_dtype
|
||||
self.quantize_method = quantize_method
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "neuron_quant"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[str]:
|
||||
return SUPPORTED_QUANT_DTYPE_LIST
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise NotImplementedError(
|
||||
"This function should not be called with Neuron Backend")
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> List[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "NeuronQuantConfig":
|
||||
quantize_method = cls.get_from_keys(config, ["quantize_method"])
|
||||
dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"])
|
||||
return cls(dequant_dtype=dequant_dtype,
|
||||
quantize_method=quantize_method)
|
||||
|
||||
def get_quant_method(self, layer: Module, prefix: str) -> Optional[Any]:
|
||||
if find_spec("transformers_neuronx") is not None:
|
||||
return self.get_quantization_config()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Neuron Quantization is only supported through"
|
||||
" transformers_neuronx.")
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
return []
|
||||
|
||||
def get_quantization_config(self):
|
||||
from transformers_neuronx.config import QuantizationConfig
|
||||
return QuantizationConfig(quant_dtype=self.quant_dtype,
|
||||
dequant_dtype=self.dequant_dtype,
|
||||
quantize_method=self.quantize_method)
|
@ -10,6 +10,7 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import get_quantization_config
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
|
||||
@ -81,8 +82,7 @@ class NeuronCasualLM(nn.Module):
|
||||
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
|
||||
|
||||
split_model_dir = f"{model_name_or_path}-split"
|
||||
if os.path.isdir(os.path.join(model_name_or_path,
|
||||
"pytorch_model.bin")):
|
||||
if _is_pretrained_neuron_checkpoint(model_name_or_path):
|
||||
split_model_dir = model_name_or_path
|
||||
elif not os.path.exists(f"{model_name_or_path}-split"):
|
||||
hf_model_cls = getattr(transformers, hf_model_cls_name)
|
||||
@ -97,6 +97,23 @@ class NeuronCasualLM(nn.Module):
|
||||
self.model.to_neuron()
|
||||
|
||||
|
||||
def _is_pretrained_neuron_checkpoint(model_name_or_path: str) -> bool:
|
||||
# Checking if the neuron checkpoint is saved in the old format.
|
||||
if os.path.isdir(os.path.join(model_name_or_path, "pytorch_model.bin")):
|
||||
return True
|
||||
# Checking if the neuron checkpoint is saved in the new format.
|
||||
pretrained_split_files = ["config.json", "generation_config.json"]
|
||||
pretrained_split_format = ".safetensors"
|
||||
for file in pretrained_split_files:
|
||||
file_path = os.path.join(model_name_or_path, file)
|
||||
if not os.path.isfile(file_path):
|
||||
return False
|
||||
for file in os.listdir(model_name_or_path):
|
||||
if file.endswith(pretrained_split_format):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _get_model_architecture(config: PretrainedConfig) -> str:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
@ -119,19 +136,51 @@ def _get_buckets(env: str, default_value: List[int]) -> List[int]:
|
||||
return buckets_list
|
||||
|
||||
|
||||
def _get_default_neuron_config(model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig):
|
||||
from transformers_neuronx.config import ContinuousBatchingConfig
|
||||
from transformers_neuronx.constants import LAYOUT_BSH
|
||||
|
||||
continuous_batching_config = ContinuousBatchingConfig(
|
||||
batch_size_for_shared_caches=scheduler_config.max_num_seqs)
|
||||
quant_config = dict(
|
||||
dequant_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
|
||||
quantize_method="vector_dynamic")
|
||||
neuron_quantization_config_builder = lambda quant: get_quantization_config(
|
||||
quant).from_config(quant_config).get_quant_method(None, "")
|
||||
# TODO: Add Paged attention config to the default neuron arguments.
|
||||
default_neuron_args = dict(
|
||||
collectives_layout=LAYOUT_BSH,
|
||||
attention_layout=LAYOUT_BSH,
|
||||
fuse_qkv=True,
|
||||
quant=neuron_quantization_config_builder(model_config.quantization)
|
||||
if model_config.quantization else None,
|
||||
continuous_batching=continuous_batching_config,
|
||||
weight_tiling=bool(model_config.quantization))
|
||||
return default_neuron_args
|
||||
|
||||
|
||||
def _get_neuron_config_after_override(default_neuron_config,
|
||||
overridden_neuron_config):
|
||||
from transformers_neuronx.config import NeuronConfig
|
||||
overridden_neuron_config = overridden_neuron_config or {}
|
||||
default_neuron_config.update(overridden_neuron_config)
|
||||
return NeuronConfig(**default_neuron_config)
|
||||
|
||||
|
||||
def get_neuron_model(model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig) -> nn.Module:
|
||||
from transformers_neuronx.config import (ContinuousBatchingConfig,
|
||||
NeuronConfig)
|
||||
|
||||
# Create a model instance.
|
||||
model = NeuronCasualLM(model_config.hf_config)
|
||||
|
||||
continuous_batching_config = ContinuousBatchingConfig(
|
||||
batch_size_for_shared_caches=scheduler_config.max_num_seqs)
|
||||
neuron_config = NeuronConfig(
|
||||
continuous_batching=continuous_batching_config)
|
||||
default_neuron_config_args = _get_default_neuron_config(
|
||||
model_config, parallel_config, scheduler_config)
|
||||
|
||||
neuron_config = _get_neuron_config_after_override(
|
||||
default_neuron_config_args, model_config.override_neuron_config)
|
||||
|
||||
context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
|
||||
[scheduler_config.max_model_len])
|
||||
|
@ -1,4 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from importlib.util import find_spec
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -76,9 +77,14 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
||||
self.model: nn.Module # initialize after load_model.
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model = get_neuron_model(self.model_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config)
|
||||
if find_spec("transformers_neuronx") is not None:
|
||||
self.model = get_neuron_model(
|
||||
self.model_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Supports only Transformer-NeuronX based models.")
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
|
Loading…
x
Reference in New Issue
Block a user