[Misc] Improved prefix cache example (#9077)
This commit is contained in:
parent
fbb74420e7
commit
05c531be47
@ -1,7 +1,8 @@
|
|||||||
from time import time
|
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
# NOTE: This is just a running example. For benchmarking purpose,
|
||||||
|
# please see benchmarks/benchmark_prefix_caching.py
|
||||||
|
|
||||||
# Common prefix.
|
# Common prefix.
|
||||||
prefix = (
|
prefix = (
|
||||||
"You are an expert school principal, skilled in effectively managing "
|
"You are an expert school principal, skilled in effectively managing "
|
||||||
@ -37,9 +38,7 @@ print("Results without `enable_prefix_caching`")
|
|||||||
|
|
||||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
# Generate texts from the prompts. The output is a list of RequestOutput objects
|
||||||
# that contain the prompt, generated text, and other information.
|
# that contain the prompt, generated text, and other information.
|
||||||
start_time_regular = time()
|
|
||||||
outputs = regular_llm.generate(generating_prompts, sampling_params)
|
outputs = regular_llm.generate(generating_prompts, sampling_params)
|
||||||
duration_regular = time() - start_time_regular
|
|
||||||
|
|
||||||
regular_generated_texts = []
|
regular_generated_texts = []
|
||||||
# Print the outputs.
|
# Print the outputs.
|
||||||
@ -55,9 +54,7 @@ print("-" * 80)
|
|||||||
prefix_cached_llm.generate(generating_prompts[0], sampling_params)
|
prefix_cached_llm.generate(generating_prompts[0], sampling_params)
|
||||||
|
|
||||||
# Generate with prefix caching.
|
# Generate with prefix caching.
|
||||||
start_time_cached = time()
|
|
||||||
outputs = prefix_cached_llm.generate(generating_prompts, sampling_params)
|
outputs = prefix_cached_llm.generate(generating_prompts, sampling_params)
|
||||||
duration_cached = time() - start_time_cached
|
|
||||||
|
|
||||||
print("Results with `enable_prefix_caching`")
|
print("Results with `enable_prefix_caching`")
|
||||||
|
|
||||||
@ -77,6 +74,3 @@ generated_same = all([
|
|||||||
for i in range(len(prompts))
|
for i in range(len(prompts))
|
||||||
])
|
])
|
||||||
print(f"Generated answers are the same: {generated_same}")
|
print(f"Generated answers are the same: {generated_same}")
|
||||||
|
|
||||||
speedup = round(duration_regular / duration_cached, 2)
|
|
||||||
print(f"Speed up of cached generation compared to the regular is: {speedup}")
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user