FastAPI-based working frontend (#10)

This commit is contained in:
Zhuohan Li 2023-03-29 14:48:56 +08:00 committed by GitHub
parent d359cda5fa
commit 721fa3df15
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 536 additions and 146 deletions

View File

@ -8,9 +8,46 @@ pip install flash-attn # This may take up to 10 mins.
pip install -e .
```
## Run
## Test simple server
```bash
ray start --head
python server.py [--tensor-parallel-size <N>]
python simple_server.py
```
The detailed arguments for `simple_server.py` can be found by:
```bash
python simple_server.py --help
```
## FastAPI server
Install the following additional dependencies:
```bash
pip install fastapi uvicorn
```
To start the server:
```bash
ray start --head
python -m cacheflow.http_frontend.fastapi_frontend
```
To test the server:
```bash
python -m cacheflow.http_frontend.test_cli_client
```
## Gradio web server
Install the following additional dependencies:
```bash
pip install gradio
```
Start the server:
```bash
python -m cacheflow.http_frontend.fastapi_frontend
# At another terminal
python -m cacheflow.http_frontend.gradio_webserver
```

View File

@ -0,0 +1,152 @@
import argparse
import asyncio
import time
from typing import List, Dict
import json
import ray
from transformers import AutoTokenizer
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
import uvicorn
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence, SequenceGroup
from cacheflow.master.server import (Server, add_server_arguments,
initialize_ray_cluster)
from cacheflow.worker.controller import DeviceID
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
app = FastAPI()
class FastAPIFrontend:
def __init__(
self,
model: str,
model_path: str,
pipeline_parallel_size: int,
tensor_parallel_size: int,
block_size: int,
dtype: str,
seed: int,
swap_space: int,
max_batch_size: int,
num_nodes: int,
num_devices_per_node: int,
distributed_init_method: str,
all_stage_devices: List[List[DeviceID]],
):
self.block_size = block_size
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.seq_group_counter = Counter()
self.seq_counter = Counter()
remote_server_class = ray.remote(num_cpus=0)(Server)
self.server = remote_server_class.remote(
model=model,
model_path=model_path,
pipeline_parallel_size=pipeline_parallel_size,
tensor_parallel_size=tensor_parallel_size,
block_size=block_size,
dtype=dtype,
seed=seed,
swap_space=swap_space,
max_batch_size=max_batch_size,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
)
self.running_seq_groups: Dict[int, SequenceGroup] = {}
self.sequence_group_events: Dict[int, asyncio.Event] = {}
self.is_server_running = False
async def server_step(self):
self.is_server_running = True
updated_seq_groups = await self.server.step.remote()
self.is_server_running = False
for seq_group in updated_seq_groups:
group_id = seq_group.group_id
self.running_seq_groups[group_id] = seq_group
self.sequence_group_events[group_id].set()
async def generate(self, request_dict: Dict):
prompt = request_dict["prompt"]
sampling_params = SamplingParams.from_dict(request_dict)
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
token_ids = self.tokenizer.encode(prompt)
seqs: List[Sequence] = []
for _ in range(sampling_params.n):
seq_id = next(self.seq_counter)
seq = Sequence(seq_id, token_ids, block_size=self.block_size)
seqs.append(seq)
group_id = next(self.seq_group_counter)
seq_group = SequenceGroup(group_id, seqs)
group_event = asyncio.Event()
self.sequence_group_events[group_id] = group_event
await self.server.add_sequence_groups.remote([(seq_group, sampling_params)])
while True:
if not self.is_server_running:
await self.server_step()
# Wait for new output. Add a 1s timeout to prevent dead lock.
await asyncio.wait_for(group_event.wait(), timeout=1)
group_event.clear()
seq_group = self.running_seq_groups[group_id]
all_outputs = []
for seq in seq_group.seqs:
token_ids = seq.get_token_ids()
output = self.tokenizer.decode(token_ids, skip_special_tokens=True)
all_outputs.append(output)
ret = {
"text": all_outputs,
"error": 0,
}
yield (json.dumps(ret) + "\0").encode("utf-8")
if seq_group.is_finished():
break
@app.post("/generate")
async def generate_stream(request: Request):
request_dict = await request.json()
return StreamingResponse(frontend.generate(request_dict))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=10002)
parser = add_server_arguments(parser)
args = parser.parse_args()
# TODO(zhuohan): Support pipeline parallelism.
assert args.pipeline_parallel_size == 1, (
'Pipeline parallelism is not supported yet.')
(num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices) = (
initialize_ray_cluster(
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))
frontend = FastAPIFrontend(
model=args.model,
model_path=args.model_path,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
block_size=args.block_size,
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
max_batch_size=args.max_batch_size,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")

View File

@ -0,0 +1,43 @@
import argparse
import json
import time
import gradio as gr
import requests
def http_bot(prompt):
headers = {"User-Agent": "Cacheflow Client"}
pload = {
"prompt": prompt,
"max_num_steps": 128,
}
response = requests.post(args.model_url, headers=headers, json=pload, stream=True)
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"][0]
yield output
def build_demo():
with gr.Blocks() as demo:
gr.Markdown(
"# Cacheflow demo\n"
)
inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER")# .style(container=False)
outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model")
inputbox.submit(http_bot, [inputbox], [outputbox])
return demo
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=10003)
parser.add_argument("--model-url", type=str, default="http://localhost:10002/generate")
args = parser.parse_args()
demo = build_demo()
demo.queue(concurrency_count=100).launch(server_name=args.host, server_port=args.port)

View File

@ -0,0 +1,23 @@
import requests
import json
def http_request():
prompt = "Ion Stoica is a"
headers = {"User-Agent": "Test Client"}
pload = {
"prompt": prompt,
"n": 4,
"use_beam_search": True,
"temperature": 0.0,
}
response = requests.post("http://localhost:10002/generate", headers=headers, json=pload, stream=True)
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"]
yield output
for h in http_request():
print(h, flush=True)

View File

@ -1,7 +1,6 @@
from typing import Dict, List
from typing import Dict, List, Tuple
from cacheflow.master.block_manager import BlockSpaceManager
from cacheflow.master.frontend import Frontend
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
@ -14,14 +13,12 @@ class Scheduler:
def __init__(
self,
frontend: Frontend,
controllers: List,
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
max_num_batched_tokens: int,
) -> None:
self.frontend = frontend
self.controllers = controllers
self.block_size = block_size
self.num_gpu_blocks = num_gpu_blocks
@ -47,9 +44,12 @@ class Scheduler:
# Pending sequence groups (FIFO).
self.pending: List[SequenceGroup] = []
def _fetch_inputs(self) -> None:
inputs = self.frontend.get_inputs()
for seq_group, sampling_params in inputs:
def add_sequence_groups(
self,
sequence_groups: List[Tuple[SequenceGroup, SamplingParams]],
) -> None:
# Add sequence groups to the pending queue.
for seq_group, sampling_params in sequence_groups:
self.pending.append(seq_group)
self.sampling_params[seq_group.group_id] = sampling_params
@ -104,7 +104,7 @@ class Scheduler:
seq.status = SequenceStatus.SWAPPED
self.swapped.append(seq_group)
def step(self) -> None:
def step(self) -> List[SequenceGroup]:
# Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {}
@ -158,7 +158,6 @@ class Scheduler:
# 3. Join new sequences if possible.
# NOTE: Here we implicitly assume FCFS scheduling.
# TODO(woosuk): Add a batching policy to control the batch size.
self._fetch_inputs()
if not self.swapped:
for i, seq_group in enumerate(self.pending):
num_prompt_tokens = seq_group.seqs[0].get_len()
@ -176,6 +175,8 @@ class Scheduler:
# 4. Create input data structures.
input_seq_groups: List[SequenceGroupInputs] = []
updated_seq_groups: List[SequenceGroup] = self.running.copy()
for seq_group in self.running:
group_id = seq_group.group_id
num_steps = self.num_steps[group_id]
@ -219,6 +220,8 @@ class Scheduler:
blocks_to_copy,
)
return updated_seq_groups
def post_step(
self,
seq_outputs: Dict[int, SequenceOutputs],
@ -268,13 +271,12 @@ class Scheduler:
running: List[SequenceGroup] = []
for seq_group in self.running:
if seq_group.is_finished():
self._return(seq_group)
self._free_seq_group(seq_group)
else:
running.append(seq_group)
self.running = running
def _return(self, seq_group: SequenceGroup) -> None:
def _free_seq_group(self, seq_group: SequenceGroup) -> None:
group_id = seq_group.group_id
del self.num_steps[group_id]
del self.sampling_params[group_id]
self.frontend.print_response(seq_group)

View File

@ -1,13 +1,98 @@
import argparse
from typing import List, Tuple
import random
from typing import List, Tuple, Dict
import ray
from cacheflow.master.frontend import Frontend
from cacheflow.master.scheduler import Scheduler
from cacheflow.models import get_memory_analyzer
from cacheflow.worker.controller import Controller, DeviceID
from cacheflow.sequence import SequenceGroup
from cacheflow.sampling_params import SamplingParams
class Server:
def __init__(
self,
model: str,
model_path: str,
pipeline_parallel_size: int,
tensor_parallel_size: int,
block_size: int,
dtype: str,
seed: int,
swap_space: int,
max_batch_size: int,
num_nodes: int,
num_devices_per_node: int,
distributed_init_method: str,
all_stage_devices: List[List[DeviceID]],
gpu_memory: int,
cpu_memory: int,
):
self.num_nodes = num_nodes
self.num_devices_per_node = num_devices_per_node
self.world_size = pipeline_parallel_size * tensor_parallel_size
self.memory_analyzer = get_memory_analyzer(
model_name=model,
block_size=block_size,
dtype=dtype,
gpu_memory=gpu_memory,
cpu_memory=cpu_memory,
tensor_parallel_size=tensor_parallel_size,
)
self.num_gpu_blocks = self.memory_analyzer.get_max_num_gpu_blocks(
max_num_batched_tokens=max_batch_size)
self.num_cpu_blocks = self.memory_analyzer.get_max_num_cpu_blocks(
swap_space=swap_space)
print(f'# GPU blocks: {self.num_gpu_blocks}, '
f'# CPU blocks: {self.num_cpu_blocks}')
# Create a controller for each pipeline stage.
self.controllers: List[Controller] = []
for i in range(pipeline_parallel_size):
controller = Controller(
stage_id=i,
stage_devices=all_stage_devices[i],
world_size=self.world_size,
pipeline_parallel_size=pipeline_parallel_size,
tensor_parallel_size=tensor_parallel_size,
distributed_init_method=distributed_init_method,
model_name=model,
block_size=block_size,
num_gpu_blocks=self.num_gpu_blocks,
num_cpu_blocks=self.num_cpu_blocks,
dtype=dtype,
seed=seed,
model_path=model_path,
)
self.controllers.append(controller)
# Create a scheduler.
self.scheduler = Scheduler(
controllers=self.controllers,
block_size=block_size,
num_gpu_blocks=self.num_gpu_blocks,
num_cpu_blocks=self.num_cpu_blocks,
max_num_batched_tokens=max_batch_size,
)
# Connect the controllers.
for i in range(len(self.controllers) - 1):
self.controllers[i].set_next(self.controllers[i + 1])
self.controllers[-1].set_next(self.scheduler)
def add_sequence_groups(
self,
sequence_groups: List[Tuple[SequenceGroup, SamplingParams]]
):
self.scheduler.add_sequence_groups(sequence_groups)
def step(self):
return self.scheduler.step()
def has_unfinished_requests(self):
return (self.scheduler.pending or self.scheduler.running or
self.scheduler.swapped)
def initialize_ray_cluster(
@ -76,88 +161,7 @@ def initialize_ray_cluster(
all_stage_devices)
def main(args: argparse.Namespace):
# TODO(zhuohan): Support pipeline parallelism.
assert args.pipeline_parallel_size == 1, (
'Pipeline parallelism is not supported yet.')
(num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices) = (
initialize_ray_cluster(
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))
world_size = args.pipeline_parallel_size * args.tensor_parallel_size
memory_analyzer = get_memory_analyzer(
model_name=args.model,
block_size=args.block_size,
dtype=args.dtype,
tensor_parallel_size=args.tensor_parallel_size,
)
num_gpu_blocks = memory_analyzer.get_max_num_gpu_blocks(
max_num_batched_tokens=args.max_batch_size)
num_cpu_blocks = memory_analyzer.get_max_num_cpu_blocks(
swap_space=args.swap_space)
print(f'# GPU blocks: {num_gpu_blocks}, # CPU blocks: {num_cpu_blocks}')
# Create a controller for each pipeline stage.
controllers: List[Controller] = []
for i in range(args.pipeline_parallel_size):
controller = Controller(
stage_id=i,
stage_devices=all_stage_devices[i],
world_size=world_size,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
distributed_init_method=distributed_init_method,
model_name=args.model,
block_size=args.block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
dtype=args.dtype,
seed=args.seed,
model_path=args.model_path,
)
controllers.append(controller)
# Create a frontend.
frontend = Frontend(
model_name=args.model,
block_size=args.block_size,
)
# Create a scheduler.
scheduler = Scheduler(
frontend=frontend,
controllers=controllers,
block_size=args.block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
max_num_batched_tokens=args.max_batch_size,
)
# Connect the controllers.
for i in range(len(controllers) - 1):
controllers[i].set_next(controllers[i + 1])
controllers[-1].set_next(scheduler)
# Test the following inputs.
test_inputs = [
('Ion Stoica is a', {'n': 4, 'use_beam_search': True, 'temperature': 0.0}),
('UC Berkeley is', {'n': 3, 'temperature': 0.8, 'top_p': 0.99}),
('The future of cloud computing is', {}), # Use default parameters.
]
while True:
if test_inputs:
text, sampling_params = test_inputs.pop(0)
frontend.query(text, **sampling_params)
scheduler.step()
if not (scheduler.pending or scheduler.running or test_inputs):
break
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='CacheFlow server')
def add_server_arguments(parser: argparse.ArgumentParser):
# Model arguments
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
parser.add_argument('--model-path', type=str, default='~/.cacheflow/model_weights',
@ -173,6 +177,4 @@ if __name__ == '__main__':
parser.add_argument('--seed', type=int, default=0, help='random seed')
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
parser.add_argument('--max-batch-size', type=int, default=2560, help='maximum number of batched tokens')
args = parser.parse_args()
main(args)
return parser

View File

@ -3,12 +3,11 @@ from typing import List, Optional, Set, Tuple
from transformers import AutoTokenizer
from cacheflow.sampling_params import SamplingParams
from cacheflow.sequence import Sequence
from cacheflow.sequence import SequenceGroup
from cacheflow.sequence import Sequence, SequenceGroup
from cacheflow.utils import Counter
class Frontend:
class SimpleFrontend:
def __init__(
self,
@ -22,30 +21,16 @@ class Frontend:
self.seq_counter = Counter()
self.inputs: List[Tuple[SequenceGroup, SamplingParams]] = []
def add_eos_token(self, sampling_params: SamplingParams) -> SamplingParams:
# Stop generation when we see an EOS token.
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
return sampling_params
def query(
self,
prompt: str,
n: int = 1,
temperature: float = 1.0,
top_p: float = 1.0,
use_beam_search: bool = False,
stop_token_ids: Set[int] = set(),
max_num_steps: int = 16, # From OpenAI API.
num_logprobs: int = 0,
context_window_size: Optional[int] = None,
sampling_params: SamplingParams,
) -> None:
# Stop when we see an EOS token.
stop_token_ids.add(self.tokenizer.eos_token_id)
sampling_params = SamplingParams(
n=n,
temperature=temperature,
top_p=top_p,
use_beam_search=use_beam_search,
stop_token_ids=stop_token_ids,
max_num_steps=max_num_steps,
num_logprobs=num_logprobs,
context_window_size=context_window_size,
)
token_ids = self.tokenizer.encode(prompt)
self._add_query(token_ids, sampling_params)

View File

@ -1,9 +1,7 @@
import torch
from transformers import AutoConfig
from cacheflow.models.utils import get_cpu_memory
from cacheflow.models.utils import get_dtype_size
from cacheflow.models.utils import get_gpu_memory
_GiB = 1 << 30
@ -31,11 +29,15 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
model_name: str,
block_size: int,
dtype: torch.dtype,
gpu_memory: int,
cpu_memory: int,
tensor_parallel_size: int,
) -> None:
self.model_name = model_name
self.block_size = block_size
self.dtype = dtype
self.gpu_memory = gpu_memory
self.cpu_memory = cpu_memory
self.tensor_parallel_size = tensor_parallel_size
config = AutoConfig.from_pretrained(model_name)
@ -106,8 +108,7 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
memory_utilization: float = 0.95,
) -> int:
# NOTE(woosuk): This assumes that the machine has homogeneous GPUs.
gpu_memory = get_gpu_memory()
usable_memory = int(memory_utilization * gpu_memory)
usable_memory = int(memory_utilization * self.gpu_memory)
param_size = self._get_param_size()
act_size = self._get_max_act_size(max_num_batched_tokens)
@ -122,16 +123,15 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
swap_space: int,
) -> int:
swap_space = swap_space * _GiB
cpu_memory = get_cpu_memory()
if swap_space > 0.8 * cpu_memory:
if swap_space > 0.8 * self.cpu_memory:
raise ValueError(f'The swap space ({swap_space / _GiB:.2f} GiB) '
'takes more than 80% of the available memory '
f'({cpu_memory / _GiB:.2f} GiB).'
f'({self.cpu_memory / _GiB:.2f} GiB).'
'Please check the swap space size.')
if swap_space > 0.5 * cpu_memory:
if swap_space > 0.5 * self.cpu_memory:
print(f'WARNING: The swap space ({swap_space / _GiB:.2f} GiB) '
'takes more than 50% of the available memory '
f'({cpu_memory / _GiB:.2f} GiB).'
f'({self.cpu_memory / _GiB:.2f} GiB).'
'This may slow the system performance.')
max_num_blocks = swap_space // self._get_cache_block_size()
return max_num_blocks

View File

@ -44,11 +44,14 @@ def get_memory_analyzer(
model_name: str,
block_size: int,
dtype: Union[torch.dtype, str],
gpu_memory: int,
cpu_memory: int,
tensor_parallel_size: int = 1,
) -> CacheFlowMemoryAnalyzer:
torch_dtype = get_torch_dtype(dtype)
for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
if model_class in model_name:
return memory_analyzer(
model_name, block_size, torch_dtype, tensor_parallel_size)
model_name, block_size, torch_dtype, gpu_memory, cpu_memory,
tensor_parallel_size)
raise ValueError(f'Unsupported model name: {model_name}')

View File

@ -1,9 +1,5 @@
from typing import Union
import random
import numpy as np
import psutil
import torch
_STR_DTYPE_TO_TORCH_DTYPE = {
@ -26,10 +22,3 @@ def get_dtype_size(dtype: Union[torch.dtype, str]) -> int:
torch_dtype = get_torch_dtype(dtype)
return torch.tensor([], dtype=torch_dtype).element_size()
def get_gpu_memory(gpu: int = 0) -> int:
return torch.cuda.get_device_properties(gpu).total_memory
def get_cpu_memory() -> int:
return psutil.virtual_memory().total

View File

@ -1,4 +1,4 @@
from typing import Optional, Set
from typing import Optional, Set, Dict
class SamplingParams:
@ -69,3 +69,16 @@ class SamplingParams:
f'max_num_steps={self.max_num_steps}, '
f'num_logprobs={self.num_logprobs}, '
f'context_window_size={self.context_window_size})')
@classmethod
def from_dict(cls, d: Dict) -> 'SamplingParams':
return cls(
n=d.get('n', 1),
temperature=d.get('temperature', 1.0),
top_p=d.get('top_p', 1.0),
use_beam_search=d.get('use_beam_search', False),
stop_token_ids=set(d.get('stop_token_ids', set())),
max_num_steps=d.get('max_num_steps', 16),
num_logprobs=d.get('num_logprobs', 0),
context_window_size=d.get('context_window_size', None),
)

View File

@ -1,5 +1,6 @@
import enum
import random
import psutil
import numpy as np
import torch
@ -26,6 +27,7 @@ class Counter:
def reset(self) -> None:
self.counter = 0
def set_random_seed(seed: int):
random.seed(seed)
np.random.seed(seed)
@ -35,3 +37,11 @@ def set_random_seed(seed: int):
if model_parallel_is_initialized():
model_parallel_cuda_manual_seed(seed)
def get_gpu_memory(gpu: int = 0) -> int:
return torch.cuda.get_device_properties(gpu).total_memory
def get_cpu_memory() -> int:
return psutil.virtual_memory().total

20
playground/http_client.py Normal file
View File

@ -0,0 +1,20 @@
import requests
import json
def http_bot():
prompt = "How are you? I'm fine."
headers = {"User-Agent": "Test Client"}
pload = {
"prompt": prompt,
}
response = requests.post("http://localhost:10002", headers=headers, json=pload, stream=True)
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"]
yield output
for h in http_bot():
print(h, end="", flush=True)

View File

@ -0,0 +1,40 @@
import argparse
import asyncio
import time
from typing import Union
import json
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
import uvicorn
app = FastAPI()
async def text_streamer(args):
context = args["prompt"]
words = context.split(" ")
for word in words:
await asyncio.sleep(1)
print("word:", word)
ret = {
"text": word + " ",
"error": 0,
}
yield (json.dumps(ret) + "\0").encode("utf-8")
@app.post("/")
async def read_root(request: Request):
args = await request.json()
return StreamingResponse(text_streamer(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=10002)
args = parser.parse_args()
uvicorn.run(app, host=args.host, port=args.port, log_level="info")

71
simple_server.py Normal file
View File

@ -0,0 +1,71 @@
import argparse
from typing import List
from cacheflow.master.simple_frontend import SimpleFrontend
from cacheflow.master.server import (Server, add_server_arguments,
initialize_ray_cluster)
from cacheflow.sampling_params import SamplingParams
from cacheflow.utils import get_gpu_memory, get_cpu_memory
def main(args: argparse.Namespace):
# TODO(zhuohan): Support pipeline parallelism.
assert args.pipeline_parallel_size == 1, (
'Pipeline parallelism is not supported yet.')
(num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices) = (
initialize_ray_cluster(
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))
# Create a server.
server = Server(
model=args.model,
model_path=args.model_path,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
block_size=args.block_size,
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
max_batch_size=args.max_batch_size,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
)
# Create a frontend.
frontend = SimpleFrontend(
model_name=args.model,
block_size=args.block_size,
)
# Test the following inputs.
test_inputs = [
('Ion Stoica is a', {'n': 4, 'use_beam_search': True, 'temperature': 0.0}),
('UC Berkeley is', {'n': 3, 'temperature': 0.8, 'top_p': 0.99}),
('The future of cloud computing is', {}), # Use default parameters.
]
while True:
if test_inputs:
text, sampling_params_dict = test_inputs.pop(0)
sampling_params = SamplingParams.from_dict(sampling_params_dict)
sampling_params = frontend.add_eos_token(sampling_params)
frontend.query(text, sampling_params)
server.add_sequence_groups(frontend.get_inputs())
updated_seq_groups = server.step()
for seq_group in updated_seq_groups:
if seq_group.is_finished():
frontend.print_response(seq_group)
if not (server.has_unfinished_requests() or test_inputs):
break
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
parser = add_server_arguments(parser)
args = parser.parse_args()
main(args)