[9/N] torch.compile LLM usage (#10552)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao 2024-11-21 19:13:31 -08:00 committed by GitHub
parent aed074860a
commit 33e0a2540a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 4 deletions

View File

@ -4,7 +4,7 @@ import tempfile
import depyf import depyf
from vllm.config import CompilationConfig, CompilationLevel from vllm.config import CompilationLevel
temp_dir = tempfile.mkdtemp() temp_dir = tempfile.mkdtemp()
with depyf.prepare_debug(temp_dir): with depyf.prepare_debug(temp_dir):
@ -34,8 +34,7 @@ with depyf.prepare_debug(temp_dir):
# all the control # all the control
llm = LLM(model="google/gemma-2b", llm = LLM(model="google/gemma-2b",
enforce_eager=True, enforce_eager=True,
compilation_config=CompilationConfig( compilation_config={"level": CompilationLevel.DYNAMO_AS_IS})
level=CompilationLevel.DYNAMO_AS_IS))
outputs = llm.generate(prompts, sampling_params) outputs = llm.generate(prompts, sampling_params)
for output, answer in zip(outputs, answers): for output, answer in zip(outputs, answers):
prompt = output.prompt prompt = output.prompt

View File

@ -1,4 +1,5 @@
import itertools import itertools
import json
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
@ -9,6 +10,7 @@ from tqdm import tqdm
from vllm import envs from vllm import envs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score) BeamSearchSequence, get_beam_search_score)
from vllm.config import CompilationConfig
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig, from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
TaskOption) TaskOption)
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
@ -107,13 +109,16 @@ class LLM:
hf_overrides: If a dictionary, contains arguments to be forwarded to the hf_overrides: If a dictionary, contains arguments to be forwarded to the
HuggingFace config. If a callable, it is called to update the HuggingFace config. If a callable, it is called to update the
HuggingFace config. HuggingFace config.
compilation_config: Either an integer or a dictionary. If it is an integer,
it is used as the level of compilation optimization. If it is a dictionary,
it can specify the full compilation configuration.
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
:ref:`engine_args`) :ref:`engine_args`)
Note: Note:
This class is intended to be used for offline inference. For online This class is intended to be used for offline inference. For online
serving, use the :class:`~vllm.AsyncLLMEngine` class instead. serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
""" """ # noqa
DEPRECATE_LEGACY: ClassVar[bool] = False DEPRECATE_LEGACY: ClassVar[bool] = False
"""A flag to toggle whether to deprecate the legacy generate/encode API.""" """A flag to toggle whether to deprecate the legacy generate/encode API."""
@ -166,6 +171,7 @@ class LLM:
# After positional args are removed, move this right below `model` # After positional args are removed, move this right below `model`
task: TaskOption = "auto", task: TaskOption = "auto",
override_pooler_config: Optional[PoolerConfig] = None, override_pooler_config: Optional[PoolerConfig] = None,
compilation_config: Optional[Union[int, Dict[str, Any]]] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
''' '''
@ -178,6 +184,12 @@ class LLM:
if "disable_log_stats" not in kwargs: if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True kwargs["disable_log_stats"] = True
if compilation_config is not None:
compilation_config_instance = CompilationConfig.from_cli(
json.dumps(compilation_config))
else:
compilation_config_instance = None
engine_args = EngineArgs( engine_args = EngineArgs(
model=model, model=model,
task=task, task=task,
@ -202,6 +214,7 @@ class LLM:
hf_overrides=hf_overrides, hf_overrides=hf_overrides,
mm_processor_kwargs=mm_processor_kwargs, mm_processor_kwargs=mm_processor_kwargs,
override_pooler_config=override_pooler_config, override_pooler_config=override_pooler_config,
compilation_config=compilation_config_instance,
**kwargs, **kwargs,
) )
# Logic to switch between engines is done at runtime instead of import # Logic to switch between engines is done at runtime instead of import