[Misc] Consolidate pooler config overrides (#10351)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2024-11-15 14:59:00 +08:00 committed by GitHub
parent 2ec8827288
commit 2ac6d0e75b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 141 additions and 190 deletions

View File

@ -345,6 +345,9 @@ Text Embedding
Some model architectures support both generation and embedding tasks. Some model architectures support both generation and embedding tasks.
In this case, you have to pass :code:`--task embedding` to run the model in embedding mode. In this case, you have to pass :code:`--task embedding` to run the model in embedding mode.
.. tip::
You can override the model's pooling method by passing :code:`--override-pooler-config`.
Reward Modeling Reward Modeling
--------------- ---------------
@ -364,7 +367,7 @@ Reward Modeling
- ✅︎ - ✅︎
.. note:: .. note::
As an interim measure, these models are supported via Embeddings API. See `this RFC <https://github.com/vllm-project/vllm/issues/8967>`_ for upcoming changes. As an interim measure, these models are supported in both offline and online inference via Embeddings API.
Classification Classification
--------------- ---------------
@ -385,7 +388,7 @@ Classification
- ✅︎ - ✅︎
.. note:: .. note::
As an interim measure, these models are supported via Embeddings API. It will be supported via Classification API in the future (no reference APIs exist now). As an interim measure, these models are supported in both offline and online inference via Embeddings API.
Multimodal Language Models Multimodal Language Models
@ -600,6 +603,9 @@ Multimodal Embedding
Some model architectures support both generation and embedding tasks. Some model architectures support both generation and embedding tasks.
In this case, you have to pass :code:`--task embedding` to run the model in embedding mode. In this case, you have to pass :code:`--task embedding` to run the model in embedding mode.
.. tip::
You can override the model's pooling method by passing :code:`--override-pooler-config`.
Model Support Policy Model Support Policy
===================== =====================

View File

@ -2,6 +2,7 @@ from argparse import ArgumentTypeError
import pytest import pytest
from vllm.config import PoolerConfig
from vllm.engine.arg_utils import EngineArgs, nullable_kvs from vllm.engine.arg_utils import EngineArgs, nullable_kvs
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
@ -32,9 +33,13 @@ def test_limit_mm_per_prompt_parser(arg, expected):
def test_valid_pooling_config(): def test_valid_pooling_config():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args = parser.parse_args(["--pooling-type=MEAN"]) args = parser.parse_args([
'--override-pooler-config',
'{"pooling_type": "MEAN"}',
])
engine_args = EngineArgs.from_cli_args(args=args) engine_args = EngineArgs.from_cli_args(args=args)
assert engine_args.pooling_type == 'MEAN' assert engine_args.override_pooler_config == PoolerConfig(
pooling_type="MEAN", )
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@ -1,6 +1,8 @@
from dataclasses import asdict
import pytest import pytest
from vllm.config import ModelConfig from vllm.config import ModelConfig, PoolerConfig
from vllm.model_executor.layers.pooler import PoolingType from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -108,7 +110,7 @@ def test_get_sliding_window():
reason="Xformers backend is not supported on ROCm.") reason="Xformers backend is not supported on ROCm.")
def test_get_pooling_config(): def test_get_pooling_config():
model_id = "sentence-transformers/all-MiniLM-L12-v2" model_id = "sentence-transformers/all-MiniLM-L12-v2"
minilm_model_config = ModelConfig( model_config = ModelConfig(
model_id, model_id,
task="auto", task="auto",
tokenizer=model_id, tokenizer=model_id,
@ -119,39 +121,31 @@ def test_get_pooling_config():
revision=None, revision=None,
) )
minilm_pooling_config = minilm_model_config._init_pooler_config( pooling_config = model_config._init_pooler_config(None)
pooling_type=None, assert pooling_config is not None
pooling_norm=None,
pooling_returned_token_ids=None,
pooling_softmax=None,
pooling_step_tag_id=None)
assert minilm_pooling_config.pooling_norm assert pooling_config.normalize
assert minilm_pooling_config.pooling_type == PoolingType.MEAN.name assert pooling_config.pooling_type == PoolingType.MEAN.name
@pytest.mark.skipif(current_platform.is_rocm(), @pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.") reason="Xformers backend is not supported on ROCm.")
def test_get_pooling_config_from_args(): def test_get_pooling_config_from_args():
model_id = "sentence-transformers/all-MiniLM-L12-v2" model_id = "sentence-transformers/all-MiniLM-L12-v2"
minilm_model_config = ModelConfig(model_id, model_config = ModelConfig(model_id,
task="auto", task="auto",
tokenizer=model_id, tokenizer=model_id,
tokenizer_mode="auto", tokenizer_mode="auto",
trust_remote_code=False, trust_remote_code=False,
seed=0, seed=0,
dtype="float16", dtype="float16",
revision=None) revision=None)
minilm_pooling_config = minilm_model_config._init_pooler_config( override_config = PoolerConfig(pooling_type='CLS', normalize=True)
pooling_type='CLS',
pooling_norm=True,
pooling_returned_token_ids=None,
pooling_softmax=None,
pooling_step_tag_id=None)
assert minilm_pooling_config.pooling_norm pooling_config = model_config._init_pooler_config(override_config)
assert minilm_pooling_config.pooling_type == PoolingType.CLS.name assert pooling_config is not None
assert asdict(pooling_config) == asdict(override_config)
@pytest.mark.skipif(current_platform.is_rocm(), @pytest.mark.skipif(current_platform.is_rocm(),

View File

@ -112,10 +112,6 @@ class ModelConfig:
the model name will be the same as `model`. the model name will be the same as `model`.
limit_mm_per_prompt: Maximum number of data items per modality limit_mm_per_prompt: Maximum number of data items per modality
per prompt. Only applicable for multimodal models. per prompt. Only applicable for multimodal models.
override_neuron_config: Initialize non default neuron config or
override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that
can not be gathered from the vllm arguments.
config_format: The config format which shall be loaded. config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'. Defaults to 'auto' which defaults to 'hf'.
hf_overrides: If a dictionary, contains arguments to be forwarded to the hf_overrides: If a dictionary, contains arguments to be forwarded to the
@ -123,20 +119,12 @@ class ModelConfig:
HuggingFace config. HuggingFace config.
mm_processor_kwargs: Arguments to be forwarded to the model's processor mm_processor_kwargs: Arguments to be forwarded to the model's processor
for multi-modal data, e.g., image processor. for multi-modal data, e.g., image processor.
pooling_type: Used to configure the pooling method in the embedding override_neuron_config: Initialize non default neuron config or
model. override default neuron config that are specific to Neuron devices,
pooling_norm: Used to determine whether to normalize the pooled this argument will be used to configure the neuron config that
data in the embedding model. can not be gathered from the vllm arguments.
pooling_softmax: Used to determine whether to softmax the pooled override_pooling_config: Initialize non default pooling config or
data in the embedding model. override default pooling config for the embedding model.
pooling_step_tag_id: When pooling_step_tag_id is not -1, it indicates
that the score corresponding to the pooling_step_tag_id in the
generated sentence should be returned. Otherwise, it returns
the scores for all tokens.
pooling_returned_token_ids: pooling_returned_token_ids represents a
list of indices for the vocabulary dimensions to be extracted,
such as the token IDs of good_token and bad_token in the
math-shepherd-mistral-7b-prm model.
""" """
def __init__( def __init__(
@ -166,16 +154,12 @@ class ModelConfig:
served_model_name: Optional[Union[str, List[str]]] = None, served_model_name: Optional[Union[str, List[str]]] = None,
limit_mm_per_prompt: Optional[Mapping[str, int]] = None, limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True, use_async_output_proc: bool = True,
override_neuron_config: Optional[Dict[str, Any]] = None,
config_format: ConfigFormat = ConfigFormat.AUTO, config_format: ConfigFormat = ConfigFormat.AUTO,
chat_template_text_format: str = "string", chat_template_text_format: str = "string",
hf_overrides: Optional[HfOverrides] = None, hf_overrides: Optional[HfOverrides] = None,
mm_processor_kwargs: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None,
pooling_type: Optional[str] = None, override_neuron_config: Optional[Dict[str, Any]] = None,
pooling_norm: Optional[bool] = None, override_pooler_config: Optional["PoolerConfig"] = None) -> None:
pooling_softmax: Optional[bool] = None,
pooling_step_tag_id: Optional[int] = None,
pooling_returned_token_ids: Optional[List[int]] = None) -> None:
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
@ -280,13 +264,7 @@ class ModelConfig:
supported_tasks, task = self._resolve_task(task, self.hf_config) supported_tasks, task = self._resolve_task(task, self.hf_config)
self.supported_tasks = supported_tasks self.supported_tasks = supported_tasks
self.task: Final = task self.task: Final = task
self.pooler_config = self._init_pooler_config( self.pooler_config = self._init_pooler_config(override_pooler_config)
pooling_type,
pooling_norm,
pooling_softmax,
pooling_step_tag_id,
pooling_returned_token_ids,
)
self._verify_quantization() self._verify_quantization()
self._verify_cuda_graph() self._verify_cuda_graph()
@ -311,27 +289,21 @@ class ModelConfig:
def _init_pooler_config( def _init_pooler_config(
self, self,
pooling_type: Optional[str] = None, override_pooler_config: Optional["PoolerConfig"],
pooling_norm: Optional[bool] = None,
pooling_softmax: Optional[bool] = None,
pooling_step_tag_id: Optional[int] = None,
pooling_returned_token_ids: Optional[List[int]] = None
) -> Optional["PoolerConfig"]: ) -> Optional["PoolerConfig"]:
if self.task == "embedding": if self.task == "embedding":
pooling_config = get_pooling_config(self.model, self.revision) user_config = override_pooler_config or PoolerConfig()
if pooling_config is not None:
# override if user does not base_config = get_pooling_config(self.model, self.revision)
# specifies pooling_type and/or pooling_norm if base_config is not None:
if pooling_type is None: # Only set values that are not overridden by the user
pooling_type = pooling_config["pooling_type"] for k, v in base_config.items():
if pooling_norm is None: if getattr(user_config, k) is None:
pooling_norm = pooling_config["normalize"] setattr(user_config, k, v)
return PoolerConfig(
pooling_type=pooling_type, return user_config
pooling_norm=pooling_norm,
pooling_softmax=pooling_softmax,
pooling_step_tag_id=pooling_step_tag_id,
pooling_returned_token_ids=pooling_returned_token_ids)
return None return None
def _init_attention_free(self) -> bool: def _init_attention_free(self) -> bool:
@ -1786,13 +1758,43 @@ class MultiModalConfig:
@dataclass @dataclass
class PoolerConfig: class PoolerConfig:
"""Controls the behavior of pooler in embedding model""" """Controls the behavior of output pooling in embedding models."""
pooling_type: Optional[str] = None pooling_type: Optional[str] = None
pooling_norm: Optional[bool] = None """
pooling_softmax: Optional[bool] = None The pooling method of the embedding model. This should be a key in
pooling_step_tag_id: Optional[int] = None :class:`vllm.model_executor.layers.pooler.PoolingType`.
pooling_returned_token_ids: Optional[List[int]] = None """
normalize: Optional[bool] = None
"""
Whether to normalize the pooled outputs. Usually, this should be set to
``True`` for embedding outputs.
"""
softmax: Optional[bool] = None
"""
Whether to apply softmax to the pooled outputs. Usually, this should be set
to ``True`` for classification outputs.
"""
step_tag_id: Optional[int] = None
"""
If set, only the score corresponding to the ``step_tag_id`` in the
generated sentence should be returned. Otherwise, the scores for all tokens
are returned.
"""
returned_token_ids: Optional[List[int]] = None
"""
A list of indices for the vocabulary dimensions to be extracted,
such as the token IDs of ``good_token`` and ``bad_token`` in the
``math-shepherd-mistral-7b-prm`` model.
"""
@staticmethod
def from_json(json_str: str) -> "PoolerConfig":
return PoolerConfig(**json.loads(json_str))
_STR_DTYPE_TO_TORCH_DTYPE = { _STR_DTYPE_TO_TORCH_DTYPE = {

View File

@ -11,12 +11,11 @@ import vllm.envs as envs
from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig, from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
DeviceConfig, HfOverrides, LoadConfig, LoadFormat, DeviceConfig, HfOverrides, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, ObservabilityConfig, LoRAConfig, ModelConfig, ObservabilityConfig,
ParallelConfig, PromptAdapterConfig, SchedulerConfig, ParallelConfig, PoolerConfig, PromptAdapterConfig,
SpeculativeConfig, TaskOption, TokenizerPoolConfig, SchedulerConfig, SpeculativeConfig, TaskOption,
VllmConfig) TokenizerPoolConfig, VllmConfig)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.pooler import PoolingType
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
@ -187,15 +186,10 @@ class EngineArgs:
otlp_traces_endpoint: Optional[str] = None otlp_traces_endpoint: Optional[str] = None
collect_detailed_traces: Optional[str] = None collect_detailed_traces: Optional[str] = None
disable_async_output_proc: bool = False disable_async_output_proc: bool = False
override_neuron_config: Optional[Dict[str, Any]] = None
scheduling_policy: Literal["fcfs", "priority"] = "fcfs" scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
# Pooling configuration. override_neuron_config: Optional[Dict[str, Any]] = None
pooling_type: Optional[str] = None override_pooler_config: Optional[PoolerConfig] = None
pooling_norm: Optional[bool] = None
pooling_softmax: Optional[bool] = None
pooling_step_tag_id: Optional[int] = None
pooling_returned_token_ids: Optional[List[int]] = None
def __post_init__(self): def __post_init__(self):
if not self.tokenizer: if not self.tokenizer:
@ -859,12 +853,6 @@ class EngineArgs:
default=EngineArgs.disable_async_output_proc, default=EngineArgs.disable_async_output_proc,
help="Disable async output processing. This may result in " help="Disable async output processing. This may result in "
"lower performance.") "lower performance.")
parser.add_argument(
'--override-neuron-config',
type=json.loads,
default=None,
help="Override or set neuron device configuration. "
"e.g. {\"cast_logits_dtype\": \"bloat16\"}.'")
parser.add_argument( parser.add_argument(
'--scheduling-policy', '--scheduling-policy',
@ -877,56 +865,17 @@ class EngineArgs:
'arrival deciding any ties).') 'arrival deciding any ties).')
parser.add_argument( parser.add_argument(
'--pooling-type', '--override-neuron-config',
choices=[pt.name for pt in PoolingType], type=json.loads,
default=None, default=None,
help='Used to configure the pooling method in the embedding model.' help="Override or set neuron device configuration. "
) "e.g. {\"cast_logits_dtype\": \"bloat16\"}.'")
parser.add_argument('--pooling-norm',
default=None,
action='store_true',
help="Used to determine whether to normalize "
"the pooled data in the embedding model.")
parser.add_argument('--no-pooling-norm',
default=None,
action='store_false',
dest='pooling_norm',
help="Used to determine whether to normalize "
"the pooled data in the embedding model.")
parser.add_argument('--pooling-softmax',
default=None,
action='store_true',
help="Used to determine whether to softmax "
"the pooled data in the embedding model.")
parser.add_argument('--no-pooling-softmax',
default=None,
action='store_false',
dest='pooling_softmax',
help="Used to determine whether to softmax "
"the pooled data in the embedding model.")
parser.add_argument( parser.add_argument(
'--pooling-step-tag-id', '--override-pooler-config',
type=int, type=PoolerConfig.from_json,
default=None, default=None,
help="When pooling-step-tag-id is not -1, it indicates " help="Override or set the pooling method in the embedding model. "
"that the score corresponding to the step-tag-ids in the " "e.g. {\"pooling_type\": \"mean\", \"normalize\": false}.'")
"generated sentence should be returned. Otherwise, it "
"returns the scores for all tokens.")
parser.add_argument(
'--pooling-returned-token-ids',
nargs='+',
type=int,
default=None,
help="pooling-returned-token-ids represents a list of "
"indices for the vocabulary dimensions to be extracted, "
"such as the token IDs of good_token and bad_token in "
"the math-shepherd-mistral-7b-prm model.")
return parser return parser
@ -967,14 +916,10 @@ class EngineArgs:
served_model_name=self.served_model_name, served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt, limit_mm_per_prompt=self.limit_mm_per_prompt,
use_async_output_proc=not self.disable_async_output_proc, use_async_output_proc=not self.disable_async_output_proc,
override_neuron_config=self.override_neuron_config,
config_format=self.config_format, config_format=self.config_format,
mm_processor_kwargs=self.mm_processor_kwargs, mm_processor_kwargs=self.mm_processor_kwargs,
pooling_type=self.pooling_type, override_neuron_config=self.override_neuron_config,
pooling_norm=self.pooling_norm, override_pooler_config=self.override_pooler_config,
pooling_softmax=self.pooling_softmax,
pooling_step_tag_id=self.pooling_step_tag_id,
pooling_returned_token_ids=self.pooling_returned_token_ids,
) )
def create_load_config(self) -> LoadConfig: def create_load_config(self) -> LoadConfig:

