[Core] Use numpy to speed up padded token processing (#6442)

This commit is contained in:
Peng Guanwen 2024-07-16 23:13:25 +08:00 committed by GitHub
parent 7508a3dc34
commit 2bb0489cb3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2,6 +2,7 @@ import random
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import numpy as np
import torch import torch
from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits
@ -457,16 +458,20 @@ class SamplingTensors:
if do_penalties: if do_penalties:
prompt_max_len = max([len(tokens) for tokens in prompt_tokens], prompt_max_len = max([len(tokens) for tokens in prompt_tokens],
default=0) default=0)
prompt_padded_tokens = [ prompt_padded_tokens = np.full(
tokens + [vocab_size] * (prompt_max_len - len(tokens)) (len(prompt_tokens), prompt_max_len),
for tokens in prompt_tokens vocab_size,
] dtype=np.int64)
for i, tokens in enumerate(prompt_tokens):
prompt_padded_tokens[i, :len(tokens)] = tokens
output_max_len = max([len(tokens) for tokens in output_tokens], output_max_len = max([len(tokens) for tokens in output_tokens],
default=0) default=0)
output_padded_tokens = [ output_padded_tokens = np.full(
tokens + [vocab_size] * (output_max_len - len(tokens)) (len(output_tokens), output_max_len),
for tokens in output_tokens vocab_size,
] dtype=np.int64)
for i, tokens in enumerate(output_tokens):
output_padded_tokens[i, :len(tokens)] = tokens
temperatures_t = torch.tensor( temperatures_t = torch.tensor(
temperatures, temperatures,
@ -517,18 +522,11 @@ class SamplingTensors:
pin_memory=pin_memory, pin_memory=pin_memory,
) )
if do_penalties: if do_penalties:
prompt_tensor = torch.tensor( prompt_tensor = torch.from_numpy(prompt_padded_tokens)
prompt_padded_tokens, output_tensor = torch.from_numpy(output_padded_tokens)
device="cpu", if pin_memory:
dtype=torch.long, prompt_tensor = prompt_tensor.pin_memory()
pin_memory=pin_memory, output_tensor = output_tensor.pin_memory()
)
output_tensor = torch.tensor(
output_padded_tokens,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
)
else: else:
prompt_tensor = None prompt_tensor = None
output_tensor = None output_tensor = None