vllm/vllm/v1/worker/gpu_model_runner.py

614 lines
27 KiB
Python
Raw Normal View History

import gc
import time
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
2024-10-22 01:24:07 -07:00
import numpy as np
import torch
import torch.distributed
import torch.nn as nn
from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed.parallel_state import graph_capture
2024-10-22 01:24:07 -07:00
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
2024-10-22 01:24:07 -07:00
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingType
2024-10-22 01:24:07 -07:00
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv,
is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend,
FlashAttentionMetadata)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
2024-10-22 01:24:07 -07:00
if TYPE_CHECKING:
from vllm.v1.core.scheduler import SchedulerOutput
logger = init_logger(__name__)
class GPUModelRunner:
def __init__(
self,
vllm_config: VllmConfig,
input_registry: InputRegistry = INPUT_REGISTRY,
2024-10-22 01:24:07 -07:00
):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.device_config = vllm_config.device_config
self.speculative_config = vllm_config.speculative_config
self.prompt_adapter_config = vllm_config.prompt_adapter_config
self.observability_config = vllm_config.observability_config
2024-10-22 01:24:07 -07:00
model_config = self.model_config
cache_config = self.cache_config
scheduler_config = self.scheduler_config
parallel_config = self.parallel_config
2024-10-22 01:24:07 -07:00
self.device = self.device_config.device
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
if cache_config.cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
else:
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
cache_config.cache_dtype]
self.sliding_window = model_config.get_sliding_window()
self.block_size = cache_config.block_size
self.max_model_len = model_config.max_model_len
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
self.max_num_tokens = scheduler_config.max_num_batched_tokens
# Model-related.
self.num_attn_layers = model_config.get_num_attention_layers(
parallel_config)
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
self.head_size = model_config.get_head_size()
self.hidden_size = model_config.get_hidden_size()
# Multi-modal data support
self.input_registry = input_registry
2024-10-22 01:24:07 -07:00
# Lazy initialization
# self.model: nn.Module # Set after load_model
self.kv_caches: List[torch.Tensor] = []
# req_id -> (input_id -> encoder_output)
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
2024-10-22 01:24:07 -07:00
# Request states.
self.requests: Dict[str, CachedRequestState] = {}
# Persistent batch.
self.input_batch = InputBatch(
max_num_reqs=self.scheduler_config.max_num_seqs,
max_model_len=self.max_model_len,
max_num_blocks_per_req=self.max_num_blocks_per_req,
device=self.device,
pin_memory=self.pin_memory,
)
self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and not self.model_config.enforce_eager)
# TODO(woosuk): Provide an option to tune the max cudagraph batch size.
# The convention is different.
# self.cudagraph_batch_sizes sorts in ascending order.
# The batch sizes in the config are in descending order.
self.cudagraph_batch_sizes = list(
reversed(self.vllm_config.compilation_config.capture_sizes))
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=self.device)
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=self.device)
2024-10-22 01:24:07 -07:00
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
# Remove stopped requests from the cached states.
# Keep the states of the pre-empted requests.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.encoder_cache.pop(req_id, None)
# Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids:
encoder_outputs = self.encoder_cache.get(req_id)
if encoder_outputs is not None:
encoder_outputs.pop(input_id, None)
if not encoder_outputs:
self.encoder_cache.pop(req_id, None)
2024-10-22 01:24:07 -07:00
# Remove the requests from the persistent batch.
stopped_req_ids = set().union(
scheduler_output.preempted_req_ids,
scheduler_output.finished_req_ids,
)
removed_req_indices: List[int] = []
for req_id in stopped_req_ids:
req_index = self.input_batch.remove_request(req_id)
if req_index is not None:
removed_req_indices.append(req_index)
# Update the states of the running requests.
for req_data in scheduler_output.scheduled_running_reqs:
req_id = req_data.req_id
req_state = self.requests[req_id]
req_index = self.input_batch.req_id_to_index[req_id]
# Update the num_computed_tokens.
req_state.num_computed_tokens = req_data.num_computed_tokens
self.input_batch.num_computed_tokens_cpu[req_index] = (
req_data.num_computed_tokens)
# Update the block table.
num_new_blocks = len(req_data.new_block_ids)
if num_new_blocks == 0:
continue
start_index = len(req_state.block_ids)
end_index = start_index + num_new_blocks
req_state.block_ids.extend(req_data.new_block_ids)
self.input_batch.block_table_cpu[
req_index, start_index:end_index] = req_data.new_block_ids
req_ids_to_add: List[str] = []
# Add new requests to the cached states.
for req_data in scheduler_output.scheduled_new_reqs:
req_id = req_data.req_id
sampling_params = req_data.sampling_params
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
generator = None
2024-10-22 01:24:07 -07:00
self.requests[req_id] = CachedRequestState(
req_id=req_id,
prompt_token_ids=req_data.prompt_token_ids,
prompt=req_data.prompt,
mm_inputs=req_data.mm_inputs,
mm_positions=req_data.mm_positions,
sampling_params=sampling_params,
generator=generator,
2024-10-22 01:24:07 -07:00
block_ids=req_data.block_ids,
num_computed_tokens=req_data.num_computed_tokens,
output_token_ids=[],
)
req_ids_to_add.append(req_id)
# Update the cached states of the resumed requests.
for req_data in scheduler_output.scheduled_resumed_reqs:
req_id = req_data.req_id
req_state = self.requests[req_id]
req_state.block_ids = req_data.block_ids
req_state.num_computed_tokens = req_data.num_computed_tokens
req_ids_to_add.append(req_id)
# Add the new or resumed requests to the persistent batch.
# The smaller empty indices are filled first.
removed_req_indices = sorted(removed_req_indices, reverse=True)
for req_id in req_ids_to_add:
req_state = self.requests[req_id]
if removed_req_indices:
# Fill the empty index.
req_index = removed_req_indices.pop()
else:
# Append to the end.
req_index = None
self.input_batch.add_request(req_state, req_index)
# Condense the batched states if there are empty indices.
if removed_req_indices:
self.input_batch.condense(removed_req_indices)
def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0
num_reqs = self.input_batch.num_reqs
assert num_reqs > 0
# OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations.
self.input_batch.block_table[:num_reqs].copy_(
self.input_batch.block_table_cpu_tensor[:num_reqs],
non_blocking=True)
# Get the number of scheduled tokens for each request.
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens = []
max_num_scheduled_tokens = 0
for req_id in self.input_batch.req_ids[:num_reqs]:
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_scheduled_tokens.append(num_tokens)
max_num_scheduled_tokens = max(max_num_scheduled_tokens,
num_tokens)
num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32)
assert max_num_scheduled_tokens > 0
# Get request indices.
# E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2]
indices = np.arange(num_reqs)
req_indices = np.repeat(indices, num_scheduled_tokens)
# Get batched arange.
# E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
arange_matrix = np.tile(np.arange(max_num_scheduled_tokens),
(num_reqs, 1))
mask = arange_matrix < num_scheduled_tokens[:, np.newaxis]
arange = arange_matrix[mask]
# Get positions.
positions = torch.empty((total_num_scheduled_tokens, ),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
positions_np = positions.numpy()
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
arange,
out=positions_np)
# Get token indices.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2]
# where M is the max_model_len.
token_indices = (positions_np +
req_indices * self.input_batch.token_ids_cpu.shape[1])
2024-10-22 01:24:07 -07:00
token_indices = torch.from_numpy(token_indices)
input_ids = torch.empty((total_num_scheduled_tokens, ),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
torch.index_select(torch.from_numpy(
self.input_batch.token_ids_cpu).flatten(),
0,
token_indices,
out=input_ids)
# Calculate the slot mapping.
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
# NOTE(woosuk): We can't simply use `token_indices // block_size` here
# because M (max_model_len) is not necessarily divisible by block_size.
2024-10-22 01:24:07 -07:00
block_numbers = self.input_batch.block_table_cpu_tensor.flatten()[
req_indices * self.max_num_blocks_per_req +
positions_np // self.block_size]
block_offsets = torch.from_numpy(positions_np % self.block_size)
2024-10-22 01:24:07 -07:00
slot_mapping = torch.empty((total_num_scheduled_tokens, ),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
torch.add(block_numbers * self.block_size,
block_offsets,
out=slot_mapping)
# Prepare the attention metadata.
query_start_loc = torch.empty((num_reqs + 1, ),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
query_start_loc_np = query_start_loc.numpy()
query_start_loc_np[0] = 0
np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1:])
seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
max_seq_len = seq_lens.max()
seq_start_loc = torch.empty((num_reqs + 1, ),
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory)
seq_start_loc_np = seq_start_loc.numpy()
seq_start_loc_np[0] = 0
np.cumsum(seq_lens, out=seq_start_loc_np[1:])
input_ids = input_ids.to(self.device, non_blocking=True)
self.positions[:total_num_scheduled_tokens].copy_(positions,
non_blocking=True)
2024-10-22 01:24:07 -07:00
query_start_loc = query_start_loc.to(self.device, non_blocking=True)
seq_start_loc = seq_start_loc.to(self.device, non_blocking=True)
slot_mapping = slot_mapping.to(self.device, non_blocking=True).long()
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=total_num_scheduled_tokens,
2024-10-22 01:24:07 -07:00
max_query_len=max_num_scheduled_tokens,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_start_loc=seq_start_loc,
block_table=self.input_batch.block_table[:num_reqs],
slot_mapping=slot_mapping,
)
# NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial
# request in the batch. While we should not sample any token from this
# partial request, we do so for simplicity. We will ignore the sampled
# token from the partial request.
# TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1
return input_ids, attn_metadata, logits_indices
2024-10-22 01:24:07 -07:00
def _prepare_sampling(
self,
scheduler_output: "SchedulerOutput",
) -> SamplingMetadata:
skip_copy = True
if (scheduler_output.finished_req_ids
or scheduler_output.preempted_req_ids):
skip_copy = False
if (scheduler_output.scheduled_new_reqs
or scheduler_output.scheduled_resumed_reqs):
skip_copy = False
# Create the sampling metadata.
sampling_metadata = self.input_batch.make_sampling_metadata(skip_copy)
return sampling_metadata
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs:
return
# Batch the multi-modal inputs.
mm_inputs: List[MultiModalKwargs] = []
req_input_ids: List[Tuple[int, int]] = []
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
for input_id in encoder_input_ids:
mm_inputs.append(req_state.mm_inputs[input_id])
req_input_ids.append((req_id, input_id))
batched_mm_inputs = MultiModalKwargs.batch(mm_inputs)
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
device=self.device)
# Run the encoder.
# `encoder_outputs` is either of the following:
# 1. A tensor of shape [num_images, feature_size, hidden_size]
# in case when feature_size is fixed across all images.
# 2. A list (length: num_images) of tensors, each of shape
# [feature_size, hidden_size] in case when the feature size is
# dynamic depending on input images.
encoder_outputs = self.model.get_multimodal_embeddings(
**batched_mm_inputs)
# Cache the encoder outputs.
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
if req_id not in self.encoder_cache:
self.encoder_cache[req_id] = {}
self.encoder_cache[req_id][input_id] = output
def _gather_encoder_outputs(
self,
scheduler_output: "SchedulerOutput",
) -> List[torch.Tensor]:
encoder_outputs: List[torch.Tensor] = []
num_reqs = self.input_batch.num_reqs
for req_id in self.input_batch.req_ids[:num_reqs]:
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
req_id]
req_state = self.requests[req_id]
num_computed_tokens = req_state.num_computed_tokens
mm_positions = req_state.mm_positions
for i, pos_info in enumerate(mm_positions):
start_pos = pos_info["offset"]
num_encoder_tokens = pos_info["length"]
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens,
# num_computed_tokens + num_scheduled_tokens) and
# [start_pos, start_pos + num_encoder_tokens)
if start_pos >= num_computed_tokens + num_scheduled_tokens:
# The encoder output is not needed in this step.
break
if start_pos + num_encoder_tokens <= num_computed_tokens:
# The encoder output is already processed and stored
# in the decoder's KV cache.
continue
start_idx = max(num_computed_tokens - start_pos, 0)
end_idx = min(
num_computed_tokens - start_pos + num_scheduled_tokens,
num_encoder_tokens)
assert start_idx < end_idx
assert req_id in self.encoder_cache
assert i in self.encoder_cache[req_id]
encoder_output = self.encoder_cache[req_id][i]
encoder_outputs.append(encoder_output[start_idx:end_idx])
return encoder_outputs
2024-10-22 01:24:07 -07:00
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput:
self._update_states(scheduler_output)
# Run the encoder.
self._execute_encoder(scheduler_output)
encoder_outputs = self._gather_encoder_outputs(scheduler_output)
# Prepare the decoder inputs.
input_ids, attn_metadata, logits_indices = self._prepare_inputs(
scheduler_output)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if (self.use_cuda_graph
and num_scheduled_tokens <= self.cudagraph_batch_sizes[-1]):
# Use piecewise CUDA graphs.
# Add padding to the batch size.
num_input_tokens = self._get_padded_batch_size(
num_scheduled_tokens)
else:
# Eager mode.
num_input_tokens = num_scheduled_tokens
2024-10-22 01:24:07 -07:00
# Get the inputs embeds.
if encoder_outputs:
inputs_embeds = self.model.get_input_embeddings(
input_ids, encoder_outputs)
else:
inputs_embeds = self.model.get_input_embeddings(input_ids)
# NOTE(woosuk): To unify token ids and soft tokens (vision embeddings),
# always use embeddings (rather than token ids) as input to the model.
# TODO(woosuk): Avoid the copy. Optimize.
self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata, self.vllm_config):
2024-10-22 01:24:07 -07:00
hidden_states = self.model(
input_ids=None,
positions=self.positions[:num_input_tokens],
2024-10-22 01:24:07 -07:00
kv_caches=self.kv_caches,
attn_metadata=None,
inputs_embeds=self.inputs_embeds[:num_input_tokens],
2024-10-22 01:24:07 -07:00
)
hidden_states = hidden_states[:num_scheduled_tokens]
2024-10-22 01:24:07 -07:00
hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(hidden_states, None)
# Sample the next token and get logprobs if needed.
sampling_metadata = self._prepare_sampling(scheduler_output)
sampler_output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
)
# NOTE: CPU-GPU synchronization happens here.
sampled_token_ids = sampler_output.sampled_token_ids.cpu()
sampled_token_ids_list = sampled_token_ids.tolist()
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
num_reqs = self.input_batch.num_reqs
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
req_state = self.requests[req_id]
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
assert seq_len <= req_state.num_tokens
if seq_len == req_state.num_tokens:
# Append the sampled token to the output token ids.
token_id = sampled_token_ids_list[i]
self.input_batch.token_ids_cpu[i, seq_len] = token_id
req_state.output_token_ids.append(token_id)
else:
# Ignore the sampled token from the partial request.
# Rewind the generator state as if the token was not sampled.
generator = self.input_batch.generators.get(i)
2024-10-22 01:24:07 -07:00
if generator is not None:
# This relies on cuda-specific torch-internal impl details
generator.set_offset(generator.get_offset() - 4)
2024-10-22 01:24:07 -07:00
if sampler_output.logprob_token_ids is None:
logprob_token_ids = None
else:
logprob_token_ids = sampler_output.logprob_token_ids.cpu()
if sampler_output.logprobs is None:
logprobs = None
else:
logprobs = sampler_output.logprobs.cpu()
model_runner_output = ModelRunnerOutput(
req_ids=self.input_batch.req_ids[:num_reqs],
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids_cpu=sampled_token_ids,
logprob_token_ids_cpu=logprob_token_ids,
logprobs_cpu=logprobs,
)
return model_runner_output
def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117
self.model = get_model(vllm_config=self.vllm_config)
2024-10-22 01:24:07 -07:00
self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB",
self.model_memory_usage / float(2**30))
@torch.inference_mode()
def _dummy_run(
self,
model: nn.Module,
num_tokens: int,
kv_caches: List[torch.Tensor],
) -> torch.Tensor:
with set_forward_context(None, self.vllm_config):
hidden_states = model(
input_ids=None,
positions=self.positions[:num_tokens],
kv_caches=kv_caches,
attn_metadata=None,
inputs_embeds=self.inputs_embeds[:num_tokens])
return hidden_states
def profile_run(self) -> None:
# TODO(woosuk): Profile the max memory usage of the encoder and
# the encoder cache.
# use an empty tensor instead of `None`` to force Dynamo to pass
# it by reference, rather by specializing on the value `None`.
# the `dtype` argument does not matter, and we use `float32` as
# a placeholder (it has wide hardware support).
# it is important to create tensors inside the loop, rather than
# multiplying the list, to avoid Dynamo from treating them as
# tensor aliasing.
dummy_kv_caches = [
torch.tensor([], dtype=torch.float32, device=self.device)
for _ in range(self.num_attn_layers)
]
# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.model, self.max_num_tokens,
dummy_kv_caches)
logits = self.model.compute_logits(hidden_states, None)
logits = logits[:self.max_num_tokens]
# TODO(woosuk): Consider the memory usage of the sampler.
2024-10-22 01:24:07 -07:00
torch.cuda.synchronize()
del hidden_states, logits
gc.collect()
2024-10-22 01:24:07 -07:00
def capture_model(self) -> None:
if not self.use_cuda_graph:
logger.warning(
"Skipping CUDA graph capture. Please add "
"-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE)
return
start_time = time.perf_counter()
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with graph_capture():
for num_tokens in reversed(self.cudagraph_batch_sizes):
for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups):
self._dummy_run(self.model, num_tokens, self.kv_caches)
self._dummy_run(self.model, num_tokens, self.kv_caches)
end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
elapsed_time = end_time - start_time
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
# This usually takes 5~20 seconds.
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time, cuda_graph_size / (1 << 30))
2024-10-22 01:24:07 -07:00
def initialize_kv_cache(self, num_blocks: int) -> None:
assert len(self.kv_caches) == 0
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks, self.block_size, self.num_kv_heads, self.head_size)
for _ in range(self.num_attn_layers):
self.kv_caches.append(
torch.zeros(kv_cache_shape,
dtype=self.kv_cache_dtype,
device=self.device))
def _get_padded_batch_size(self, batch_size: int) -> Optional[int]:
# TODO: Optimize this?
for size in self.cudagraph_batch_sizes:
if batch_size <= size:
return size
return None