[Bugfix] Fix score api for missing max_model_len validation (#12119)

Signed-off-by: Wallas Santos <wallashss@ibm.com>
This commit is contained in:
Wallas Henrique 2025-01-17 13:24:22 -03:00 committed by GitHub
parent 87a0c076af
commit 58fd57ff1d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 80 additions and 33 deletions

View File

@ -12,6 +12,9 @@ MODEL_NAME = "BAAI/bge-reranker-v2-m3"
def server(): def server():
args = [ args = [
"--enforce-eager", "--enforce-eager",
# Will be used on tests to compare prompt input length
"--max-model-len",
"100"
] ]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
@ -20,8 +23,7 @@ def server():
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_text_1_str_text_2_list(server: RemoteOpenAIServer, def test_text_1_str_text_2_list(server: RemoteOpenAIServer, model_name: str):
model_name: str):
text_1 = "What is the capital of France?" text_1 = "What is the capital of France?"
text_2 = [ text_2 = [
"The capital of Brazil is Brasilia.", "The capital of France is Paris." "The capital of Brazil is Brasilia.", "The capital of France is Paris."
@ -45,8 +47,7 @@ async def test_text_1_str_text_2_list(server: RemoteOpenAIServer,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_text_1_list_text_2_list(server: RemoteOpenAIServer, def test_text_1_list_text_2_list(server: RemoteOpenAIServer, model_name: str):
model_name: str):
text_1 = [ text_1 = [
"What is the capital of the United States?", "What is the capital of the United States?",
"What is the capital of France?" "What is the capital of France?"
@ -73,8 +74,7 @@ async def test_text_1_list_text_2_list(server: RemoteOpenAIServer,
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_text_1_str_text_2_str(server: RemoteOpenAIServer, def test_text_1_str_text_2_str(server: RemoteOpenAIServer, model_name: str):
model_name: str):
text_1 = "What is the capital of France?" text_1 = "What is the capital of France?"
text_2 = "The capital of France is Paris." text_2 = "The capital of France is Paris."
@ -91,3 +91,36 @@ async def test_text_1_str_text_2_str(server: RemoteOpenAIServer,
assert score.data is not None assert score.data is not None
assert len(score.data) == 1 assert len(score.data) == 1
assert score.data[0].score >= 0.9 assert score.data[0].score >= 0.9
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_score_max_model_len(server: RemoteOpenAIServer, model_name: str):
text_1 = "What is the capital of France?" * 20
text_2 = [
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
]
score_response = requests.post(server.url_for("score"),
json={
"model": model_name,
"text_1": text_1,
"text_2": text_2,
})
assert score_response.status_code == 400
# Assert just a small fragments of the response
assert "Please reduce the length of the input." in \
score_response.text
# Test truncation
score_response = requests.post(server.url_for("score"),
json={
"model": model_name,
"text_1": text_1,
"text_2": text_2,
"truncate_prompt_tokens": 101
})
assert score_response.status_code == 400
assert "Please, select a smaller truncation size." in \
score_response.text

View File

@ -203,15 +203,19 @@ class OpenAIServing:
) -> TextTokensPrompt: ) -> TextTokensPrompt:
token_num = len(input_ids) token_num = len(input_ids)
# Note: EmbeddingRequest doesn't have max_tokens # Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens
if isinstance(request, if isinstance(
(EmbeddingChatRequest, EmbeddingCompletionRequest)): request,
(EmbeddingChatRequest, EmbeddingCompletionRequest, ScoreRequest)):
operation = "score" if isinstance(request, ScoreRequest) \
else "embedding generation"
if token_num > self.max_model_len: if token_num > self.max_model_len:
raise ValueError( raise ValueError(
f"This model's maximum context length is " f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested " f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the input for embedding " f"{token_num} tokens in the input for {operation}. "
f"generation. Please reduce the length of the input.") f"Please reduce the length of the input.")
return TextTokensPrompt(prompt=input_text, return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids) prompt_token_ids=input_ids)

View File

@ -101,15 +101,14 @@ class OpenAIServingScores(OpenAIServing):
if not self.model_config.is_cross_encoder: if not self.model_config.is_cross_encoder:
raise ValueError("Model is not cross encoder.") raise ValueError("Model is not cross encoder.")
except ValueError as e: if truncate_prompt_tokens is not None and \
logger.exception("Error in preprocessing prompt inputs") truncate_prompt_tokens > self.max_model_len:
return self.create_error_response(str(e)) raise ValueError(
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
# Schedule the request and get the result generator. f"is greater than max_model_len ({self.max_model_len})."
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = [] f" Please, select a smaller truncation size.")
input_pairs = make_pairs(request.text_1, request.text_2) input_pairs = make_pairs(request.text_1, request.text_2)
for q, t in input_pairs: for q, t in input_pairs:
request_prompt = f"{q}{tokenizer.sep_token}{t}" request_prompt = f"{q}{tokenizer.sep_token}{t}"
@ -123,13 +122,24 @@ class OpenAIServingScores(OpenAIServing):
prompt_inputs = await tokenize_async(text=q, prompt_inputs = await tokenize_async(text=q,
text_pair=t, text_pair=t,
**tokenization_kwargs) **tokenization_kwargs)
input_ids = prompt_inputs["input_ids"]
text_token_prompt = \
self._validate_input(request, input_ids, request_prompt)
engine_prompt = TokensPrompt( engine_prompt = TokensPrompt(
prompt_token_ids=prompt_inputs["input_ids"], prompt_token_ids=text_token_prompt["prompt_token_ids"],
token_type_ids=prompt_inputs.get("token_type_ids")) token_type_ids=prompt_inputs.get("token_type_ids"))
request_prompts.append(request_prompt) request_prompts.append(request_prompt)
engine_prompts.append(engine_prompt) engine_prompts.append(engine_prompt)
except ValueError as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
# Schedule the request and get the result generator.
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
try: try:
pooling_params = request.to_pooling_params() pooling_params = request.to_pooling_params()