[Neuron] Adding support for context-lenght, token-gen buckets. (#7885)

Co-authored-by: Harsha Bikki <harbikh@amazon.com>
This commit is contained in:
Harsha vardhan manoj Bikki 2024-08-29 13:58:14 -07:00 committed by GitHub
parent 86a677de42
commit 257afc37c5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 33 additions and 11 deletions

View File

@ -1,5 +1,12 @@
import os
from vllm import LLM, SamplingParams
# creates XLA hlo graphs for all the context length buckets.
os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048"
# creates XLA hlo graphs for all the token gen buckets.
os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048"
# Sample prompts.
prompts = [
"Hello, my name is",
@ -19,8 +26,8 @@ llm = LLM(
# Currently, this is a known limitation in continuous batching support
# in transformers-neuronx.
# TODO(liangfu): Support paged-attention in transformers-neuronx.
max_model_len=128,
block_size=128,
max_model_len=2048,
block_size=2048,
# The device can be automatically detected when AWS Neuron SDK is installed.
# The device argument can be either unspecified for automated detection,
# or explicitly assigned.

View File

@ -1,7 +1,7 @@
"""Utilities for selecting and loading neuron models."""
import importlib
import os
from typing import Dict, Optional, Tuple
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
@ -109,6 +109,17 @@ def _get_model_architecture(config: PretrainedConfig) -> str:
f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
def _get_buckets(env: str, default_value: List[int]) -> List[int]:
env_value = os.getenv(env)
if env_value is None:
return default_value
buckets_remove_empty = filter(
lambda x: x is not None and len(x.strip()) > 0, env_value.split(","))
buckets_int = map(int, buckets_remove_empty)
buckets_list = list(buckets_int)
return buckets_list
def get_neuron_model(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
@ -123,14 +134,18 @@ def get_neuron_model(model_config: ModelConfig,
neuron_config = NeuronConfig(
continuous_batching=continuous_batching_config)
context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
[scheduler_config.max_model_len])
n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
[scheduler_config.max_model_len])
# Load the weights from the cached or downloaded files.
model.load_weights(
model_config.model,
model.load_weights(model_config.model,
tp_degree=parallel_config.tensor_parallel_size,
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
neuron_config=neuron_config,
context_length_estimate=[scheduler_config.max_model_len],
n_positions=[scheduler_config.max_model_len],
context_length_estimate=context_length_estimates,
n_positions=n_positions,
batch_size=scheduler_config.max_num_seqs)
return model.eval()