View File

@ -9,7 +9,8 @@ from tqdm import tqdm
from vllm import envs from vllm import envs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score) BeamSearchSequence, get_beam_search_score)
from vllm.engine.arg_utils import EngineArgs, HfOverrides, TaskOption from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
TaskOption)
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
apply_hf_chat_template, apply_hf_chat_template,
@ -162,11 +163,7 @@ class LLM:
mm_processor_kwargs: Optional[Dict[str, Any]] = None, mm_processor_kwargs: Optional[Dict[str, Any]] = None,
# After positional args are removed, move this right below `model` # After positional args are removed, move this right below `model`
task: TaskOption = "auto", task: TaskOption = "auto",
pooling_type: Optional[str] = None, override_pooler_config: Optional[PoolerConfig] = None,
pooling_norm: Optional[bool] = None,
pooling_softmax: Optional[bool] = None,
pooling_step_tag_id: Optional[int] = None,
pooling_returned_token_ids: Optional[List[int]] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
''' '''
@ -202,11 +199,7 @@ class LLM:
disable_async_output_proc=disable_async_output_proc, disable_async_output_proc=disable_async_output_proc,
hf_overrides=hf_overrides, hf_overrides=hf_overrides,
mm_processor_kwargs=mm_processor_kwargs, mm_processor_kwargs=mm_processor_kwargs,
pooling_type=pooling_type, override_pooler_config=override_pooler_config,
pooling_norm=pooling_norm,
pooling_softmax=pooling_softmax,
pooling_step_tag_id=pooling_step_tag_id,
pooling_returned_token_ids=pooling_returned_token_ids,
**kwargs, **kwargs,
) )
# Logic to switch between engines is done at runtime instead of import # Logic to switch between engines is done at runtime instead of import

