[Misc] refactor examples (#16563)
Signed-off-by: reidliu41 <reid201711@gmail.com> Co-authored-by: reidliu41 <reid201711@gmail.com>
This commit is contained in:
parent
ce4ddd2d1a
commit
7cbfc10943
@ -95,7 +95,7 @@ def run_decode(prefill_done):
|
|||||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def main():
|
||||||
prefill_done = Event()
|
prefill_done = Event()
|
||||||
prefill_process = Process(target=run_prefill, args=(prefill_done, ))
|
prefill_process = Process(target=run_prefill, args=(prefill_done, ))
|
||||||
decode_process = Process(target=run_decode, args=(prefill_done, ))
|
decode_process = Process(target=run_decode, args=(prefill_done, ))
|
||||||
@ -109,3 +109,7 @@ if __name__ == "__main__":
|
|||||||
# Terminate the prefill node when decode is finished
|
# Terminate the prefill node when decode is finished
|
||||||
decode_process.join()
|
decode_process.join()
|
||||||
prefill_process.terminate()
|
prefill_process.terminate()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
@ -38,6 +38,10 @@ os.environ["LMCACHE_REMOTE_URL"] = f"lm://localhost:{port}"
|
|||||||
# `naive` indicates using raw bytes of the tensor without any compression
|
# `naive` indicates using raw bytes of the tensor without any compression
|
||||||
os.environ["LMCACHE_REMOTE_SERDE"] = "naive"
|
os.environ["LMCACHE_REMOTE_SERDE"] = "naive"
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"Hello, how are you?" * 1000,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def run_prefill(prefill_done, prompts):
|
def run_prefill(prefill_done, prompts):
|
||||||
# We use GPU 0 for prefill node.
|
# We use GPU 0 for prefill node.
|
||||||
@ -106,12 +110,7 @@ def run_lmcache_server(port):
|
|||||||
return server_proc
|
return server_proc
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def main():
|
||||||
|
|
||||||
prompts = [
|
|
||||||
"Hello, how are you?" * 1000,
|
|
||||||
]
|
|
||||||
|
|
||||||
prefill_done = Event()
|
prefill_done = Event()
|
||||||
prefill_process = Process(target=run_prefill, args=(prefill_done, prompts))
|
prefill_process = Process(target=run_prefill, args=(prefill_done, prompts))
|
||||||
decode_process = Process(target=run_decode, args=(prefill_done, prompts))
|
decode_process = Process(target=run_decode, args=(prefill_done, prompts))
|
||||||
@ -128,3 +127,7 @@ if __name__ == "__main__":
|
|||||||
prefill_process.terminate()
|
prefill_process.terminate()
|
||||||
lmcache_server_process.terminate()
|
lmcache_server_process.terminate()
|
||||||
lmcache_server_process.wait()
|
lmcache_server_process.wait()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
@ -2,32 +2,46 @@
|
|||||||
"""
|
"""
|
||||||
Example of using the OpenAI entrypoint's rerank API which is compatible with
|
Example of using the OpenAI entrypoint's rerank API which is compatible with
|
||||||
the Cohere SDK: https://github.com/cohere-ai/cohere-python
|
the Cohere SDK: https://github.com/cohere-ai/cohere-python
|
||||||
|
Note that `pip install cohere` is needed to run this example.
|
||||||
|
|
||||||
run: vllm serve BAAI/bge-reranker-base
|
run: vllm serve BAAI/bge-reranker-base
|
||||||
"""
|
"""
|
||||||
import cohere
|
from typing import Union
|
||||||
|
|
||||||
|
import cohere
|
||||||
|
from cohere import Client, ClientV2
|
||||||
|
|
||||||
|
model = "BAAI/bge-reranker-base"
|
||||||
|
|
||||||
|
query = "What is the capital of France?"
|
||||||
|
|
||||||
# cohere v1 client
|
|
||||||
co = cohere.Client(base_url="http://localhost:8000", api_key="sk-fake-key")
|
|
||||||
rerank_v1_result = co.rerank(
|
|
||||||
model="BAAI/bge-reranker-base",
|
|
||||||
query="What is the capital of France?",
|
|
||||||
documents = [
|
documents = [
|
||||||
"The capital of France is Paris", "Reranking is fun!",
|
"The capital of France is Paris", "Reranking is fun!",
|
||||||
"vLLM is an open-source framework for fast AI serving"
|
"vLLM is an open-source framework for fast AI serving"
|
||||||
])
|
]
|
||||||
|
|
||||||
print(rerank_v1_result)
|
|
||||||
|
def cohere_rerank(client: Union[Client, ClientV2], model: str, query: str,
|
||||||
|
documents: list[str]) -> dict:
|
||||||
|
return client.rerank(model=model, query=query, documents=documents)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# cohere v1 client
|
||||||
|
cohere_v1 = cohere.Client(base_url="http://localhost:8000",
|
||||||
|
api_key="sk-fake-key")
|
||||||
|
rerank_v1_result = cohere_rerank(cohere_v1, model, query, documents)
|
||||||
|
print("-" * 50)
|
||||||
|
print("rerank_v1_result:\n", rerank_v1_result)
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
# or the v2
|
# or the v2
|
||||||
co2 = cohere.ClientV2("sk-fake-key", base_url="http://localhost:8000")
|
cohere_v2 = cohere.ClientV2("sk-fake-key",
|
||||||
|
base_url="http://localhost:8000")
|
||||||
|
rerank_v2_result = cohere_rerank(cohere_v2, model, query, documents)
|
||||||
|
print("rerank_v2_result:\n", rerank_v2_result)
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
v2_rerank_result = co2.rerank(
|
|
||||||
model="BAAI/bge-reranker-base",
|
|
||||||
query="What is the capital of France?",
|
|
||||||
documents=[
|
|
||||||
"The capital of France is Paris", "Reranking is fun!",
|
|
||||||
"vLLM is an open-source framework for fast AI serving"
|
|
||||||
])
|
|
||||||
|
|
||||||
print(v2_rerank_result)
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
@ -23,6 +23,9 @@ data = {
|
|||||||
"The capital of France is Paris.", "Horses and cows are both animals"
|
"The capital of France is Paris.", "Horses and cows are both animals"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
response = requests.post(url, headers=headers, json=data)
|
response = requests.post(url, headers=headers, json=data)
|
||||||
|
|
||||||
# Check the response
|
# Check the response
|
||||||
@ -32,3 +35,7 @@ if response.status_code == 200:
|
|||||||
else:
|
else:
|
||||||
print(f"Request failed with status code: {response.status_code}")
|
print(f"Request failed with status code: {response.status_code}")
|
||||||
print(response.text)
|
print(response.text)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
@ -1,11 +1,30 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
"""Example Python client for OpenAI Chat Completion using vLLM API server
|
||||||
|
NOTE: start a supported chat completion model server with `vllm serve`, e.g.
|
||||||
|
vllm serve meta-llama/Llama-2-7b-chat-hf
|
||||||
|
"""
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||||
openai_api_key = "EMPTY"
|
openai_api_key = "EMPTY"
|
||||||
openai_api_base = "http://localhost:8000/v1"
|
openai_api_base = "http://localhost:8000/v1"
|
||||||
|
|
||||||
|
messages = [{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant."
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "Who won the world series in 2020?"
|
||||||
|
}, {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "The Los Angeles Dodgers won the World Series in 2020."
|
||||||
|
}, {
|
||||||
|
"role": "user",
|
||||||
|
"content": "Where was it played?"
|
||||||
|
}]
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
client = OpenAI(
|
client = OpenAI(
|
||||||
# defaults to os.environ.get("OPENAI_API_KEY")
|
# defaults to os.environ.get("OPENAI_API_KEY")
|
||||||
api_key=openai_api_key,
|
api_key=openai_api_key,
|
||||||
@ -16,23 +35,15 @@ models = client.models.list()
|
|||||||
model = models.data[0].id
|
model = models.data[0].id
|
||||||
|
|
||||||
chat_completion = client.chat.completions.create(
|
chat_completion = client.chat.completions.create(
|
||||||
messages=[{
|
messages=messages,
|
||||||
"role": "system",
|
|
||||||
"content": "You are a helpful assistant."
|
|
||||||
}, {
|
|
||||||
"role": "user",
|
|
||||||
"content": "Who won the world series in 2020?"
|
|
||||||
}, {
|
|
||||||
"role":
|
|
||||||
"assistant",
|
|
||||||
"content":
|
|
||||||
"The Los Angeles Dodgers won the World Series in 2020."
|
|
||||||
}, {
|
|
||||||
"role": "user",
|
|
||||||
"content": "Where was it played?"
|
|
||||||
}],
|
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print("-" * 50)
|
||||||
print("Chat completion results:")
|
print("Chat completion results:")
|
||||||
print(chat_completion)
|
print(chat_completion)
|
||||||
|
print("-" * 50)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user