[Bugfix][Hardware][CPU] Fix broken encoder-decoder CPU runner (#10218)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2024-11-11 20:37:58 +08:00 committed by GitHub
parent 5fb1f935b0
commit 2cebda42bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 16 additions and 0 deletions

View File

@ -18,6 +18,8 @@ source /etc/environment
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN="$HF_TOKEN" --name cpu-test cpu-test
function cpu_tests() {
set -e
# Run basic model test
docker exec cpu-test bash -c "
set -e

View File

@ -20,6 +20,8 @@ docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/hugg
--cpuset-mems=1 --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-avx2 cpu-test-avx2
function cpu_tests() {
set -e
# offline inference
docker exec cpu-test-avx2 bash -c "
set -e

View File

@ -95,6 +95,7 @@ class CPUEmbeddingModelRunner(
model_input.seq_lens)
return dataclasses.replace(model_input,
virtual_engine=virtual_engine,
pooling_metadata=pooling_metadata)
def _prepare_pooling(

View File

@ -4,6 +4,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast
import torch
from vllm.attention import AttentionMetadata
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MultiModalKwargs
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
@ -96,11 +97,21 @@ class CPUEncoderDecoderModelRunner(
encoder_input_positions_tensor,
) = self._prepare_encoder_model_input_tensors(seq_group_metadata_list,
model_input)
# Sampling metadata is only required for the final pp group
generators = self.get_generators(finished_requests_ids)
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
model_input.seq_lens,
model_input.query_lens,
self.device,
pin_memory=False,
generators=generators)
return dataclasses.replace(
model_input,
sampling_metadata=sampling_metadata,
attn_metadata=attn_metadata,
encoder_input_tokens=encoder_input_tokens_tensor,
encoder_input_positions=encoder_input_positions_tensor,
virtual_engine=virtual_engine,
)
def _prepare_encoder_model_input_tensors(