View File

@ -63,14 +63,14 @@ class Pooler(nn.Module):
return cls( return cls(
pooling_type=PoolingType[pooler_config.pooling_type] pooling_type=PoolingType[pooler_config.pooling_type]
if pooler_config.pooling_type is not None else pooling_type, if pooler_config.pooling_type is not None else pooling_type,
normalize=pooler_config.pooling_norm normalize=pooler_config.normalize
if pooler_config.pooling_norm is not None else normalize, if pooler_config.normalize is not None else normalize,
softmax=pooler_config.pooling_softmax softmax=pooler_config.softmax
if pooler_config.pooling_softmax is not None else softmax, if pooler_config.softmax is not None else softmax,
step_tag_id=pooler_config.pooling_step_tag_id step_tag_id=pooler_config.step_tag_id
if pooler_config.pooling_step_tag_id is not None else step_tag_id, if pooler_config.step_tag_id is not None else step_tag_id,
returned_token_ids=pooler_config.pooling_returned_token_ids returned_token_ids=pooler_config.returned_token_ids
if pooler_config.pooling_returned_token_ids is not None else if pooler_config.returned_token_ids is not None else
returned_token_ids, returned_token_ids,
) )
@ -94,10 +94,14 @@ class Pooler(nn.Module):
pooled_data = hidden_states[last_token_flat_indices] pooled_data = hidden_states[last_token_flat_indices]
elif self.pooling_type == PoolingType.ALL: elif self.pooling_type == PoolingType.ALL:
offset = 0 offset = 0
pooled_data = [] pooled_data_lst = []
for prompt_len in prompt_lens: for prompt_len in prompt_lens:
pooled_data.append(hidden_states[offset:offset + prompt_len]) pooled_data_i = hidden_states[offset:offset + prompt_len]
pooled_data_lst.append(pooled_data_i)
offset += prompt_len offset += prompt_len
pooled_data = torch.stack(pooled_data_lst)
elif self.pooling_type == PoolingType.MEAN: elif self.pooling_type == PoolingType.MEAN:
# Calculate mean pooling # Calculate mean pooling
cumsum = torch.cumsum(hidden_states, dim=0) cumsum = torch.cumsum(hidden_states, dim=0)
@ -110,24 +114,26 @@ class Pooler(nn.Module):
cumsum[end_indices - 1] - cumsum[start_indices] + cumsum[end_indices - 1] - cumsum[start_indices] +
hidden_states[start_indices]) / prompt_lens.unsqueeze(1) hidden_states[start_indices]) / prompt_lens.unsqueeze(1)
elif self.pooling_type == PoolingType.STEP: elif self.pooling_type == PoolingType.STEP:
if self.returned_token_ids is not None and len( returned_token_ids = self.returned_token_ids
self.returned_token_ids) > 0: if returned_token_ids is not None and len(returned_token_ids) > 0:
logits = hidden_states[:, hidden_states = hidden_states[:, returned_token_ids]
self.returned_token_ids].softmax(dim=-1)
else: logits = hidden_states.softmax(dim=-1)
logits = hidden_states.softmax(dim=-1) step_tag_id = self.step_tag_id
offset = 0 offset = 0
pooled_data = [] pooled_data_lst = []
for prompt_len, seq_data_i in zip( for prompt_len, seq_data_i in zip(
prompt_lens, pooling_metadata.seq_data.values()): prompt_lens, pooling_metadata.seq_data.values()):
if self.step_tag_id is None: pooled_data_i = logits[offset:offset + prompt_len]
pooled_data.append(logits[offset:offset + prompt_len]) if step_tag_id is not None:
else: token_ids = torch.tensor(seq_data_i.prompt_token_ids)
step_idxs = torch.tensor( pooled_data_i = pooled_data_i[token_ids == step_tag_id]
seq_data_i.prompt_token_ids) == self.step_tag_id
pooled_data.append(logits[offset:offset +
prompt_len][step_idxs])
offset += prompt_len offset += prompt_len
pooled_data_lst.append(pooled_data_i)
pooled_data = torch.stack(pooled_data_lst)
else: else:
raise ValueError(f"Invalid pooling type: {self.pooling_type}") raise ValueError(f"Invalid pooling type: {self.pooling_type}")