[Bugfix][V1] Fix bug from putting llm_engine.model_executor in a background process (#15367)
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
parent
5e125e74d1
commit
463bbb1835
93
examples/offline_inference/load_sharded_state.py
Normal file
93
examples/offline_inference/load_sharded_state.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""
|
||||||
|
Validates the loading of a model saved with the sharded_state format.
|
||||||
|
This script demonstrates how to load a model that was previously saved
|
||||||
|
using save_sharded_state.py and validates it by running inference.
|
||||||
|
Example usage:
|
||||||
|
(First need to save a sharded_state mode)
|
||||||
|
|
||||||
|
python save_sharded_state.py \
|
||||||
|
--model /path/to/load \
|
||||||
|
--quantization deepspeedfp \
|
||||||
|
--tensor-parallel-size 8 \
|
||||||
|
--output /path/to/save/sharded/modele
|
||||||
|
|
||||||
|
python load_sharded_state.py \
|
||||||
|
--model /path/to/saved/sharded/model \
|
||||||
|
--load-format sharded_state \
|
||||||
|
--quantization deepspeedfp \
|
||||||
|
--tensor-parallel-size 8 \
|
||||||
|
--prompt "Hello, my name is" \
|
||||||
|
--max-tokens 50
|
||||||
|
"""
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
|
from vllm import LLM, EngineArgs, SamplingParams
|
||||||
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = FlexibleArgumentParser()
|
||||||
|
# Add engine arguments
|
||||||
|
EngineArgs.add_cli_args(parser)
|
||||||
|
|
||||||
|
# Override default load_format for clarity
|
||||||
|
parser.set_defaults(load_format="sharded_state")
|
||||||
|
|
||||||
|
# Add validation arguments
|
||||||
|
parser.add_argument("--prompt",
|
||||||
|
type=str,
|
||||||
|
default="Hello, world!",
|
||||||
|
help="Prompt for validation")
|
||||||
|
parser.add_argument("--max-tokens",
|
||||||
|
type=int,
|
||||||
|
default=100,
|
||||||
|
help="Maximum number of tokens to generate")
|
||||||
|
parser.add_argument("--temperature",
|
||||||
|
type=float,
|
||||||
|
default=0.7,
|
||||||
|
help="Sampling temperature")
|
||||||
|
parser.add_argument("--top-p",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="Top-p sampling parameter")
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
engine_args = EngineArgs.from_cli_args(args)
|
||||||
|
|
||||||
|
print(f"Loading model from {engine_args.model} "
|
||||||
|
f"using format {engine_args.load_format}")
|
||||||
|
print(f"Tensor parallel size: {engine_args.tensor_parallel_size}")
|
||||||
|
|
||||||
|
# Load the model using engine args
|
||||||
|
llm = LLM(**dataclasses.asdict(engine_args))
|
||||||
|
|
||||||
|
# Prepare sampling parameters
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=args.temperature,
|
||||||
|
top_p=args.top_p,
|
||||||
|
max_tokens=args.max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\nRunning inference:")
|
||||||
|
print(f"Prompt: {args.prompt}")
|
||||||
|
|
||||||
|
# Generate completion
|
||||||
|
outputs = llm.generate(args.prompt, sampling_params)
|
||||||
|
|
||||||
|
# Display generated text
|
||||||
|
print("\nGenerated outputs:")
|
||||||
|
for output in outputs:
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
print("-" * 50)
|
||||||
|
print(f"Full output: {args.prompt}{generated_text}")
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -57,10 +57,25 @@ def main(args):
|
|||||||
# Prepare output directory
|
# Prepare output directory
|
||||||
Path(args.output).mkdir(exist_ok=True)
|
Path(args.output).mkdir(exist_ok=True)
|
||||||
# Dump worker states to output directory
|
# Dump worker states to output directory
|
||||||
|
|
||||||
|
# Check which engine version is being used
|
||||||
|
is_v1_engine = hasattr(llm.llm_engine, "engine_core")
|
||||||
|
|
||||||
|
if is_v1_engine:
|
||||||
|
# For V1 engine, we need to use engine_core.save_sharded_state
|
||||||
|
print("Using V1 engine save path")
|
||||||
|
llm.llm_engine.engine_core.save_sharded_state(
|
||||||
|
path=args.output,
|
||||||
|
pattern=args.file_pattern,
|
||||||
|
max_size=args.max_file_size)
|
||||||
|
else:
|
||||||
|
# For V0 engine
|
||||||
|
print("Using V0 engine save path")
|
||||||
model_executor = llm.llm_engine.model_executor
|
model_executor = llm.llm_engine.model_executor
|
||||||
model_executor.save_sharded_state(path=args.output,
|
model_executor.save_sharded_state(path=args.output,
|
||||||
pattern=args.file_pattern,
|
pattern=args.file_pattern,
|
||||||
max_size=args.max_file_size)
|
max_size=args.max_file_size)
|
||||||
|
|
||||||
# Copy metadata files to output directory
|
# Copy metadata files to output directory
|
||||||
for file in os.listdir(model_path):
|
for file in os.listdir(model_path):
|
||||||
if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"):
|
if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"):
|
||||||
|
@ -285,6 +285,16 @@ class EngineCore:
|
|||||||
def pin_lora(self, lora_id: int) -> bool:
|
def pin_lora(self, lora_id: int) -> bool:
|
||||||
return self.model_executor.pin_lora(lora_id)
|
return self.model_executor.pin_lora(lora_id)
|
||||||
|
|
||||||
|
def save_sharded_state(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
pattern: Optional[str] = None,
|
||||||
|
max_size: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
self.model_executor.save_sharded_state(path=path,
|
||||||
|
pattern=pattern,
|
||||||
|
max_size=max_size)
|
||||||
|
|
||||||
def collective_rpc(self,
|
def collective_rpc(self,
|
||||||
method: Union[str, Callable[..., _R]],
|
method: Union[str, Callable[..., _R]],
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
|
@ -119,6 +119,12 @@ class EngineCoreClient(ABC):
|
|||||||
def pin_lora(self, lora_id: int) -> bool:
|
def pin_lora(self, lora_id: int) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def save_sharded_state(self,
|
||||||
|
path: str,
|
||||||
|
pattern: Optional[str] = None,
|
||||||
|
max_size: Optional[int] = None) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def collective_rpc(self,
|
def collective_rpc(self,
|
||||||
method: Union[str, Callable[..., _R]],
|
method: Union[str, Callable[..., _R]],
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
@ -162,6 +168,12 @@ class EngineCoreClient(ABC):
|
|||||||
async def pin_lora_async(self, lora_id: int) -> bool:
|
async def pin_lora_async(self, lora_id: int) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def save_sharded_state_async(self,
|
||||||
|
path: str,
|
||||||
|
pattern: Optional[str] = None,
|
||||||
|
max_size: Optional[int] = None) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def collective_rpc_async(
|
async def collective_rpc_async(
|
||||||
self,
|
self,
|
||||||
method: Union[str, Callable[..., _R]],
|
method: Union[str, Callable[..., _R]],
|
||||||
@ -227,6 +239,12 @@ class InprocClient(EngineCoreClient):
|
|||||||
def pin_lora(self, lora_id: int) -> bool:
|
def pin_lora(self, lora_id: int) -> bool:
|
||||||
return self.engine_core.pin_lora(lora_id)
|
return self.engine_core.pin_lora(lora_id)
|
||||||
|
|
||||||
|
def save_sharded_state(self,
|
||||||
|
path: str,
|
||||||
|
pattern: Optional[str] = None,
|
||||||
|
max_size: Optional[int] = None) -> None:
|
||||||
|
self.engine_core.save_sharded_state(path, pattern, max_size)
|
||||||
|
|
||||||
def collective_rpc(self,
|
def collective_rpc(self,
|
||||||
method: Union[str, Callable[..., _R]],
|
method: Union[str, Callable[..., _R]],
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
@ -537,6 +555,12 @@ class SyncMPClient(MPClient):
|
|||||||
return self.call_utility("collective_rpc", method, timeout, args,
|
return self.call_utility("collective_rpc", method, timeout, args,
|
||||||
kwargs)
|
kwargs)
|
||||||
|
|
||||||
|
def save_sharded_state(self,
|
||||||
|
path: str,
|
||||||
|
pattern: Optional[str] = None,
|
||||||
|
max_size: Optional[int] = None) -> None:
|
||||||
|
self.call_utility("save_sharded_state", path, pattern, max_size)
|
||||||
|
|
||||||
|
|
||||||
class AsyncMPClient(MPClient):
|
class AsyncMPClient(MPClient):
|
||||||
"""Asyncio-compatible client for multi-proc EngineCore."""
|
"""Asyncio-compatible client for multi-proc EngineCore."""
|
||||||
@ -668,6 +692,13 @@ class AsyncMPClient(MPClient):
|
|||||||
async def pin_lora_async(self, lora_id: int) -> bool:
|
async def pin_lora_async(self, lora_id: int) -> bool:
|
||||||
return await self.call_utility_async("pin_lora", lora_id)
|
return await self.call_utility_async("pin_lora", lora_id)
|
||||||
|
|
||||||
|
async def save_sharded_state_async(self,
|
||||||
|
path: str,
|
||||||
|
pattern: Optional[str] = None,
|
||||||
|
max_size: Optional[int] = None) -> None:
|
||||||
|
await self.call_utility_async("save_sharded_state", path, pattern,
|
||||||
|
max_size)
|
||||||
|
|
||||||
async def collective_rpc_async(
|
async def collective_rpc_async(
|
||||||
self,
|
self,
|
||||||
method: Union[str, Callable[..., _R]],
|
method: Union[str, Callable[..., _R]],
|
||||||
|
@ -269,6 +269,20 @@ class Worker(WorkerBase):
|
|||||||
# worker will always be healthy as long as it's running.
|
# worker will always be healthy as long as it's running.
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def save_sharded_state(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
pattern: Optional[str] = None,
|
||||||
|
max_size: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
from vllm.model_executor.model_loader.loader import ShardedStateLoader
|
||||||
|
ShardedStateLoader.save_model(
|
||||||
|
self.model_runner.model,
|
||||||
|
path,
|
||||||
|
pattern=pattern,
|
||||||
|
max_size=max_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def init_worker_distributed_environment(
|
def init_worker_distributed_environment(
|
||||||
parallel_config: ParallelConfig,
|
parallel_config: ParallelConfig,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user