[Bugfix] Fix a bug in RequestOutput.finished (#202)
This commit is contained in:
parent
2e0d314384
commit
14f0b39cda
@ -30,7 +30,7 @@ def main(args: argparse.Namespace):
|
|||||||
|
|
||||||
request_outputs = engine.step()
|
request_outputs = engine.step()
|
||||||
for request_output in request_outputs:
|
for request_output in request_outputs:
|
||||||
if request_output.finished():
|
if request_output.finished:
|
||||||
print(request_output)
|
print(request_output)
|
||||||
|
|
||||||
if not (engine.has_unfinished_requests() or test_prompts):
|
if not (engine.has_unfinished_requests() or test_prompts):
|
||||||
|
@ -154,7 +154,7 @@ class AsyncLLMEngine:
|
|||||||
yield request_output
|
yield request_output
|
||||||
|
|
||||||
# Once finished, release the resources of the sequence group.
|
# Once finished, release the resources of the sequence group.
|
||||||
if request_output.finished():
|
if request_output.finished:
|
||||||
if self.log_requests:
|
if self.log_requests:
|
||||||
logger.info(f"Finished request {request_id}.")
|
logger.info(f"Finished request {request_id}.")
|
||||||
|
|
||||||
|
@ -133,7 +133,7 @@ class LLM:
|
|||||||
while self.llm_engine.has_unfinished_requests():
|
while self.llm_engine.has_unfinished_requests():
|
||||||
step_outputs = self.llm_engine.step()
|
step_outputs = self.llm_engine.step()
|
||||||
for output in step_outputs:
|
for output in step_outputs:
|
||||||
if output.finished():
|
if output.finished:
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
if use_tqdm:
|
if use_tqdm:
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
@ -60,11 +60,13 @@ class RequestOutput:
|
|||||||
prompt: str,
|
prompt: str,
|
||||||
prompt_token_ids: List[int],
|
prompt_token_ids: List[int],
|
||||||
outputs: List[CompletionOutput],
|
outputs: List[CompletionOutput],
|
||||||
|
finished: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
self.prompt = prompt
|
self.prompt = prompt
|
||||||
self.prompt_token_ids = prompt_token_ids
|
self.prompt_token_ids = prompt_token_ids
|
||||||
self.outputs = outputs
|
self.outputs = outputs
|
||||||
|
self.finished = finished
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
|
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
|
||||||
@ -95,13 +97,13 @@ class RequestOutput:
|
|||||||
# Every sequence in the sequence group should have the same prompt.
|
# Every sequence in the sequence group should have the same prompt.
|
||||||
prompt = top_n_seqs[0].prompt
|
prompt = top_n_seqs[0].prompt
|
||||||
prompt_token_ids = top_n_seqs[0].data.prompt_token_ids
|
prompt_token_ids = top_n_seqs[0].data.prompt_token_ids
|
||||||
return cls(seq_group.request_id, prompt, prompt_token_ids, outputs)
|
finished = seq_group.is_finished()
|
||||||
|
return cls(seq_group.request_id, prompt, prompt_token_ids, outputs,
|
||||||
|
finished)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"RequestOutput(request_id={self.request_id}, "
|
return (f"RequestOutput(request_id={self.request_id}, "
|
||||||
f"prompt={self.prompt!r}, "
|
f"prompt={self.prompt!r}, "
|
||||||
f"prompt_token_ids={self.prompt_token_ids}, "
|
f"prompt_token_ids={self.prompt_token_ids}, "
|
||||||
f"outputs={self.outputs})")
|
f"outputs={self.outputs}, "
|
||||||
|
f"finished={self.finished})")
|
||||||
def finished(self) -> bool:
|
|
||||||
return all(output.finished() for output in self.outputs)
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user