Fix a bug in 1D input shape (#5)

This commit is contained in:
Woosuk Kwon 2023-03-06 10:05:27 -08:00 committed by GitHub
parent 3e9f991d6a
commit 04e5acc08e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 11 additions and 6 deletions

View File

@ -47,9 +47,8 @@ class OPTCacheFlowAttention(nn.Module):
max_s=max_prompt_len, max_s=max_prompt_len,
causal=True, causal=True,
)[0] )[0]
num_tokens = prefix_sum[-1]
# FIXME(woosuk): Unnecessary copy. Optimize this. # FIXME(woosuk): Unnecessary copy. Optimize this.
output[:num_tokens].copy_(out, non_blocking=True) output.copy_(out, non_blocking=True)
def single_query_cached_kv_attention( def single_query_cached_kv_attention(
self, self,
@ -108,8 +107,14 @@ class OPTCacheFlowAttention(nn.Module):
# Compute the attention op for prompts. # Compute the attention op for prompts.
if input_metadata.num_prompts > 0: if input_metadata.num_prompts > 0:
num_prompt_tokens = sum(input_metadata.prompt_lens)
self.multi_query_kv_attention( self.multi_query_kv_attention(
output, query, key, value, input_metadata.prompt_lens) output[:num_prompt_tokens],
query[:num_prompt_tokens],
key[:num_prompt_tokens],
value[:num_prompt_tokens],
input_metadata.prompt_lens,
)
# Wait until the cache op is done. # Wait until the cache op is done.
if cache_event is not None: if cache_event is not None:

View File

@ -24,7 +24,7 @@ class InputMetadata:
self.num_prompts = len(prompt_lens) self.num_prompts = len(prompt_lens)
self.num_generation_tokens = context_lens.shape[0] self.num_generation_tokens = context_lens.shape[0]
self.num_valid_tokens = len(slot_mapping) self.num_valid_tokens = slot_mapping.shape[0]
if block_tables.numel() > 0: if block_tables.numel() > 0:
self.max_num_blocks_per_seq = block_tables.shape[1] self.max_num_blocks_per_seq = block_tables.shape[1]
else: else:

View File

@ -57,11 +57,11 @@ def main():
'UC Berkeley is', 'UC Berkeley is',
'The future of cloud computing is', 'The future of cloud computing is',
] ]
for prompt in test_inputs:
frontend.query(prompt)
# FIXME # FIXME
while True: while True:
if test_inputs:
frontend.query(test_inputs.pop())
scheduler.step() scheduler.step()
if not scheduler.pending and not scheduler.running: if not scheduler.pending and not scheduler.running:
break break