[bugfix] torch profiler bug for single gpu with GPUExecutor (#8354)

This commit is contained in:
William Lin 2024-09-12 21:30:00 -07:00 committed by GitHub
parent 6821020109
commit ba77527955
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 27 additions and 5 deletions

View File

@ -16,7 +16,7 @@ prompts = [
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM.
llm = LLM(model="facebook/opt-125m")
llm = LLM(model="facebook/opt-125m", tensor_parallel_size=1)
llm.start_profile()

View File

@ -13,6 +13,7 @@ from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
from vllm.engine.metrics_types import StatLoggerBase
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutorAsync
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptInputs
from vllm.logger import init_logger
@ -1019,7 +1020,17 @@ class AsyncLLMEngine:
self.engine.remove_logger(logger_name=logger_name)
async def start_profile(self) -> None:
self.engine.model_executor._run_workers("start_profile")
# using type instead of isinstance to check to avoid capturing
# inherited classes
if type(self.engine.model_executor) == GPUExecutorAsync:
self.engine.model_executor.start_profile()
else:
self.engine.model_executor._run_workers("start_profile")
async def stop_profile(self) -> None:
self.engine.model_executor._run_workers("stop_profile")
# using type instead of isinstance to check to avoid capturing
# inherited classes
if type(self.engine.model_executor) == GPUExecutorAsync:
self.engine.model_executor.stop_profile()
else:
self.engine.model_executor._run_workers("stop_profile")

View File

@ -26,6 +26,7 @@ from vllm.engine.output_processor.interfaces import (
from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs,
InputRegistry, LLMInputs, PromptInputs)
@ -1597,10 +1598,20 @@ class LLMEngine:
self.model_executor.check_health()
def start_profile(self) -> None:
self.model_executor.start_profile()
# using type instead of isinstance to check to avoid capturing
# inherited classes (MultiprocessingGPUExecutor)
if type(self.model_executor) == GPUExecutor:
self.model_executor.start_profile()
else:
self.model_executor._run_workers("start_profile")
def stop_profile(self) -> None:
self.model_executor.stop_profile()
# using type instead of isinstance to check to avoid capturing
# inherited classes (MultiprocessingGPUExecutor)
if type(self.model_executor) == GPUExecutor:
self.model_executor.stop_profile()
else:
self.model_executor._run_workers("stop_profile")
def is_tracing_enabled(self) -> bool:
return self.tracer is not None