Fix a bug in 1D input shape (#5)
This commit is contained in:
parent
3e9f991d6a
commit
04e5acc08e
@ -47,9 +47,8 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
max_s=max_prompt_len,
|
||||
causal=True,
|
||||
)[0]
|
||||
num_tokens = prefix_sum[-1]
|
||||
# FIXME(woosuk): Unnecessary copy. Optimize this.
|
||||
output[:num_tokens].copy_(out, non_blocking=True)
|
||||
output.copy_(out, non_blocking=True)
|
||||
|
||||
def single_query_cached_kv_attention(
|
||||
self,
|
||||
@ -108,8 +107,14 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
|
||||
# Compute the attention op for prompts.
|
||||
if input_metadata.num_prompts > 0:
|
||||
num_prompt_tokens = sum(input_metadata.prompt_lens)
|
||||
self.multi_query_kv_attention(
|
||||
output, query, key, value, input_metadata.prompt_lens)
|
||||
output[:num_prompt_tokens],
|
||||
query[:num_prompt_tokens],
|
||||
key[:num_prompt_tokens],
|
||||
value[:num_prompt_tokens],
|
||||
input_metadata.prompt_lens,
|
||||
)
|
||||
|
||||
# Wait until the cache op is done.
|
||||
if cache_event is not None:
|
||||
|
@ -24,7 +24,7 @@ class InputMetadata:
|
||||
|
||||
self.num_prompts = len(prompt_lens)
|
||||
self.num_generation_tokens = context_lens.shape[0]
|
||||
self.num_valid_tokens = len(slot_mapping)
|
||||
self.num_valid_tokens = slot_mapping.shape[0]
|
||||
if block_tables.numel() > 0:
|
||||
self.max_num_blocks_per_seq = block_tables.shape[1]
|
||||
else:
|
||||
|
@ -57,11 +57,11 @@ def main():
|
||||
'UC Berkeley is',
|
||||
'The future of cloud computing is',
|
||||
]
|
||||
for prompt in test_inputs:
|
||||
frontend.query(prompt)
|
||||
|
||||
# FIXME
|
||||
while True:
|
||||
if test_inputs:
|
||||
frontend.query(test_inputs.pop())
|
||||
scheduler.step()
|
||||
if not scheduler.pending and not scheduler.running:
|
||||
break
|
||||
|
Loading…
x
Reference in New Issue
Block a user