[Bugfix] Fix score api for missing max_model_len validation (#12119)
Signed-off-by: Wallas Santos <wallashss@ibm.com>
This commit is contained in:
parent
87a0c076af
commit
58fd57ff1d
@ -12,6 +12,9 @@ MODEL_NAME = "BAAI/bge-reranker-v2-m3"
|
|||||||
def server():
|
def server():
|
||||||
args = [
|
args = [
|
||||||
"--enforce-eager",
|
"--enforce-eager",
|
||||||
|
# Will be used on tests to compare prompt input length
|
||||||
|
"--max-model-len",
|
||||||
|
"100"
|
||||||
]
|
]
|
||||||
|
|
||||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
@ -20,8 +23,7 @@ def server():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
async def test_text_1_str_text_2_list(server: RemoteOpenAIServer,
|
def test_text_1_str_text_2_list(server: RemoteOpenAIServer, model_name: str):
|
||||||
model_name: str):
|
|
||||||
text_1 = "What is the capital of France?"
|
text_1 = "What is the capital of France?"
|
||||||
text_2 = [
|
text_2 = [
|
||||||
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||||
@ -45,8 +47,7 @@ async def test_text_1_str_text_2_list(server: RemoteOpenAIServer,
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
async def test_text_1_list_text_2_list(server: RemoteOpenAIServer,
|
def test_text_1_list_text_2_list(server: RemoteOpenAIServer, model_name: str):
|
||||||
model_name: str):
|
|
||||||
text_1 = [
|
text_1 = [
|
||||||
"What is the capital of the United States?",
|
"What is the capital of the United States?",
|
||||||
"What is the capital of France?"
|
"What is the capital of France?"
|
||||||
@ -73,8 +74,7 @@ async def test_text_1_list_text_2_list(server: RemoteOpenAIServer,
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
async def test_text_1_str_text_2_str(server: RemoteOpenAIServer,
|
def test_text_1_str_text_2_str(server: RemoteOpenAIServer, model_name: str):
|
||||||
model_name: str):
|
|
||||||
text_1 = "What is the capital of France?"
|
text_1 = "What is the capital of France?"
|
||||||
text_2 = "The capital of France is Paris."
|
text_2 = "The capital of France is Paris."
|
||||||
|
|
||||||
@ -91,3 +91,36 @@ async def test_text_1_str_text_2_str(server: RemoteOpenAIServer,
|
|||||||
assert score.data is not None
|
assert score.data is not None
|
||||||
assert len(score.data) == 1
|
assert len(score.data) == 1
|
||||||
assert score.data[0].score >= 0.9
|
assert score.data[0].score >= 0.9
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||||
|
def test_score_max_model_len(server: RemoteOpenAIServer, model_name: str):
|
||||||
|
|
||||||
|
text_1 = "What is the capital of France?" * 20
|
||||||
|
text_2 = [
|
||||||
|
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||||
|
]
|
||||||
|
|
||||||
|
score_response = requests.post(server.url_for("score"),
|
||||||
|
json={
|
||||||
|
"model": model_name,
|
||||||
|
"text_1": text_1,
|
||||||
|
"text_2": text_2,
|
||||||
|
})
|
||||||
|
assert score_response.status_code == 400
|
||||||
|
# Assert just a small fragments of the response
|
||||||
|
assert "Please reduce the length of the input." in \
|
||||||
|
score_response.text
|
||||||
|
|
||||||
|
# Test truncation
|
||||||
|
score_response = requests.post(server.url_for("score"),
|
||||||
|
json={
|
||||||
|
"model": model_name,
|
||||||
|
"text_1": text_1,
|
||||||
|
"text_2": text_2,
|
||||||
|
"truncate_prompt_tokens": 101
|
||||||
|
})
|
||||||
|
assert score_response.status_code == 400
|
||||||
|
assert "Please, select a smaller truncation size." in \
|
||||||
|
score_response.text
|
||||||
|
@ -203,15 +203,19 @@ class OpenAIServing:
|
|||||||
) -> TextTokensPrompt:
|
) -> TextTokensPrompt:
|
||||||
token_num = len(input_ids)
|
token_num = len(input_ids)
|
||||||
|
|
||||||
# Note: EmbeddingRequest doesn't have max_tokens
|
# Note: EmbeddingRequest and ScoreRequest doesn't have max_tokens
|
||||||
if isinstance(request,
|
if isinstance(
|
||||||
(EmbeddingChatRequest, EmbeddingCompletionRequest)):
|
request,
|
||||||
|
(EmbeddingChatRequest, EmbeddingCompletionRequest, ScoreRequest)):
|
||||||
|
|
||||||
|
operation = "score" if isinstance(request, ScoreRequest) \
|
||||||
|
else "embedding generation"
|
||||||
if token_num > self.max_model_len:
|
if token_num > self.max_model_len:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"This model's maximum context length is "
|
f"This model's maximum context length is "
|
||||||
f"{self.max_model_len} tokens. However, you requested "
|
f"{self.max_model_len} tokens. However, you requested "
|
||||||
f"{token_num} tokens in the input for embedding "
|
f"{token_num} tokens in the input for {operation}. "
|
||||||
f"generation. Please reduce the length of the input.")
|
f"Please reduce the length of the input.")
|
||||||
return TextTokensPrompt(prompt=input_text,
|
return TextTokensPrompt(prompt=input_text,
|
||||||
prompt_token_ids=input_ids)
|
prompt_token_ids=input_ids)
|
||||||
|
|
||||||
|
@ -101,6 +101,38 @@ class OpenAIServingScores(OpenAIServing):
|
|||||||
if not self.model_config.is_cross_encoder:
|
if not self.model_config.is_cross_encoder:
|
||||||
raise ValueError("Model is not cross encoder.")
|
raise ValueError("Model is not cross encoder.")
|
||||||
|
|
||||||
|
if truncate_prompt_tokens is not None and \
|
||||||
|
truncate_prompt_tokens > self.max_model_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"truncate_prompt_tokens value ({truncate_prompt_tokens}) "
|
||||||
|
f"is greater than max_model_len ({self.max_model_len})."
|
||||||
|
f" Please, select a smaller truncation size.")
|
||||||
|
|
||||||
|
input_pairs = make_pairs(request.text_1, request.text_2)
|
||||||
|
for q, t in input_pairs:
|
||||||
|
request_prompt = f"{q}{tokenizer.sep_token}{t}"
|
||||||
|
|
||||||
|
tokenization_kwargs: Dict[str, Any] = {}
|
||||||
|
if truncate_prompt_tokens is not None:
|
||||||
|
tokenization_kwargs["truncation"] = True
|
||||||
|
tokenization_kwargs["max_length"] = truncate_prompt_tokens
|
||||||
|
|
||||||
|
tokenize_async = make_async(tokenizer.__call__,
|
||||||
|
executor=self._tokenizer_executor)
|
||||||
|
prompt_inputs = await tokenize_async(text=q,
|
||||||
|
text_pair=t,
|
||||||
|
**tokenization_kwargs)
|
||||||
|
|
||||||
|
input_ids = prompt_inputs["input_ids"]
|
||||||
|
text_token_prompt = \
|
||||||
|
self._validate_input(request, input_ids, request_prompt)
|
||||||
|
engine_prompt = TokensPrompt(
|
||||||
|
prompt_token_ids=text_token_prompt["prompt_token_ids"],
|
||||||
|
token_type_ids=prompt_inputs.get("token_type_ids"))
|
||||||
|
|
||||||
|
request_prompts.append(request_prompt)
|
||||||
|
engine_prompts.append(engine_prompt)
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.exception("Error in preprocessing prompt inputs")
|
logger.exception("Error in preprocessing prompt inputs")
|
||||||
return self.create_error_response(str(e))
|
return self.create_error_response(str(e))
|
||||||
@ -108,28 +140,6 @@ class OpenAIServingScores(OpenAIServing):
|
|||||||
# Schedule the request and get the result generator.
|
# Schedule the request and get the result generator.
|
||||||
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
|
generators: List[AsyncGenerator[PoolingRequestOutput, None]] = []
|
||||||
|
|
||||||
input_pairs = make_pairs(request.text_1, request.text_2)
|
|
||||||
|
|
||||||
for q, t in input_pairs:
|
|
||||||
request_prompt = f"{q}{tokenizer.sep_token}{t}"
|
|
||||||
|
|
||||||
tokenization_kwargs: Dict[str, Any] = {}
|
|
||||||
if truncate_prompt_tokens is not None:
|
|
||||||
tokenization_kwargs["truncation"] = True
|
|
||||||
tokenization_kwargs["max_length"] = truncate_prompt_tokens
|
|
||||||
|
|
||||||
tokenize_async = make_async(tokenizer.__call__,
|
|
||||||
executor=self._tokenizer_executor)
|
|
||||||
prompt_inputs = await tokenize_async(text=q,
|
|
||||||
text_pair=t,
|
|
||||||
**tokenization_kwargs)
|
|
||||||
engine_prompt = TokensPrompt(
|
|
||||||
prompt_token_ids=prompt_inputs["input_ids"],
|
|
||||||
token_type_ids=prompt_inputs.get("token_type_ids"))
|
|
||||||
|
|
||||||
request_prompts.append(request_prompt)
|
|
||||||
engine_prompts.append(engine_prompt)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pooling_params = request.to_pooling_params()
|
pooling_params = request.to_pooling_params()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user