[Spec Decode] Fix input triton kernel for eagle (#15909)
This commit is contained in:
parent
58f5a59769
commit
24b7fb455a
@ -250,13 +250,12 @@ def prepare_input_kernel(
|
|||||||
num_tokens = end_pos - start_pos
|
num_tokens = end_pos - start_pos
|
||||||
|
|
||||||
index_start = tl.load(cu_query_lens_ptr + pid)
|
index_start = tl.load(cu_query_lens_ptr + pid)
|
||||||
indices = index_start + tl.arange(0, BLOCK_SIZE)
|
|
||||||
|
|
||||||
num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE)
|
num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE)
|
||||||
for i in tl.range(num_blocks):
|
for i in tl.range(num_blocks):
|
||||||
offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
tl.store(
|
tl.store(
|
||||||
out_ptr + start_pos + offset,
|
out_ptr + start_pos + offset,
|
||||||
indices,
|
index_start + offset,
|
||||||
mask=offset < num_tokens,
|
mask=offset < num_tokens,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user