Fix potential bugs in FastAPI frontend and add comments (#28)

This commit is contained in:
Zhuohan Li 2023-04-06 13:44:24 +08:00 committed by GitHub
parent 12659a0bd7
commit a490aafa36
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -17,8 +17,10 @@ from cacheflow.master.server import (Server, add_server_arguments,
from cacheflow.worker.controller import DeviceID
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
app = FastAPI()
class FastAPIFrontend:
def __init__(
self,
@ -30,7 +32,7 @@ class FastAPIFrontend:
dtype: str,
seed: int,
swap_space: int,
max_batch_size: int,
max_num_batched_tokens: int,
num_nodes: int,
num_devices_per_node: int,
distributed_init_method: str,
@ -51,7 +53,7 @@ class FastAPIFrontend:
dtype=dtype,
seed=seed,
swap_space=swap_space,
max_batch_size=max_batch_size,
max_num_batched_tokens=max_num_batched_tokens,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
@ -68,12 +70,14 @@ class FastAPIFrontend:
self.is_server_running = True
updated_seq_groups = await self.server.step.remote()
self.is_server_running = False
# Notify the waiting coroutines that there new outputs ready.
for seq_group in updated_seq_groups:
group_id = seq_group.group_id
self.running_seq_groups[group_id] = seq_group
self.sequence_group_events[group_id].set()
async def generate(self, request_dict: Dict):
# Preprocess the request.
prompt = request_dict["prompt"]
sampling_params = SamplingParams.from_dict(request_dict)
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
@ -87,15 +91,27 @@ class FastAPIFrontend:
arrival_time = time.time()
group_id = next(self.seq_group_counter)
seq_group = SequenceGroup(group_id, seqs, arrival_time)
# Create an event to notify us that there is new output from the
# cacheflow server.
group_event = asyncio.Event()
self.running_seq_groups[group_id] = seq_group
self.sequence_group_events[group_id] = group_event
# Add the request into the cacheflow server's waiting queue.
await self.server.add_sequence_groups.remote([(seq_group, sampling_params)])
# The cacheflow server does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking
# the server to process the requests.
while True:
# Kick the server if the server is not running.
if not self.is_server_running:
await self.server_step()
# Wait for new output. Add a 1s timeout to prevent dead lock.
await asyncio.wait_for(group_event.wait(), timeout=1)
# Wait for new output. The group_event will be set in server_step
# when there is new output available for the sequence group.
# Added a timeout to prevent deadlock.
await asyncio.wait_for(group_event.wait(), timeout=TIMEOUT_TO_PREVENT_DEADLOCK)
# Reset the event to wait for the next output.
group_event.clear()
# Decode and return new outputs
seq_group = self.running_seq_groups[group_id]
all_outputs = []
for seq in seq_group.seqs:
@ -107,7 +123,16 @@ class FastAPIFrontend:
"error": 0,
}
yield (json.dumps(ret) + "\0").encode("utf-8")
# Once finished, release the resources of the sequence group.
if seq_group.is_finished():
del self.running_seq_groups[group_id]
del self.sequence_group_events[group_id]
# Kick the server if the server is not running. This is to
# prevent that there are still requests in server's waiting
# queue to be executed.
if not self.is_server_running:
await self.server_step()
break
@ -143,7 +168,7 @@ if __name__ == "__main__":
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
max_batch_size=args.max_batch_size,
max_num_batched_tokens=args.max_num_batched_tokens,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,