[Core] Use numpy to speed up padded token processing (#6442)
This commit is contained in:
parent
7508a3dc34
commit
2bb0489cb3
@ -2,6 +2,7 @@ import random
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits
|
||||
@ -457,16 +458,20 @@ class SamplingTensors:
|
||||
if do_penalties:
|
||||
prompt_max_len = max([len(tokens) for tokens in prompt_tokens],
|
||||
default=0)
|
||||
prompt_padded_tokens = [
|
||||
tokens + [vocab_size] * (prompt_max_len - len(tokens))
|
||||
for tokens in prompt_tokens
|
||||
]
|
||||
prompt_padded_tokens = np.full(
|
||||
(len(prompt_tokens), prompt_max_len),
|
||||
vocab_size,
|
||||
dtype=np.int64)
|
||||
for i, tokens in enumerate(prompt_tokens):
|
||||
prompt_padded_tokens[i, :len(tokens)] = tokens
|
||||
output_max_len = max([len(tokens) for tokens in output_tokens],
|
||||
default=0)
|
||||
output_padded_tokens = [
|
||||
tokens + [vocab_size] * (output_max_len - len(tokens))
|
||||
for tokens in output_tokens
|
||||
]
|
||||
output_padded_tokens = np.full(
|
||||
(len(output_tokens), output_max_len),
|
||||
vocab_size,
|
||||
dtype=np.int64)
|
||||
for i, tokens in enumerate(output_tokens):
|
||||
output_padded_tokens[i, :len(tokens)] = tokens
|
||||
|
||||
temperatures_t = torch.tensor(
|
||||
temperatures,
|
||||
@ -517,18 +522,11 @@ class SamplingTensors:
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
if do_penalties:
|
||||
prompt_tensor = torch.tensor(
|
||||
prompt_padded_tokens,
|
||||
device="cpu",
|
||||
dtype=torch.long,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
output_tensor = torch.tensor(
|
||||
output_padded_tokens,
|
||||
device="cpu",
|
||||
dtype=torch.long,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
prompt_tensor = torch.from_numpy(prompt_padded_tokens)
|
||||
output_tensor = torch.from_numpy(output_padded_tokens)
|
||||
if pin_memory:
|
||||
prompt_tensor = prompt_tensor.pin_memory()
|
||||
output_tensor = output_tensor.pin_memory()
|
||||
else:
|
||||
prompt_tensor = None
|
||||
output_tensor = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user