[Bugfix][Hardware][CPU] Fix CPU input_positions creation for text-only inputs with mrope (#11434)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py 2024-12-24 17:59:51 +08:00 committed by GitHub
parent b1b1038fbd
commit 7a5286cc04
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -114,8 +114,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
def __init__(self, use_mrope: bool): def __init__(self, use_mrope: bool):
self.use_mrope = use_mrope self.use_mrope = use_mrope
self.input_tokens: List[int] = [] self.input_tokens: List[int] = []
self.input_positions: Optional[ self.input_positions: List[int] = []
List[int]] = [] if not self.use_mrope else None
self.token_type_ids: Optional[List[int]] = [] self.token_type_ids: Optional[List[int]] = []
self.seq_lens: List[int] = [] self.seq_lens: List[int] = []
self.query_lens: List[int] = [] self.query_lens: List[int] = []
@ -130,9 +129,8 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
self.multi_modal_placeholder_maps: Dict[ self.multi_modal_placeholder_maps: Dict[
str, MultiModalPlaceholderMap] = defaultdict( str, MultiModalPlaceholderMap] = defaultdict(
MultiModalPlaceholderMap) MultiModalPlaceholderMap)
self.input_mrope_positions: Optional[List[List[int]]] = [ self.input_mrope_positions: List[List[int]] = [[]
[] for _ in range(3) for _ in range(3)]
] if self.use_mrope else None
def __init__(self, def __init__(self,
runner: "CPUModelRunner", runner: "CPUModelRunner",
@ -167,7 +165,8 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
device="cpu") device="cpu")
input_positions = torch.tensor( input_positions = torch.tensor(
input_data.input_positions input_data.input_positions
if not input_data.use_mrope else input_data.input_mrope_positions, if not any(input_data.input_mrope_positions) else
input_data.input_mrope_positions,
dtype=torch.long, dtype=torch.long,
device="cpu") device="cpu")
token_type_ids = torch.tensor(input_data.token_type_ids, token_type_ids = torch.tensor(input_data.token_type_ids,
@ -236,7 +235,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
block_table = block_table[start_block:] block_table = block_table[start_block:]
# For MRotaryEmbedding # For MRotaryEmbedding
if data.input_positions is None: if seq_data.mrope_position_delta is not None:
next_pos = MRotaryEmbedding.get_next_input_positions( next_pos = MRotaryEmbedding.get_next_input_positions(
seq_data.mrope_position_delta, seq_data.mrope_position_delta,
context_len, context_len,
@ -309,8 +308,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
data.slot_mapping.extend(slot_mapping) data.slot_mapping.extend(slot_mapping)
# The MROPE positions are prepared in _compute_multi_modal_input # The MROPE positions are prepared in _compute_multi_modal_input
if data.input_positions is not None: data.input_positions.extend(token_positions)
data.input_positions.extend(token_positions)
if data.token_type_ids is not None: if data.token_type_ids is not None:
data.token_type_ids.extend(token_types if token_types else []) data.token_type_ids.extend(token_types if token_types else [])