[9/N] torch.compile LLM usage (#10552)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
aed074860a
commit
33e0a2540a
@ -4,7 +4,7 @@ import tempfile
|
||||
|
||||
import depyf
|
||||
|
||||
from vllm.config import CompilationConfig, CompilationLevel
|
||||
from vllm.config import CompilationLevel
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
with depyf.prepare_debug(temp_dir):
|
||||
@ -34,8 +34,7 @@ with depyf.prepare_debug(temp_dir):
|
||||
# all the control
|
||||
llm = LLM(model="google/gemma-2b",
|
||||
enforce_eager=True,
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.DYNAMO_AS_IS))
|
||||
compilation_config={"level": CompilationLevel.DYNAMO_AS_IS})
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output, answer in zip(outputs, answers):
|
||||
prompt = output.prompt
|
||||
|
@ -1,4 +1,5 @@
|
||||
import itertools
|
||||
import json
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
|
||||
@ -9,6 +10,7 @@ from tqdm import tqdm
|
||||
from vllm import envs
|
||||
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
|
||||
BeamSearchSequence, get_beam_search_score)
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
|
||||
TaskOption)
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
@ -107,13 +109,16 @@ class LLM:
|
||||
hf_overrides: If a dictionary, contains arguments to be forwarded to the
|
||||
HuggingFace config. If a callable, it is called to update the
|
||||
HuggingFace config.
|
||||
compilation_config: Either an integer or a dictionary. If it is an integer,
|
||||
it is used as the level of compilation optimization. If it is a dictionary,
|
||||
it can specify the full compilation configuration.
|
||||
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
|
||||
:ref:`engine_args`)
|
||||
|
||||
Note:
|
||||
This class is intended to be used for offline inference. For online
|
||||
serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
|
||||
"""
|
||||
""" # noqa
|
||||
|
||||
DEPRECATE_LEGACY: ClassVar[bool] = False
|
||||
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
|
||||
@ -166,6 +171,7 @@ class LLM:
|
||||
# After positional args are removed, move this right below `model`
|
||||
task: TaskOption = "auto",
|
||||
override_pooler_config: Optional[PoolerConfig] = None,
|
||||
compilation_config: Optional[Union[int, Dict[str, Any]]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
'''
|
||||
@ -178,6 +184,12 @@ class LLM:
|
||||
if "disable_log_stats" not in kwargs:
|
||||
kwargs["disable_log_stats"] = True
|
||||
|
||||
if compilation_config is not None:
|
||||
compilation_config_instance = CompilationConfig.from_cli(
|
||||
json.dumps(compilation_config))
|
||||
else:
|
||||
compilation_config_instance = None
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
task=task,
|
||||
@ -202,6 +214,7 @@ class LLM:
|
||||
hf_overrides=hf_overrides,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
override_pooler_config=override_pooler_config,
|
||||
compilation_config=compilation_config_instance,
|
||||
**kwargs,
|
||||
)
|
||||
# Logic to switch between engines is done at runtime instead of import
|
||||
|
Loading…
x
Reference in New Issue
Block a user