Fix a bug in 1D input shape (#5)
This commit is contained in:
parent
3e9f991d6a
commit
04e5acc08e
@ -47,9 +47,8 @@ class OPTCacheFlowAttention(nn.Module):
|
|||||||
max_s=max_prompt_len,
|
max_s=max_prompt_len,
|
||||||
causal=True,
|
causal=True,
|
||||||
)[0]
|
)[0]
|
||||||
num_tokens = prefix_sum[-1]
|
|
||||||
# FIXME(woosuk): Unnecessary copy. Optimize this.
|
# FIXME(woosuk): Unnecessary copy. Optimize this.
|
||||||
output[:num_tokens].copy_(out, non_blocking=True)
|
output.copy_(out, non_blocking=True)
|
||||||
|
|
||||||
def single_query_cached_kv_attention(
|
def single_query_cached_kv_attention(
|
||||||
self,
|
self,
|
||||||
@ -108,8 +107,14 @@ class OPTCacheFlowAttention(nn.Module):
|
|||||||
|
|
||||||
# Compute the attention op for prompts.
|
# Compute the attention op for prompts.
|
||||||
if input_metadata.num_prompts > 0:
|
if input_metadata.num_prompts > 0:
|
||||||
|
num_prompt_tokens = sum(input_metadata.prompt_lens)
|
||||||
self.multi_query_kv_attention(
|
self.multi_query_kv_attention(
|
||||||
output, query, key, value, input_metadata.prompt_lens)
|
output[:num_prompt_tokens],
|
||||||
|
query[:num_prompt_tokens],
|
||||||
|
key[:num_prompt_tokens],
|
||||||
|
value[:num_prompt_tokens],
|
||||||
|
input_metadata.prompt_lens,
|
||||||
|
)
|
||||||
|
|
||||||
# Wait until the cache op is done.
|
# Wait until the cache op is done.
|
||||||
if cache_event is not None:
|
if cache_event is not None:
|
||||||
|
@ -24,7 +24,7 @@ class InputMetadata:
|
|||||||
|
|
||||||
self.num_prompts = len(prompt_lens)
|
self.num_prompts = len(prompt_lens)
|
||||||
self.num_generation_tokens = context_lens.shape[0]
|
self.num_generation_tokens = context_lens.shape[0]
|
||||||
self.num_valid_tokens = len(slot_mapping)
|
self.num_valid_tokens = slot_mapping.shape[0]
|
||||||
if block_tables.numel() > 0:
|
if block_tables.numel() > 0:
|
||||||
self.max_num_blocks_per_seq = block_tables.shape[1]
|
self.max_num_blocks_per_seq = block_tables.shape[1]
|
||||||
else:
|
else:
|
||||||
|
@ -57,11 +57,11 @@ def main():
|
|||||||
'UC Berkeley is',
|
'UC Berkeley is',
|
||||||
'The future of cloud computing is',
|
'The future of cloud computing is',
|
||||||
]
|
]
|
||||||
for prompt in test_inputs:
|
|
||||||
frontend.query(prompt)
|
|
||||||
|
|
||||||
# FIXME
|
# FIXME
|
||||||
while True:
|
while True:
|
||||||
|
if test_inputs:
|
||||||
|
frontend.query(test_inputs.pop())
|
||||||
scheduler.step()
|
scheduler.step()
|
||||||
if not scheduler.pending and not scheduler.running:
|
if not scheduler.pending and not scheduler.running:
|
||||||
break
|
break
|
||||||
|
Loading…
x
Reference in New Issue
Block a user