[HotFix] Fix final output truncation with stop string + streaming (#8468)
This commit is contained in:
parent
f57092c00b
commit
18e9e1f7b3
@ -159,7 +159,8 @@ def should_do_global_cleanup_after_test(request) -> bool:
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
async def test_asyncio_run(async_engine):
|
||||
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
|
||||
async def test_asyncio_run(async_engine, stop):
|
||||
|
||||
scheduler_config = await async_engine.get_scheduler_config()
|
||||
num_scheduler_steps = scheduler_config.num_scheduler_steps
|
||||
@ -169,6 +170,7 @@ async def test_asyncio_run(async_engine):
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
min_tokens=32,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
output_count = 0
|
||||
@ -203,7 +205,8 @@ async def test_asyncio_run(async_engine):
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
async def test_output_kinds(async_engine):
|
||||
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
|
||||
async def test_output_kinds(async_engine, stop):
|
||||
"""Test that output_kind works as expected and that
|
||||
results are equivalent across different kinds."""
|
||||
|
||||
@ -214,6 +217,7 @@ async def test_output_kinds(async_engine):
|
||||
temperature=0,
|
||||
max_tokens=32,
|
||||
min_tokens=32,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
async def run(prompt: str, kind: RequestOutputKind):
|
||||
@ -229,6 +233,8 @@ async def test_output_kinds(async_engine):
|
||||
final_output = output
|
||||
|
||||
assert final_output is not None
|
||||
assert final_output.finished
|
||||
|
||||
return (final_output.prompt_token_ids,
|
||||
final_output.outputs[0].token_ids,
|
||||
final_output.outputs[0].text, output_count)
|
||||
@ -241,16 +247,18 @@ async def test_output_kinds(async_engine):
|
||||
output_tokens: List[int] = []
|
||||
output_text = ""
|
||||
output_count = 0
|
||||
final_output = None
|
||||
async for output in async_engine.generate(prompt,
|
||||
params,
|
||||
request_id=uid()):
|
||||
token_ids = output.outputs[0].token_ids
|
||||
text = output.outputs[0].text
|
||||
final_output = output
|
||||
|
||||
# Ensure we get prompt ids iff we haven't yet received output tokens
|
||||
if output_tokens:
|
||||
assert 1 <= len(token_ids) <= num_scheduler_steps
|
||||
assert text
|
||||
assert stop or text
|
||||
assert not output.prompt_token_ids
|
||||
else:
|
||||
assert output.prompt_token_ids
|
||||
@ -260,6 +268,10 @@ async def test_output_kinds(async_engine):
|
||||
output_text += text
|
||||
|
||||
output_count += 1
|
||||
|
||||
assert final_output is not None
|
||||
assert final_output.finished
|
||||
|
||||
return prompt_tokens, output_tokens, output_text, output_count
|
||||
|
||||
results = await asyncio.gather(
|
||||
@ -291,7 +303,8 @@ async def test_output_kinds(async_engine):
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
async def test_cancellation(async_engine):
|
||||
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
|
||||
async def test_cancellation(async_engine, stop):
|
||||
scheduler_config = await async_engine.get_scheduler_config()
|
||||
num_scheduler_steps = scheduler_config.num_scheduler_steps
|
||||
|
||||
@ -299,6 +312,7 @@ async def test_cancellation(async_engine):
|
||||
temperature=0,
|
||||
min_tokens=13,
|
||||
max_tokens=13,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
stop_at = 5 if num_scheduler_steps == 1 else 1
|
||||
@ -319,7 +333,8 @@ async def test_cancellation(async_engine):
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
async def test_delayed_generator(async_engine):
|
||||
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
|
||||
async def test_delayed_generator(async_engine, stop):
|
||||
scheduler_config = await async_engine.get_scheduler_config()
|
||||
|
||||
if scheduler_config.num_scheduler_steps != 1:
|
||||
@ -329,6 +344,7 @@ async def test_delayed_generator(async_engine):
|
||||
temperature=0,
|
||||
min_tokens=10,
|
||||
max_tokens=10,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
stream = async_engine.generate("test3", sampling_params, request_id=uid())
|
||||
|
@ -477,7 +477,9 @@ class Sequence:
|
||||
if not delta:
|
||||
return self.output_text[:-buffer_length] if truncate else (
|
||||
self.output_text)
|
||||
length = len(self.output_text) - buffer_length
|
||||
length = len(self.output_text)
|
||||
if truncate:
|
||||
length -= buffer_length
|
||||
last_offset = self._last_output_text_offset
|
||||
if last_offset < length:
|
||||
self._last_output_text_offset = length
|
||||
|
Loading…
x
Reference in New Issue
Block a user