Fix potential bugs in FastAPI frontend and add comments (#28)
This commit is contained in:
parent
12659a0bd7
commit
a490aafa36
@ -17,8 +17,10 @@ from cacheflow.master.server import (Server, add_server_arguments,
|
||||
from cacheflow.worker.controller import DeviceID
|
||||
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
|
||||
|
||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
class FastAPIFrontend:
|
||||
def __init__(
|
||||
self,
|
||||
@ -30,7 +32,7 @@ class FastAPIFrontend:
|
||||
dtype: str,
|
||||
seed: int,
|
||||
swap_space: int,
|
||||
max_batch_size: int,
|
||||
max_num_batched_tokens: int,
|
||||
num_nodes: int,
|
||||
num_devices_per_node: int,
|
||||
distributed_init_method: str,
|
||||
@ -51,7 +53,7 @@ class FastAPIFrontend:
|
||||
dtype=dtype,
|
||||
seed=seed,
|
||||
swap_space=swap_space,
|
||||
max_batch_size=max_batch_size,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
num_nodes=num_nodes,
|
||||
num_devices_per_node=num_devices_per_node,
|
||||
distributed_init_method=distributed_init_method,
|
||||
@ -68,12 +70,14 @@ class FastAPIFrontend:
|
||||
self.is_server_running = True
|
||||
updated_seq_groups = await self.server.step.remote()
|
||||
self.is_server_running = False
|
||||
# Notify the waiting coroutines that there new outputs ready.
|
||||
for seq_group in updated_seq_groups:
|
||||
group_id = seq_group.group_id
|
||||
self.running_seq_groups[group_id] = seq_group
|
||||
self.sequence_group_events[group_id].set()
|
||||
|
||||
async def generate(self, request_dict: Dict):
|
||||
# Preprocess the request.
|
||||
prompt = request_dict["prompt"]
|
||||
sampling_params = SamplingParams.from_dict(request_dict)
|
||||
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
|
||||
@ -87,15 +91,27 @@ class FastAPIFrontend:
|
||||
arrival_time = time.time()
|
||||
group_id = next(self.seq_group_counter)
|
||||
seq_group = SequenceGroup(group_id, seqs, arrival_time)
|
||||
# Create an event to notify us that there is new output from the
|
||||
# cacheflow server.
|
||||
group_event = asyncio.Event()
|
||||
self.running_seq_groups[group_id] = seq_group
|
||||
self.sequence_group_events[group_id] = group_event
|
||||
# Add the request into the cacheflow server's waiting queue.
|
||||
await self.server.add_sequence_groups.remote([(seq_group, sampling_params)])
|
||||
# The cacheflow server does not have a background loop that keeps
|
||||
# processing incoming requests. Therefore, we need to keep kicking
|
||||
# the server to process the requests.
|
||||
while True:
|
||||
# Kick the server if the server is not running.
|
||||
if not self.is_server_running:
|
||||
await self.server_step()
|
||||
# Wait for new output. Add a 1s timeout to prevent dead lock.
|
||||
await asyncio.wait_for(group_event.wait(), timeout=1)
|
||||
# Wait for new output. The group_event will be set in server_step
|
||||
# when there is new output available for the sequence group.
|
||||
# Added a timeout to prevent deadlock.
|
||||
await asyncio.wait_for(group_event.wait(), timeout=TIMEOUT_TO_PREVENT_DEADLOCK)
|
||||
# Reset the event to wait for the next output.
|
||||
group_event.clear()
|
||||
# Decode and return new outputs
|
||||
seq_group = self.running_seq_groups[group_id]
|
||||
all_outputs = []
|
||||
for seq in seq_group.seqs:
|
||||
@ -107,7 +123,16 @@ class FastAPIFrontend:
|
||||
"error": 0,
|
||||
}
|
||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||
|
||||
# Once finished, release the resources of the sequence group.
|
||||
if seq_group.is_finished():
|
||||
del self.running_seq_groups[group_id]
|
||||
del self.sequence_group_events[group_id]
|
||||
# Kick the server if the server is not running. This is to
|
||||
# prevent that there are still requests in server's waiting
|
||||
# queue to be executed.
|
||||
if not self.is_server_running:
|
||||
await self.server_step()
|
||||
break
|
||||
|
||||
|
||||
@ -143,7 +168,7 @@ if __name__ == "__main__":
|
||||
dtype=args.dtype,
|
||||
seed=args.seed,
|
||||
swap_space=args.swap_space,
|
||||
max_batch_size=args.max_batch_size,
|
||||
max_num_batched_tokens=args.max_num_batched_tokens,
|
||||
num_nodes=num_nodes,
|
||||
num_devices_per_node=num_devices_per_node,
|
||||
distributed_init_method=distributed_init_method,
|
||||
|
Loading…
x
Reference in New Issue
Block a user