[9/N] torch.compile LLM usage (#10552)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
aed074860a
commit
33e0a2540a
@ -4,7 +4,7 @@ import tempfile
|
|||||||
|
|
||||||
import depyf
|
import depyf
|
||||||
|
|
||||||
from vllm.config import CompilationConfig, CompilationLevel
|
from vllm.config import CompilationLevel
|
||||||
|
|
||||||
temp_dir = tempfile.mkdtemp()
|
temp_dir = tempfile.mkdtemp()
|
||||||
with depyf.prepare_debug(temp_dir):
|
with depyf.prepare_debug(temp_dir):
|
||||||
@ -34,8 +34,7 @@ with depyf.prepare_debug(temp_dir):
|
|||||||
# all the control
|
# all the control
|
||||||
llm = LLM(model="google/gemma-2b",
|
llm = LLM(model="google/gemma-2b",
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
compilation_config=CompilationConfig(
|
compilation_config={"level": CompilationLevel.DYNAMO_AS_IS})
|
||||||
level=CompilationLevel.DYNAMO_AS_IS))
|
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
for output, answer in zip(outputs, answers):
|
for output, answer in zip(outputs, answers):
|
||||||
prompt = output.prompt
|
prompt = output.prompt
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import itertools
|
import itertools
|
||||||
|
import json
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
|
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
|
||||||
@ -9,6 +10,7 @@ from tqdm import tqdm
|
|||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
|
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
|
||||||
BeamSearchSequence, get_beam_search_score)
|
BeamSearchSequence, get_beam_search_score)
|
||||||
|
from vllm.config import CompilationConfig
|
||||||
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
|
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
|
||||||
TaskOption)
|
TaskOption)
|
||||||
from vllm.engine.llm_engine import LLMEngine
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
@ -107,13 +109,16 @@ class LLM:
|
|||||||
hf_overrides: If a dictionary, contains arguments to be forwarded to the
|
hf_overrides: If a dictionary, contains arguments to be forwarded to the
|
||||||
HuggingFace config. If a callable, it is called to update the
|
HuggingFace config. If a callable, it is called to update the
|
||||||
HuggingFace config.
|
HuggingFace config.
|
||||||
|
compilation_config: Either an integer or a dictionary. If it is an integer,
|
||||||
|
it is used as the level of compilation optimization. If it is a dictionary,
|
||||||
|
it can specify the full compilation configuration.
|
||||||
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
|
**kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
|
||||||
:ref:`engine_args`)
|
:ref:`engine_args`)
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
This class is intended to be used for offline inference. For online
|
This class is intended to be used for offline inference. For online
|
||||||
serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
|
serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
|
||||||
"""
|
""" # noqa
|
||||||
|
|
||||||
DEPRECATE_LEGACY: ClassVar[bool] = False
|
DEPRECATE_LEGACY: ClassVar[bool] = False
|
||||||
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
|
"""A flag to toggle whether to deprecate the legacy generate/encode API."""
|
||||||
@ -166,6 +171,7 @@ class LLM:
|
|||||||
# After positional args are removed, move this right below `model`
|
# After positional args are removed, move this right below `model`
|
||||||
task: TaskOption = "auto",
|
task: TaskOption = "auto",
|
||||||
override_pooler_config: Optional[PoolerConfig] = None,
|
override_pooler_config: Optional[PoolerConfig] = None,
|
||||||
|
compilation_config: Optional[Union[int, Dict[str, Any]]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
'''
|
'''
|
||||||
@ -178,6 +184,12 @@ class LLM:
|
|||||||
if "disable_log_stats" not in kwargs:
|
if "disable_log_stats" not in kwargs:
|
||||||
kwargs["disable_log_stats"] = True
|
kwargs["disable_log_stats"] = True
|
||||||
|
|
||||||
|
if compilation_config is not None:
|
||||||
|
compilation_config_instance = CompilationConfig.from_cli(
|
||||||
|
json.dumps(compilation_config))
|
||||||
|
else:
|
||||||
|
compilation_config_instance = None
|
||||||
|
|
||||||
engine_args = EngineArgs(
|
engine_args = EngineArgs(
|
||||||
model=model,
|
model=model,
|
||||||
task=task,
|
task=task,
|
||||||
@ -202,6 +214,7 @@ class LLM:
|
|||||||
hf_overrides=hf_overrides,
|
hf_overrides=hf_overrides,
|
||||||
mm_processor_kwargs=mm_processor_kwargs,
|
mm_processor_kwargs=mm_processor_kwargs,
|
||||||
override_pooler_config=override_pooler_config,
|
override_pooler_config=override_pooler_config,
|
||||||
|
compilation_config=compilation_config_instance,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
# Logic to switch between engines is done at runtime instead of import
|
# Logic to switch between engines is done at runtime instead of import
|
||||||
|
Loading…
x
Reference in New Issue
Block a user