diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 57c6b652..3aaaf34b 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -250,13 +250,12 @@ def prepare_input_kernel( num_tokens = end_pos - start_pos index_start = tl.load(cu_query_lens_ptr + pid) - indices = index_start + tl.arange(0, BLOCK_SIZE) num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE) for i in tl.range(num_blocks): offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) tl.store( out_ptr + start_pos + offset, - indices, + index_start + offset, mask=offset < num_tokens, )