Fix a bug in 1D input shape (#5)

This commit is contained in:
Woosuk Kwon 2023-03-06 10:05:27 -08:00 committed by GitHub
parent 3e9f991d6a
commit 04e5acc08e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 6 deletions

View File

@ -47,9 +47,8 @@ class OPTCacheFlowAttention(nn.Module):
max_s=max_prompt_len,
causal=True,
)[0]
num_tokens = prefix_sum[-1]
# FIXME(woosuk): Unnecessary copy. Optimize this.
output[:num_tokens].copy_(out, non_blocking=True)
output.copy_(out, non_blocking=True)
def single_query_cached_kv_attention(
self,
@ -108,8 +107,14 @@ class OPTCacheFlowAttention(nn.Module):
# Compute the attention op for prompts.
if input_metadata.num_prompts > 0:
num_prompt_tokens = sum(input_metadata.prompt_lens)
self.multi_query_kv_attention(
output, query, key, value, input_metadata.prompt_lens)
output[:num_prompt_tokens],
query[:num_prompt_tokens],
key[:num_prompt_tokens],
value[:num_prompt_tokens],
input_metadata.prompt_lens,
)
# Wait until the cache op is done.
if cache_event is not None:

View File

@ -24,7 +24,7 @@ class InputMetadata:
self.num_prompts = len(prompt_lens)
self.num_generation_tokens = context_lens.shape[0]
self.num_valid_tokens = len(slot_mapping)
self.num_valid_tokens = slot_mapping.shape[0]
if block_tables.numel() > 0:
self.max_num_blocks_per_seq = block_tables.shape[1]
else:

View File

@ -57,11 +57,11 @@ def main():
'UC Berkeley is',
'The future of cloud computing is',
]
for prompt in test_inputs:
frontend.query(prompt)
# FIXME
while True:
if test_inputs:
frontend.query(test_inputs.pop())
scheduler.step()
if not scheduler.pending and not scheduler.running:
break