From 463bbb1835b8bbb1a80e6286c6396f6f3182c4f7 Mon Sep 17 00:00:00 2001 From: wwl2755 Date: Thu, 3 Apr 2025 02:32:10 -0500 Subject: [PATCH] [Bugfix][V1] Fix bug from putting llm_engine.model_executor in a background process (#15367) Signed-off-by: wwl2755 --- .../offline_inference/load_sharded_state.py | 93 +++++++++++++++++++ .../offline_inference/save_sharded_state.py | 23 ++++- vllm/v1/engine/core.py | 10 ++ vllm/v1/engine/core_client.py | 31 +++++++ vllm/v1/worker/gpu_worker.py | 14 +++ 5 files changed, 167 insertions(+), 4 deletions(-) create mode 100644 examples/offline_inference/load_sharded_state.py diff --git a/examples/offline_inference/load_sharded_state.py b/examples/offline_inference/load_sharded_state.py new file mode 100644 index 00000000..7e90d5d2 --- /dev/null +++ b/examples/offline_inference/load_sharded_state.py @@ -0,0 +1,93 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Validates the loading of a model saved with the sharded_state format. +This script demonstrates how to load a model that was previously saved +using save_sharded_state.py and validates it by running inference. +Example usage: +(First need to save a sharded_state mode) + +python save_sharded_state.py \ + --model /path/to/load \ + --quantization deepspeedfp \ + --tensor-parallel-size 8 \ + --output /path/to/save/sharded/modele + +python load_sharded_state.py \ + --model /path/to/saved/sharded/model \ + --load-format sharded_state \ + --quantization deepspeedfp \ + --tensor-parallel-size 8 \ + --prompt "Hello, my name is" \ + --max-tokens 50 +""" + +import dataclasses + +from vllm import LLM, EngineArgs, SamplingParams +from vllm.utils import FlexibleArgumentParser + + +def parse_args(): + parser = FlexibleArgumentParser() + # Add engine arguments + EngineArgs.add_cli_args(parser) + + # Override default load_format for clarity + parser.set_defaults(load_format="sharded_state") + + # Add validation arguments + parser.add_argument("--prompt", + type=str, + default="Hello, world!", + help="Prompt for validation") + parser.add_argument("--max-tokens", + type=int, + default=100, + help="Maximum number of tokens to generate") + parser.add_argument("--temperature", + type=float, + default=0.7, + help="Sampling temperature") + parser.add_argument("--top-p", + type=float, + default=1.0, + help="Top-p sampling parameter") + + return parser.parse_args() + + +def main(): + args = parse_args() + engine_args = EngineArgs.from_cli_args(args) + + print(f"Loading model from {engine_args.model} " + f"using format {engine_args.load_format}") + print(f"Tensor parallel size: {engine_args.tensor_parallel_size}") + + # Load the model using engine args + llm = LLM(**dataclasses.asdict(engine_args)) + + # Prepare sampling parameters + sampling_params = SamplingParams( + temperature=args.temperature, + top_p=args.top_p, + max_tokens=args.max_tokens, + ) + + print("\nRunning inference:") + print(f"Prompt: {args.prompt}") + + # Generate completion + outputs = llm.generate(args.prompt, sampling_params) + + # Display generated text + print("\nGenerated outputs:") + for output in outputs: + generated_text = output.outputs[0].text + print("-" * 50) + print(f"Full output: {args.prompt}{generated_text}") + print("-" * 50) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/offline_inference/save_sharded_state.py b/examples/offline_inference/save_sharded_state.py index 86327643..6aac9b75 100644 --- a/examples/offline_inference/save_sharded_state.py +++ b/examples/offline_inference/save_sharded_state.py @@ -57,10 +57,25 @@ def main(args): # Prepare output directory Path(args.output).mkdir(exist_ok=True) # Dump worker states to output directory - model_executor = llm.llm_engine.model_executor - model_executor.save_sharded_state(path=args.output, - pattern=args.file_pattern, - max_size=args.max_file_size) + + # Check which engine version is being used + is_v1_engine = hasattr(llm.llm_engine, "engine_core") + + if is_v1_engine: + # For V1 engine, we need to use engine_core.save_sharded_state + print("Using V1 engine save path") + llm.llm_engine.engine_core.save_sharded_state( + path=args.output, + pattern=args.file_pattern, + max_size=args.max_file_size) + else: + # For V0 engine + print("Using V0 engine save path") + model_executor = llm.llm_engine.model_executor + model_executor.save_sharded_state(path=args.output, + pattern=args.file_pattern, + max_size=args.max_file_size) + # Copy metadata files to output directory for file in os.listdir(model_path): if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"): diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 19c7799b..39caca0c 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -285,6 +285,16 @@ class EngineCore: def pin_lora(self, lora_id: int) -> bool: return self.model_executor.pin_lora(lora_id) + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + self.model_executor.save_sharded_state(path=path, + pattern=pattern, + max_size=max_size) + def collective_rpc(self, method: Union[str, Callable[..., _R]], timeout: Optional[float] = None, diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 99774ff4..e948e59b 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -119,6 +119,12 @@ class EngineCoreClient(ABC): def pin_lora(self, lora_id: int) -> bool: raise NotImplementedError + def save_sharded_state(self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None) -> None: + raise NotImplementedError + def collective_rpc(self, method: Union[str, Callable[..., _R]], timeout: Optional[float] = None, @@ -162,6 +168,12 @@ class EngineCoreClient(ABC): async def pin_lora_async(self, lora_id: int) -> bool: raise NotImplementedError + async def save_sharded_state_async(self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None) -> None: + raise NotImplementedError + async def collective_rpc_async( self, method: Union[str, Callable[..., _R]], @@ -227,6 +239,12 @@ class InprocClient(EngineCoreClient): def pin_lora(self, lora_id: int) -> bool: return self.engine_core.pin_lora(lora_id) + def save_sharded_state(self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None) -> None: + self.engine_core.save_sharded_state(path, pattern, max_size) + def collective_rpc(self, method: Union[str, Callable[..., _R]], timeout: Optional[float] = None, @@ -537,6 +555,12 @@ class SyncMPClient(MPClient): return self.call_utility("collective_rpc", method, timeout, args, kwargs) + def save_sharded_state(self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None) -> None: + self.call_utility("save_sharded_state", path, pattern, max_size) + class AsyncMPClient(MPClient): """Asyncio-compatible client for multi-proc EngineCore.""" @@ -668,6 +692,13 @@ class AsyncMPClient(MPClient): async def pin_lora_async(self, lora_id: int) -> bool: return await self.call_utility_async("pin_lora", lora_id) + async def save_sharded_state_async(self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None) -> None: + await self.call_utility_async("save_sharded_state", path, pattern, + max_size) + async def collective_rpc_async( self, method: Union[str, Callable[..., _R]], diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 19144368..2972e0ff 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -269,6 +269,20 @@ class Worker(WorkerBase): # worker will always be healthy as long as it's running. return + def save_sharded_state( + self, + path: str, + pattern: Optional[str] = None, + max_size: Optional[int] = None, + ) -> None: + from vllm.model_executor.model_loader.loader import ShardedStateLoader + ShardedStateLoader.save_model( + self.model_runner.model, + path, + pattern=pattern, + max_size=max_size, + ) + def init_worker_distributed_environment( parallel_config: ParallelConfig,