[Bugfix] Fix broken GritLM model and tests (missing pooling_metadata) (#16631)

Signed-off-by: Pooya Davoodi <pooya.davoodi@parasail.io>
This commit is contained in:
Pooya Davoodi 2025-04-14 23:09:58 -07:00 committed by GitHub
parent dbb036cf61
commit bc5dd4f669
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 13 additions and 11 deletions

View File

@ -57,24 +57,25 @@ def test_find_array(monkeypatch: pytest.MonkeyPatch):
def server_embedding():
# GritLM embedding implementation is only supported by XFormers backend.
args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
with pytest.MonkeyPatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.fixture(scope="module")
def server_generate():
args = ["--task", "generate", "--max_model_len", str(MAX_MODEL_LEN)]
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
with pytest.MonkeyPatch.context() as m:
m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest_asyncio.fixture
async def client_embedding(monkeypatch: pytest.MonkeyPatch,
server_embedding: RemoteOpenAIServer):
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
async with server_embedding.get_async_client() as async_client:
yield async_client
async def client_embedding(server_embedding: RemoteOpenAIServer):
async with server_embedding.get_async_client() as async_client:
yield async_client
@pytest_asyncio.fixture

View File

@ -170,7 +170,8 @@ class GritLMPooler(nn.Module):
mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze(
1)
pooled_data = self.head(mean_embeddings)
pooled_data = self.head(mean_embeddings,
pooling_metadata=pooling_metadata)
pooled_outputs = [
PoolingSequenceGroupOutput(data) for data in pooled_data