From bc5dd4f669e2f83adec58b38ea11d75c74bc1706 Mon Sep 17 00:00:00 2001 From: Pooya Davoodi Date: Mon, 14 Apr 2025 23:09:58 -0700 Subject: [PATCH] [Bugfix] Fix broken GritLM model and tests (missing pooling_metadata) (#16631) Signed-off-by: Pooya Davoodi --- .../models/embedding/language/test_gritlm.py | 21 ++++++++++--------- vllm/model_executor/models/gritlm.py | 3 ++- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/tests/models/embedding/language/test_gritlm.py b/tests/models/embedding/language/test_gritlm.py index d6bf7d27..87a1dde9 100644 --- a/tests/models/embedding/language/test_gritlm.py +++ b/tests/models/embedding/language/test_gritlm.py @@ -57,24 +57,25 @@ def test_find_array(monkeypatch: pytest.MonkeyPatch): def server_embedding(): # GritLM embedding implementation is only supported by XFormers backend. args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)] - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server + with pytest.MonkeyPatch.context() as m: + m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS") + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server @pytest.fixture(scope="module") def server_generate(): args = ["--task", "generate", "--max_model_len", str(MAX_MODEL_LEN)] - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server + with pytest.MonkeyPatch.context() as m: + m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS") + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server @pytest_asyncio.fixture -async def client_embedding(monkeypatch: pytest.MonkeyPatch, - server_embedding: RemoteOpenAIServer): - with monkeypatch.context() as m: - m.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS") - async with server_embedding.get_async_client() as async_client: - yield async_client +async def client_embedding(server_embedding: RemoteOpenAIServer): + async with server_embedding.get_async_client() as async_client: + yield async_client @pytest_asyncio.fixture diff --git a/vllm/model_executor/models/gritlm.py b/vllm/model_executor/models/gritlm.py index 2984f224..e4692c45 100644 --- a/vllm/model_executor/models/gritlm.py +++ b/vllm/model_executor/models/gritlm.py @@ -170,7 +170,8 @@ class GritLMPooler(nn.Module): mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze( 1) - pooled_data = self.head(mean_embeddings) + pooled_data = self.head(mean_embeddings, + pooling_metadata=pooling_metadata) pooled_outputs = [ PoolingSequenceGroupOutput(data) for data in pooled_data