[Bugfix] Fix broken GritLM model and tests (missing pooling_metadata) (#16631)
Signed-off-by: Pooya Davoodi <pooya.davoodi@parasail.io>
This commit is contained in:
parent
dbb036cf61
commit
bc5dd4f669
@ -57,24 +57,25 @@ def test_find_array(monkeypatch: pytest.MonkeyPatch):
|
||||
def server_embedding():
|
||||
# GritLM embedding implementation is only supported by XFormers backend.
|
||||
args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)]
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
with pytest.MonkeyPatch.context() as m:
|
||||
m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server_generate():
|
||||
args = ["--task", "generate", "--max_model_len", str(MAX_MODEL_LEN)]
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
with pytest.MonkeyPatch.context() as m:
|
||||
m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS")
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def client_embedding(monkeypatch: pytest.MonkeyPatch,
|
||||
server_embedding: RemoteOpenAIServer):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS")
|
||||
async with server_embedding.get_async_client() as async_client:
|
||||
yield async_client
|
||||
async def client_embedding(server_embedding: RemoteOpenAIServer):
|
||||
async with server_embedding.get_async_client() as async_client:
|
||||
yield async_client
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
|
@ -170,7 +170,8 @@ class GritLMPooler(nn.Module):
|
||||
mean_embeddings = sum_embeddings / num_non_instruction_tokens.unsqueeze(
|
||||
1)
|
||||
|
||||
pooled_data = self.head(mean_embeddings)
|
||||
pooled_data = self.head(mean_embeddings,
|
||||
pooling_metadata=pooling_metadata)
|
||||
|
||||
pooled_outputs = [
|
||||
PoolingSequenceGroupOutput(data) for data in pooled_data
|
||||
|
Loading…
x
Reference in New Issue
Block a user