[Bugfix][V1] Fix bug from putting llm_engine.model_executor in a background process (#15367)
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
parent
5e125e74d1
commit
463bbb1835
93
examples/offline_inference/load_sharded_state.py
Normal file
93
examples/offline_inference/load_sharded_state.py
Normal file
@ -0,0 +1,93 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Validates the loading of a model saved with the sharded_state format.
|
||||
This script demonstrates how to load a model that was previously saved
|
||||
using save_sharded_state.py and validates it by running inference.
|
||||
Example usage:
|
||||
(First need to save a sharded_state mode)
|
||||
|
||||
python save_sharded_state.py \
|
||||
--model /path/to/load \
|
||||
--quantization deepspeedfp \
|
||||
--tensor-parallel-size 8 \
|
||||
--output /path/to/save/sharded/modele
|
||||
|
||||
python load_sharded_state.py \
|
||||
--model /path/to/saved/sharded/model \
|
||||
--load-format sharded_state \
|
||||
--quantization deepspeedfp \
|
||||
--tensor-parallel-size 8 \
|
||||
--prompt "Hello, my name is" \
|
||||
--max-tokens 50
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
|
||||
from vllm import LLM, EngineArgs, SamplingParams
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = FlexibleArgumentParser()
|
||||
# Add engine arguments
|
||||
EngineArgs.add_cli_args(parser)
|
||||
|
||||
# Override default load_format for clarity
|
||||
parser.set_defaults(load_format="sharded_state")
|
||||
|
||||
# Add validation arguments
|
||||
parser.add_argument("--prompt",
|
||||
type=str,
|
||||
default="Hello, world!",
|
||||
help="Prompt for validation")
|
||||
parser.add_argument("--max-tokens",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum number of tokens to generate")
|
||||
parser.add_argument("--temperature",
|
||||
type=float,
|
||||
default=0.7,
|
||||
help="Sampling temperature")
|
||||
parser.add_argument("--top-p",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Top-p sampling parameter")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
|
||||
print(f"Loading model from {engine_args.model} "
|
||||
f"using format {engine_args.load_format}")
|
||||
print(f"Tensor parallel size: {engine_args.tensor_parallel_size}")
|
||||
|
||||
# Load the model using engine args
|
||||
llm = LLM(**dataclasses.asdict(engine_args))
|
||||
|
||||
# Prepare sampling parameters
|
||||
sampling_params = SamplingParams(
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
max_tokens=args.max_tokens,
|
||||
)
|
||||
|
||||
print("\nRunning inference:")
|
||||
print(f"Prompt: {args.prompt}")
|
||||
|
||||
# Generate completion
|
||||
outputs = llm.generate(args.prompt, sampling_params)
|
||||
|
||||
# Display generated text
|
||||
print("\nGenerated outputs:")
|
||||
for output in outputs:
|
||||
generated_text = output.outputs[0].text
|
||||
print("-" * 50)
|
||||
print(f"Full output: {args.prompt}{generated_text}")
|
||||
print("-" * 50)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -57,10 +57,25 @@ def main(args):
|
||||
# Prepare output directory
|
||||
Path(args.output).mkdir(exist_ok=True)
|
||||
# Dump worker states to output directory
|
||||
model_executor = llm.llm_engine.model_executor
|
||||
model_executor.save_sharded_state(path=args.output,
|
||||
pattern=args.file_pattern,
|
||||
max_size=args.max_file_size)
|
||||
|
||||
# Check which engine version is being used
|
||||
is_v1_engine = hasattr(llm.llm_engine, "engine_core")
|
||||
|
||||
if is_v1_engine:
|
||||
# For V1 engine, we need to use engine_core.save_sharded_state
|
||||
print("Using V1 engine save path")
|
||||
llm.llm_engine.engine_core.save_sharded_state(
|
||||
path=args.output,
|
||||
pattern=args.file_pattern,
|
||||
max_size=args.max_file_size)
|
||||
else:
|
||||
# For V0 engine
|
||||
print("Using V0 engine save path")
|
||||
model_executor = llm.llm_engine.model_executor
|
||||
model_executor.save_sharded_state(path=args.output,
|
||||
pattern=args.file_pattern,
|
||||
max_size=args.max_file_size)
|
||||
|
||||
# Copy metadata files to output directory
|
||||
for file in os.listdir(model_path):
|
||||
if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"):
|
||||
|
@ -285,6 +285,16 @@ class EngineCore:
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.model_executor.pin_lora(lora_id)
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
pattern: Optional[str] = None,
|
||||
max_size: Optional[int] = None,
|
||||
) -> None:
|
||||
self.model_executor.save_sharded_state(path=path,
|
||||
pattern=pattern,
|
||||
max_size=max_size)
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
timeout: Optional[float] = None,
|
||||
|
@ -119,6 +119,12 @@ class EngineCoreClient(ABC):
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def save_sharded_state(self,
|
||||
path: str,
|
||||
pattern: Optional[str] = None,
|
||||
max_size: Optional[int] = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
timeout: Optional[float] = None,
|
||||
@ -162,6 +168,12 @@ class EngineCoreClient(ABC):
|
||||
async def pin_lora_async(self, lora_id: int) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
async def save_sharded_state_async(self,
|
||||
path: str,
|
||||
pattern: Optional[str] = None,
|
||||
max_size: Optional[int] = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
async def collective_rpc_async(
|
||||
self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
@ -227,6 +239,12 @@ class InprocClient(EngineCoreClient):
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.engine_core.pin_lora(lora_id)
|
||||
|
||||
def save_sharded_state(self,
|
||||
path: str,
|
||||
pattern: Optional[str] = None,
|
||||
max_size: Optional[int] = None) -> None:
|
||||
self.engine_core.save_sharded_state(path, pattern, max_size)
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
timeout: Optional[float] = None,
|
||||
@ -537,6 +555,12 @@ class SyncMPClient(MPClient):
|
||||
return self.call_utility("collective_rpc", method, timeout, args,
|
||||
kwargs)
|
||||
|
||||
def save_sharded_state(self,
|
||||
path: str,
|
||||
pattern: Optional[str] = None,
|
||||
max_size: Optional[int] = None) -> None:
|
||||
self.call_utility("save_sharded_state", path, pattern, max_size)
|
||||
|
||||
|
||||
class AsyncMPClient(MPClient):
|
||||
"""Asyncio-compatible client for multi-proc EngineCore."""
|
||||
@ -668,6 +692,13 @@ class AsyncMPClient(MPClient):
|
||||
async def pin_lora_async(self, lora_id: int) -> bool:
|
||||
return await self.call_utility_async("pin_lora", lora_id)
|
||||
|
||||
async def save_sharded_state_async(self,
|
||||
path: str,
|
||||
pattern: Optional[str] = None,
|
||||
max_size: Optional[int] = None) -> None:
|
||||
await self.call_utility_async("save_sharded_state", path, pattern,
|
||||
max_size)
|
||||
|
||||
async def collective_rpc_async(
|
||||
self,
|
||||
method: Union[str, Callable[..., _R]],
|
||||
|
@ -269,6 +269,20 @@ class Worker(WorkerBase):
|
||||
# worker will always be healthy as long as it's running.
|
||||
return
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
path: str,
|
||||
pattern: Optional[str] = None,
|
||||
max_size: Optional[int] = None,
|
||||
) -> None:
|
||||
from vllm.model_executor.model_loader.loader import ShardedStateLoader
|
||||
ShardedStateLoader.save_model(
|
||||
self.model_runner.model,
|
||||
path,
|
||||
pattern=pattern,
|
||||
max_size=max_size,
|
||||
)
|
||||
|
||||
|
||||
def init_worker_distributed_environment(
|
||||
parallel_config: ParallelConfig,
|
||||
|
Loading…
x
Reference in New Issue
Block a user