Fix potential bugs in FastAPI frontend and add comments (#28)
This commit is contained in:
parent
12659a0bd7
commit
a490aafa36
@ -17,8 +17,10 @@ from cacheflow.master.server import (Server, add_server_arguments,
|
|||||||
from cacheflow.worker.controller import DeviceID
|
from cacheflow.worker.controller import DeviceID
|
||||||
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
|
from cacheflow.utils import Counter, get_gpu_memory, get_cpu_memory
|
||||||
|
|
||||||
|
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
class FastAPIFrontend:
|
class FastAPIFrontend:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -30,7 +32,7 @@ class FastAPIFrontend:
|
|||||||
dtype: str,
|
dtype: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
swap_space: int,
|
swap_space: int,
|
||||||
max_batch_size: int,
|
max_num_batched_tokens: int,
|
||||||
num_nodes: int,
|
num_nodes: int,
|
||||||
num_devices_per_node: int,
|
num_devices_per_node: int,
|
||||||
distributed_init_method: str,
|
distributed_init_method: str,
|
||||||
@ -51,7 +53,7 @@ class FastAPIFrontend:
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
swap_space=swap_space,
|
swap_space=swap_space,
|
||||||
max_batch_size=max_batch_size,
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
num_nodes=num_nodes,
|
num_nodes=num_nodes,
|
||||||
num_devices_per_node=num_devices_per_node,
|
num_devices_per_node=num_devices_per_node,
|
||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
@ -68,12 +70,14 @@ class FastAPIFrontend:
|
|||||||
self.is_server_running = True
|
self.is_server_running = True
|
||||||
updated_seq_groups = await self.server.step.remote()
|
updated_seq_groups = await self.server.step.remote()
|
||||||
self.is_server_running = False
|
self.is_server_running = False
|
||||||
|
# Notify the waiting coroutines that there new outputs ready.
|
||||||
for seq_group in updated_seq_groups:
|
for seq_group in updated_seq_groups:
|
||||||
group_id = seq_group.group_id
|
group_id = seq_group.group_id
|
||||||
self.running_seq_groups[group_id] = seq_group
|
self.running_seq_groups[group_id] = seq_group
|
||||||
self.sequence_group_events[group_id].set()
|
self.sequence_group_events[group_id].set()
|
||||||
|
|
||||||
async def generate(self, request_dict: Dict):
|
async def generate(self, request_dict: Dict):
|
||||||
|
# Preprocess the request.
|
||||||
prompt = request_dict["prompt"]
|
prompt = request_dict["prompt"]
|
||||||
sampling_params = SamplingParams.from_dict(request_dict)
|
sampling_params = SamplingParams.from_dict(request_dict)
|
||||||
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
|
sampling_params.stop_token_ids.add(self.tokenizer.eos_token_id)
|
||||||
@ -87,15 +91,27 @@ class FastAPIFrontend:
|
|||||||
arrival_time = time.time()
|
arrival_time = time.time()
|
||||||
group_id = next(self.seq_group_counter)
|
group_id = next(self.seq_group_counter)
|
||||||
seq_group = SequenceGroup(group_id, seqs, arrival_time)
|
seq_group = SequenceGroup(group_id, seqs, arrival_time)
|
||||||
|
# Create an event to notify us that there is new output from the
|
||||||
|
# cacheflow server.
|
||||||
group_event = asyncio.Event()
|
group_event = asyncio.Event()
|
||||||
|
self.running_seq_groups[group_id] = seq_group
|
||||||
self.sequence_group_events[group_id] = group_event
|
self.sequence_group_events[group_id] = group_event
|
||||||
|
# Add the request into the cacheflow server's waiting queue.
|
||||||
await self.server.add_sequence_groups.remote([(seq_group, sampling_params)])
|
await self.server.add_sequence_groups.remote([(seq_group, sampling_params)])
|
||||||
|
# The cacheflow server does not have a background loop that keeps
|
||||||
|
# processing incoming requests. Therefore, we need to keep kicking
|
||||||
|
# the server to process the requests.
|
||||||
while True:
|
while True:
|
||||||
|
# Kick the server if the server is not running.
|
||||||
if not self.is_server_running:
|
if not self.is_server_running:
|
||||||
await self.server_step()
|
await self.server_step()
|
||||||
# Wait for new output. Add a 1s timeout to prevent dead lock.
|
# Wait for new output. The group_event will be set in server_step
|
||||||
await asyncio.wait_for(group_event.wait(), timeout=1)
|
# when there is new output available for the sequence group.
|
||||||
|
# Added a timeout to prevent deadlock.
|
||||||
|
await asyncio.wait_for(group_event.wait(), timeout=TIMEOUT_TO_PREVENT_DEADLOCK)
|
||||||
|
# Reset the event to wait for the next output.
|
||||||
group_event.clear()
|
group_event.clear()
|
||||||
|
# Decode and return new outputs
|
||||||
seq_group = self.running_seq_groups[group_id]
|
seq_group = self.running_seq_groups[group_id]
|
||||||
all_outputs = []
|
all_outputs = []
|
||||||
for seq in seq_group.seqs:
|
for seq in seq_group.seqs:
|
||||||
@ -107,7 +123,16 @@ class FastAPIFrontend:
|
|||||||
"error": 0,
|
"error": 0,
|
||||||
}
|
}
|
||||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||||
|
|
||||||
|
# Once finished, release the resources of the sequence group.
|
||||||
if seq_group.is_finished():
|
if seq_group.is_finished():
|
||||||
|
del self.running_seq_groups[group_id]
|
||||||
|
del self.sequence_group_events[group_id]
|
||||||
|
# Kick the server if the server is not running. This is to
|
||||||
|
# prevent that there are still requests in server's waiting
|
||||||
|
# queue to be executed.
|
||||||
|
if not self.is_server_running:
|
||||||
|
await self.server_step()
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
@ -143,7 +168,7 @@ if __name__ == "__main__":
|
|||||||
dtype=args.dtype,
|
dtype=args.dtype,
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
swap_space=args.swap_space,
|
swap_space=args.swap_space,
|
||||||
max_batch_size=args.max_batch_size,
|
max_num_batched_tokens=args.max_num_batched_tokens,
|
||||||
num_nodes=num_nodes,
|
num_nodes=num_nodes,
|
||||||
num_devices_per_node=num_devices_per_node,
|
num_devices_per_node=num_devices_per_node,
|
||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user