[HotFix] Fix final output truncation with stop string + streaming (#8468)
This commit is contained in:
parent
f57092c00b
commit
18e9e1f7b3
@ -159,7 +159,8 @@ def should_do_global_cleanup_after_test(request) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="module")
|
@pytest.mark.asyncio(scope="module")
|
||||||
async def test_asyncio_run(async_engine):
|
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
|
||||||
|
async def test_asyncio_run(async_engine, stop):
|
||||||
|
|
||||||
scheduler_config = await async_engine.get_scheduler_config()
|
scheduler_config = await async_engine.get_scheduler_config()
|
||||||
num_scheduler_steps = scheduler_config.num_scheduler_steps
|
num_scheduler_steps = scheduler_config.num_scheduler_steps
|
||||||
@ -169,6 +170,7 @@ async def test_asyncio_run(async_engine):
|
|||||||
temperature=0,
|
temperature=0,
|
||||||
max_tokens=32,
|
max_tokens=32,
|
||||||
min_tokens=32,
|
min_tokens=32,
|
||||||
|
stop=stop,
|
||||||
)
|
)
|
||||||
|
|
||||||
output_count = 0
|
output_count = 0
|
||||||
@ -203,7 +205,8 @@ async def test_asyncio_run(async_engine):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="module")
|
@pytest.mark.asyncio(scope="module")
|
||||||
async def test_output_kinds(async_engine):
|
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
|
||||||
|
async def test_output_kinds(async_engine, stop):
|
||||||
"""Test that output_kind works as expected and that
|
"""Test that output_kind works as expected and that
|
||||||
results are equivalent across different kinds."""
|
results are equivalent across different kinds."""
|
||||||
|
|
||||||
@ -214,6 +217,7 @@ async def test_output_kinds(async_engine):
|
|||||||
temperature=0,
|
temperature=0,
|
||||||
max_tokens=32,
|
max_tokens=32,
|
||||||
min_tokens=32,
|
min_tokens=32,
|
||||||
|
stop=stop,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run(prompt: str, kind: RequestOutputKind):
|
async def run(prompt: str, kind: RequestOutputKind):
|
||||||
@ -229,6 +233,8 @@ async def test_output_kinds(async_engine):
|
|||||||
final_output = output
|
final_output = output
|
||||||
|
|
||||||
assert final_output is not None
|
assert final_output is not None
|
||||||
|
assert final_output.finished
|
||||||
|
|
||||||
return (final_output.prompt_token_ids,
|
return (final_output.prompt_token_ids,
|
||||||
final_output.outputs[0].token_ids,
|
final_output.outputs[0].token_ids,
|
||||||
final_output.outputs[0].text, output_count)
|
final_output.outputs[0].text, output_count)
|
||||||
@ -241,16 +247,18 @@ async def test_output_kinds(async_engine):
|
|||||||
output_tokens: List[int] = []
|
output_tokens: List[int] = []
|
||||||
output_text = ""
|
output_text = ""
|
||||||
output_count = 0
|
output_count = 0
|
||||||
|
final_output = None
|
||||||
async for output in async_engine.generate(prompt,
|
async for output in async_engine.generate(prompt,
|
||||||
params,
|
params,
|
||||||
request_id=uid()):
|
request_id=uid()):
|
||||||
token_ids = output.outputs[0].token_ids
|
token_ids = output.outputs[0].token_ids
|
||||||
text = output.outputs[0].text
|
text = output.outputs[0].text
|
||||||
|
final_output = output
|
||||||
|
|
||||||
# Ensure we get prompt ids iff we haven't yet received output tokens
|
# Ensure we get prompt ids iff we haven't yet received output tokens
|
||||||
if output_tokens:
|
if output_tokens:
|
||||||
assert 1 <= len(token_ids) <= num_scheduler_steps
|
assert 1 <= len(token_ids) <= num_scheduler_steps
|
||||||
assert text
|
assert stop or text
|
||||||
assert not output.prompt_token_ids
|
assert not output.prompt_token_ids
|
||||||
else:
|
else:
|
||||||
assert output.prompt_token_ids
|
assert output.prompt_token_ids
|
||||||
@ -260,6 +268,10 @@ async def test_output_kinds(async_engine):
|
|||||||
output_text += text
|
output_text += text
|
||||||
|
|
||||||
output_count += 1
|
output_count += 1
|
||||||
|
|
||||||
|
assert final_output is not None
|
||||||
|
assert final_output.finished
|
||||||
|
|
||||||
return prompt_tokens, output_tokens, output_text, output_count
|
return prompt_tokens, output_tokens, output_text, output_count
|
||||||
|
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
@ -291,7 +303,8 @@ async def test_output_kinds(async_engine):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="module")
|
@pytest.mark.asyncio(scope="module")
|
||||||
async def test_cancellation(async_engine):
|
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
|
||||||
|
async def test_cancellation(async_engine, stop):
|
||||||
scheduler_config = await async_engine.get_scheduler_config()
|
scheduler_config = await async_engine.get_scheduler_config()
|
||||||
num_scheduler_steps = scheduler_config.num_scheduler_steps
|
num_scheduler_steps = scheduler_config.num_scheduler_steps
|
||||||
|
|
||||||
@ -299,6 +312,7 @@ async def test_cancellation(async_engine):
|
|||||||
temperature=0,
|
temperature=0,
|
||||||
min_tokens=13,
|
min_tokens=13,
|
||||||
max_tokens=13,
|
max_tokens=13,
|
||||||
|
stop=stop,
|
||||||
)
|
)
|
||||||
|
|
||||||
stop_at = 5 if num_scheduler_steps == 1 else 1
|
stop_at = 5 if num_scheduler_steps == 1 else 1
|
||||||
@ -319,7 +333,8 @@ async def test_cancellation(async_engine):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio(scope="module")
|
@pytest.mark.asyncio(scope="module")
|
||||||
async def test_delayed_generator(async_engine):
|
@pytest.mark.parametrize("stop", [None, ["a stop string"]])
|
||||||
|
async def test_delayed_generator(async_engine, stop):
|
||||||
scheduler_config = await async_engine.get_scheduler_config()
|
scheduler_config = await async_engine.get_scheduler_config()
|
||||||
|
|
||||||
if scheduler_config.num_scheduler_steps != 1:
|
if scheduler_config.num_scheduler_steps != 1:
|
||||||
@ -329,6 +344,7 @@ async def test_delayed_generator(async_engine):
|
|||||||
temperature=0,
|
temperature=0,
|
||||||
min_tokens=10,
|
min_tokens=10,
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
|
stop=stop,
|
||||||
)
|
)
|
||||||
|
|
||||||
stream = async_engine.generate("test3", sampling_params, request_id=uid())
|
stream = async_engine.generate("test3", sampling_params, request_id=uid())
|
||||||
|
@ -477,7 +477,9 @@ class Sequence:
|
|||||||
if not delta:
|
if not delta:
|
||||||
return self.output_text[:-buffer_length] if truncate else (
|
return self.output_text[:-buffer_length] if truncate else (
|
||||||
self.output_text)
|
self.output_text)
|
||||||
length = len(self.output_text) - buffer_length
|
length = len(self.output_text)
|
||||||
|
if truncate:
|
||||||
|
length -= buffer_length
|
||||||
last_offset = self._last_output_text_offset
|
last_offset = self._last_output_text_offset
|
||||||
if last_offset < length:
|
if last_offset < length:
|
||||||
self._last_output_text_offset = length
|
self._last_output_text_offset = length
|
||||||
|
Loading…
x
Reference in New Issue
Block a user