From 7cbfc1094359d52508bf18611e6cc46bea2e1d43 Mon Sep 17 00:00:00 2001 From: Reid <61492567+reidliu41@users.noreply.github.com> Date: Mon, 14 Apr 2025 17:59:15 +0800 Subject: [PATCH] [Misc] refactor examples (#16563) Signed-off-by: reidliu41 Co-authored-by: reidliu41 --- .../disaggregated_prefill.py | 6 +- .../disaggregated_prefill_lmcache.py | 15 +++-- .../online_serving/cohere_rerank_client.py | 54 +++++++++------ .../online_serving/jinaai_rerank_client.py | 23 ++++--- .../openai_chat_completion_client.py | 67 +++++++++++-------- 5 files changed, 102 insertions(+), 63 deletions(-) diff --git a/examples/offline_inference/disaggregated_prefill.py b/examples/offline_inference/disaggregated_prefill.py index 36ee24bf..d6098514 100644 --- a/examples/offline_inference/disaggregated_prefill.py +++ b/examples/offline_inference/disaggregated_prefill.py @@ -95,7 +95,7 @@ def run_decode(prefill_done): print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") -if __name__ == "__main__": +def main(): prefill_done = Event() prefill_process = Process(target=run_prefill, args=(prefill_done, )) decode_process = Process(target=run_decode, args=(prefill_done, )) @@ -109,3 +109,7 @@ if __name__ == "__main__": # Terminate the prefill node when decode is finished decode_process.join() prefill_process.terminate() + + +if __name__ == "__main__": + main() diff --git a/examples/offline_inference/disaggregated_prefill_lmcache.py b/examples/offline_inference/disaggregated_prefill_lmcache.py index 5c84bbfc..7da6fb7a 100644 --- a/examples/offline_inference/disaggregated_prefill_lmcache.py +++ b/examples/offline_inference/disaggregated_prefill_lmcache.py @@ -38,6 +38,10 @@ os.environ["LMCACHE_REMOTE_URL"] = f"lm://localhost:{port}" # `naive` indicates using raw bytes of the tensor without any compression os.environ["LMCACHE_REMOTE_SERDE"] = "naive" +prompts = [ + "Hello, how are you?" * 1000, +] + def run_prefill(prefill_done, prompts): # We use GPU 0 for prefill node. @@ -106,12 +110,7 @@ def run_lmcache_server(port): return server_proc -if __name__ == "__main__": - - prompts = [ - "Hello, how are you?" * 1000, - ] - +def main(): prefill_done = Event() prefill_process = Process(target=run_prefill, args=(prefill_done, prompts)) decode_process = Process(target=run_decode, args=(prefill_done, prompts)) @@ -128,3 +127,7 @@ if __name__ == "__main__": prefill_process.terminate() lmcache_server_process.terminate() lmcache_server_process.wait() + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/cohere_rerank_client.py b/examples/online_serving/cohere_rerank_client.py index fc434ada..c2d4ef08 100644 --- a/examples/online_serving/cohere_rerank_client.py +++ b/examples/online_serving/cohere_rerank_client.py @@ -2,32 +2,46 @@ """ Example of using the OpenAI entrypoint's rerank API which is compatible with the Cohere SDK: https://github.com/cohere-ai/cohere-python +Note that `pip install cohere` is needed to run this example. run: vllm serve BAAI/bge-reranker-base """ +from typing import Union + import cohere +from cohere import Client, ClientV2 -# cohere v1 client -co = cohere.Client(base_url="http://localhost:8000", api_key="sk-fake-key") -rerank_v1_result = co.rerank( - model="BAAI/bge-reranker-base", - query="What is the capital of France?", - documents=[ - "The capital of France is Paris", "Reranking is fun!", - "vLLM is an open-source framework for fast AI serving" - ]) +model = "BAAI/bge-reranker-base" -print(rerank_v1_result) +query = "What is the capital of France?" -# or the v2 -co2 = cohere.ClientV2("sk-fake-key", base_url="http://localhost:8000") +documents = [ + "The capital of France is Paris", "Reranking is fun!", + "vLLM is an open-source framework for fast AI serving" +] -v2_rerank_result = co2.rerank( - model="BAAI/bge-reranker-base", - query="What is the capital of France?", - documents=[ - "The capital of France is Paris", "Reranking is fun!", - "vLLM is an open-source framework for fast AI serving" - ]) -print(v2_rerank_result) +def cohere_rerank(client: Union[Client, ClientV2], model: str, query: str, + documents: list[str]) -> dict: + return client.rerank(model=model, query=query, documents=documents) + + +def main(): + # cohere v1 client + cohere_v1 = cohere.Client(base_url="http://localhost:8000", + api_key="sk-fake-key") + rerank_v1_result = cohere_rerank(cohere_v1, model, query, documents) + print("-" * 50) + print("rerank_v1_result:\n", rerank_v1_result) + print("-" * 50) + + # or the v2 + cohere_v2 = cohere.ClientV2("sk-fake-key", + base_url="http://localhost:8000") + rerank_v2_result = cohere_rerank(cohere_v2, model, query, documents) + print("rerank_v2_result:\n", rerank_v2_result) + print("-" * 50) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/jinaai_rerank_client.py b/examples/online_serving/jinaai_rerank_client.py index 3e760e17..3076bba7 100644 --- a/examples/online_serving/jinaai_rerank_client.py +++ b/examples/online_serving/jinaai_rerank_client.py @@ -23,12 +23,19 @@ data = { "The capital of France is Paris.", "Horses and cows are both animals" ] } -response = requests.post(url, headers=headers, json=data) -# Check the response -if response.status_code == 200: - print("Request successful!") - print(json.dumps(response.json(), indent=2)) -else: - print(f"Request failed with status code: {response.status_code}") - print(response.text) + +def main(): + response = requests.post(url, headers=headers, json=data) + + # Check the response + if response.status_code == 200: + print("Request successful!") + print(json.dumps(response.json(), indent=2)) + else: + print(f"Request failed with status code: {response.status_code}") + print(response.text) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/openai_chat_completion_client.py b/examples/online_serving/openai_chat_completion_client.py index a8156204..74e0c045 100644 --- a/examples/online_serving/openai_chat_completion_client.py +++ b/examples/online_serving/openai_chat_completion_client.py @@ -1,38 +1,49 @@ # SPDX-License-Identifier: Apache-2.0 - +"""Example Python client for OpenAI Chat Completion using vLLM API server +NOTE: start a supported chat completion model server with `vllm serve`, e.g. + vllm serve meta-llama/Llama-2-7b-chat-hf +""" from openai import OpenAI # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" openai_api_base = "http://localhost:8000/v1" -client = OpenAI( - # defaults to os.environ.get("OPENAI_API_KEY") - api_key=openai_api_key, - base_url=openai_api_base, -) +messages = [{ + "role": "system", + "content": "You are a helpful assistant." +}, { + "role": "user", + "content": "Who won the world series in 2020?" +}, { + "role": "assistant", + "content": "The Los Angeles Dodgers won the World Series in 2020." +}, { + "role": "user", + "content": "Where was it played?" +}] -models = client.models.list() -model = models.data[0].id -chat_completion = client.chat.completions.create( - messages=[{ - "role": "system", - "content": "You are a helpful assistant." - }, { - "role": "user", - "content": "Who won the world series in 2020?" - }, { - "role": - "assistant", - "content": - "The Los Angeles Dodgers won the World Series in 2020." - }, { - "role": "user", - "content": "Where was it played?" - }], - model=model, -) +def main(): + client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, + ) -print("Chat completion results:") -print(chat_completion) + models = client.models.list() + model = models.data[0].id + + chat_completion = client.chat.completions.create( + messages=messages, + model=model, + ) + + print("-" * 50) + print("Chat completion results:") + print(chat_completion) + print("-" * 50) + + +if __name__ == "__main__": + main()