93 lines
2.7 KiB
Python
93 lines
2.7 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
"""
|
|
Validates the loading of a model saved with the sharded_state format.
|
|
This script demonstrates how to load a model that was previously saved
|
|
using save_sharded_state.py and validates it by running inference.
|
|
Example usage:
|
|
(First need to save a sharded_state mode)
|
|
|
|
python save_sharded_state.py \
|
|
--model /path/to/load \
|
|
--quantization deepspeedfp \
|
|
--tensor-parallel-size 8 \
|
|
--output /path/to/save/sharded/modele
|
|
|
|
python load_sharded_state.py \
|
|
--model /path/to/saved/sharded/model \
|
|
--load-format sharded_state \
|
|
--quantization deepspeedfp \
|
|
--tensor-parallel-size 8 \
|
|
--prompt "Hello, my name is" \
|
|
--max-tokens 50
|
|
"""
|
|
|
|
import dataclasses
|
|
|
|
from vllm import LLM, EngineArgs, SamplingParams
|
|
from vllm.utils import FlexibleArgumentParser
|
|
|
|
|
|
def parse_args():
|
|
parser = FlexibleArgumentParser()
|
|
# Add engine arguments
|
|
EngineArgs.add_cli_args(parser)
|
|
|
|
# Override default load_format for clarity
|
|
parser.set_defaults(load_format="sharded_state")
|
|
|
|
# Add validation arguments
|
|
parser.add_argument("--prompt",
|
|
type=str,
|
|
default="Hello, world!",
|
|
help="Prompt for validation")
|
|
parser.add_argument("--max-tokens",
|
|
type=int,
|
|
default=100,
|
|
help="Maximum number of tokens to generate")
|
|
parser.add_argument("--temperature",
|
|
type=float,
|
|
default=0.7,
|
|
help="Sampling temperature")
|
|
parser.add_argument("--top-p",
|
|
type=float,
|
|
default=1.0,
|
|
help="Top-p sampling parameter")
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
engine_args = EngineArgs.from_cli_args(args)
|
|
|
|
print(f"Loading model from {engine_args.model} "
|
|
f"using format {engine_args.load_format}")
|
|
print(f"Tensor parallel size: {engine_args.tensor_parallel_size}")
|
|
|
|
# Load the model using engine args
|
|
llm = LLM(**dataclasses.asdict(engine_args))
|
|
|
|
# Prepare sampling parameters
|
|
sampling_params = SamplingParams(
|
|
temperature=args.temperature,
|
|
top_p=args.top_p,
|
|
max_tokens=args.max_tokens,
|
|
)
|
|
|
|
print("\nRunning inference:")
|
|
print(f"Prompt: {args.prompt}")
|
|
|
|
# Generate completion
|
|
outputs = llm.generate(args.prompt, sampling_params)
|
|
|
|
# Display generated text
|
|
print("\nGenerated outputs:")
|
|
for output in outputs:
|
|
generated_text = output.outputs[0].text
|
|
print("-" * 50)
|
|
print(f"Full output: {args.prompt}{generated_text}")
|
|
print("-" * 50)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |