[HotFix] Fix final output truncation with stop string + streaming (#8468)

This commit is contained in:
Nick Hill 2024-09-13 19:31:12 +01:00 committed by GitHub
parent f57092c00b
commit 18e9e1f7b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 24 additions and 6 deletions

View File

@ -159,7 +159,8 @@ def should_do_global_cleanup_after_test(request) -> bool:
@pytest.mark.asyncio(scope="module") @pytest.mark.asyncio(scope="module")
async def test_asyncio_run(async_engine): @pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_asyncio_run(async_engine, stop):
scheduler_config = await async_engine.get_scheduler_config() scheduler_config = await async_engine.get_scheduler_config()
num_scheduler_steps = scheduler_config.num_scheduler_steps num_scheduler_steps = scheduler_config.num_scheduler_steps
@ -169,6 +170,7 @@ async def test_asyncio_run(async_engine):
temperature=0, temperature=0,
max_tokens=32, max_tokens=32,
min_tokens=32, min_tokens=32,
stop=stop,
) )
output_count = 0 output_count = 0
@ -203,7 +205,8 @@ async def test_asyncio_run(async_engine):
@pytest.mark.asyncio(scope="module") @pytest.mark.asyncio(scope="module")
async def test_output_kinds(async_engine): @pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_output_kinds(async_engine, stop):
"""Test that output_kind works as expected and that """Test that output_kind works as expected and that
results are equivalent across different kinds.""" results are equivalent across different kinds."""
@ -214,6 +217,7 @@ async def test_output_kinds(async_engine):
temperature=0, temperature=0,
max_tokens=32, max_tokens=32,
min_tokens=32, min_tokens=32,
stop=stop,
) )
async def run(prompt: str, kind: RequestOutputKind): async def run(prompt: str, kind: RequestOutputKind):
@ -229,6 +233,8 @@ async def test_output_kinds(async_engine):
final_output = output final_output = output
assert final_output is not None assert final_output is not None
assert final_output.finished
return (final_output.prompt_token_ids, return (final_output.prompt_token_ids,
final_output.outputs[0].token_ids, final_output.outputs[0].token_ids,
final_output.outputs[0].text, output_count) final_output.outputs[0].text, output_count)
@ -241,16 +247,18 @@ async def test_output_kinds(async_engine):
output_tokens: List[int] = [] output_tokens: List[int] = []
output_text = "" output_text = ""
output_count = 0 output_count = 0
final_output = None
async for output in async_engine.generate(prompt, async for output in async_engine.generate(prompt,
params, params,
request_id=uid()): request_id=uid()):
token_ids = output.outputs[0].token_ids token_ids = output.outputs[0].token_ids
text = output.outputs[0].text text = output.outputs[0].text
final_output = output
# Ensure we get prompt ids iff we haven't yet received output tokens # Ensure we get prompt ids iff we haven't yet received output tokens
if output_tokens: if output_tokens:
assert 1 <= len(token_ids) <= num_scheduler_steps assert 1 <= len(token_ids) <= num_scheduler_steps
assert text assert stop or text
assert not output.prompt_token_ids assert not output.prompt_token_ids
else: else:
assert output.prompt_token_ids assert output.prompt_token_ids
@ -260,6 +268,10 @@ async def test_output_kinds(async_engine):
output_text += text output_text += text
output_count += 1 output_count += 1
assert final_output is not None
assert final_output.finished
return prompt_tokens, output_tokens, output_text, output_count return prompt_tokens, output_tokens, output_text, output_count
results = await asyncio.gather( results = await asyncio.gather(
@ -291,7 +303,8 @@ async def test_output_kinds(async_engine):
@pytest.mark.asyncio(scope="module") @pytest.mark.asyncio(scope="module")
async def test_cancellation(async_engine): @pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_cancellation(async_engine, stop):
scheduler_config = await async_engine.get_scheduler_config() scheduler_config = await async_engine.get_scheduler_config()
num_scheduler_steps = scheduler_config.num_scheduler_steps num_scheduler_steps = scheduler_config.num_scheduler_steps
@ -299,6 +312,7 @@ async def test_cancellation(async_engine):
temperature=0, temperature=0,
min_tokens=13, min_tokens=13,
max_tokens=13, max_tokens=13,
stop=stop,
) )
stop_at = 5 if num_scheduler_steps == 1 else 1 stop_at = 5 if num_scheduler_steps == 1 else 1
@ -319,7 +333,8 @@ async def test_cancellation(async_engine):
@pytest.mark.asyncio(scope="module") @pytest.mark.asyncio(scope="module")
async def test_delayed_generator(async_engine): @pytest.mark.parametrize("stop", [None, ["a stop string"]])
async def test_delayed_generator(async_engine, stop):
scheduler_config = await async_engine.get_scheduler_config() scheduler_config = await async_engine.get_scheduler_config()
if scheduler_config.num_scheduler_steps != 1: if scheduler_config.num_scheduler_steps != 1:
@ -329,6 +344,7 @@ async def test_delayed_generator(async_engine):
temperature=0, temperature=0,
min_tokens=10, min_tokens=10,
max_tokens=10, max_tokens=10,
stop=stop,
) )
stream = async_engine.generate("test3", sampling_params, request_id=uid()) stream = async_engine.generate("test3", sampling_params, request_id=uid())

View File

@ -477,7 +477,9 @@ class Sequence:
if not delta: if not delta:
return self.output_text[:-buffer_length] if truncate else ( return self.output_text[:-buffer_length] if truncate else (
self.output_text) self.output_text)
length = len(self.output_text) - buffer_length length = len(self.output_text)
if truncate:
length -= buffer_length
last_offset = self._last_output_text_offset last_offset = self._last_output_text_offset
if last_offset < length: if last_offset < length:
self._last_output_text_offset = length self._last_output_text_offset = length