[Spec Decode] Fix input triton kernel for eagle (#15909)
This commit is contained in:
parent
58f5a59769
commit
24b7fb455a
@ -250,13 +250,12 @@ def prepare_input_kernel(
|
||||
num_tokens = end_pos - start_pos
|
||||
|
||||
index_start = tl.load(cu_query_lens_ptr + pid)
|
||||
indices = index_start + tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE)
|
||||
for i in tl.range(num_blocks):
|
||||
offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
tl.store(
|
||||
out_ptr + start_pos + offset,
|
||||
indices,
|
||||
index_start + offset,
|
||||
mask=offset < num_tokens,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user