[Spec Decode] Fix input triton kernel for eagle (#15909)

This commit is contained in:
Ekagra Ranjan 2025-04-01 21:15:14 -04:00 committed by GitHub
parent 58f5a59769
commit 24b7fb455a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -250,13 +250,12 @@ def prepare_input_kernel(
num_tokens = end_pos - start_pos
index_start = tl.load(cu_query_lens_ptr + pid)
indices = index_start + tl.arange(0, BLOCK_SIZE)
num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE)
for i in tl.range(num_blocks):
offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
tl.store(
out_ptr + start_pos + offset,
indices,
index_start + offset,
mask=offset < num_tokens,
)