diff --git a/Dockerfile b/Dockerfile index 343364da..4c0f5aeb 100644 --- a/Dockerfile +++ b/Dockerfile @@ -191,6 +191,9 @@ ADD . /vllm-workspace/ RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-dev.txt +# Copy in the v1 package for testing (it isn't distributed yet) +COPY vllm/v1 /usr/local/lib/python3.12/dist-packages/vllm/v1 + # doc requires source code # we hide them inside `test_docs/` , so that this source code # will not be imported by other tests diff --git a/pyproject.toml b/pyproject.toml index 35625696..1aebc543 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,4 +97,5 @@ markers = [ "skip_global_cleanup", "core_model: run this model test in each PR instead of just daily", "distributed_2_gpus: run this test only in distributed tests for 2 GPUs", + "skip_v1: do not run this test with v1", ] diff --git a/tests/conftest.py b/tests/conftest.py index f9dfabc8..6cf791dc 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ from collections import UserList from enum import Enum from typing import (Any, Callable, Dict, List, Optional, Tuple, Type, TypedDict, TypeVar, Union) +from unittest.mock import patch import numpy as np import pytest @@ -108,6 +109,23 @@ VIDEO_ASSETS = _VideoAssets() """Singleton instance of :class:`_VideoAssets`.""" +@pytest.fixture(params=[True, False]) +def run_with_both_engines(request): + # Automatically runs tests twice, once with V1 and once without + use_v1 = request.param + # Tests decorated with `@skip_v1` are only run without v1 + skip_v1 = request.node.get_closest_marker("skip_v1") + + if use_v1: + if skip_v1: + pytest.skip("Skipping test on vllm V1") + with patch('vllm.envs.VLLM_USE_V1', True): + yield + else: + with patch('vllm.envs.VLLM_USE_V1', False): + yield + + @pytest.fixture(autouse=True) def init_test_http_connection(): # pytest_asyncio may use a different event loop per test diff --git a/tests/entrypoints/llm/test_prompt_validation.py b/tests/entrypoints/llm/test_prompt_validation.py index 675a980a..ee7010a2 100644 --- a/tests/entrypoints/llm/test_prompt_validation.py +++ b/tests/entrypoints/llm/test_prompt_validation.py @@ -3,12 +3,21 @@ import pytest from vllm import LLM +@pytest.fixture(autouse=True) +def v1(run_with_both_engines): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + def test_empty_prompt(): llm = LLM(model="gpt2", enforce_eager=True) with pytest.raises(ValueError, match='Prompt cannot be empty'): llm.generate([""]) +@pytest.mark.skip_v1 def test_out_of_vocab_token(): llm = LLM(model="gpt2", enforce_eager=True) with pytest.raises(ValueError, match='out of vocabulary'): diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index 3fe9ca0b..169ce040 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -44,6 +44,8 @@ def test_env(name: str, device: str, monkeypatch): def test_flash_attn(monkeypatch): """Test FlashAttn validation.""" + # TODO: When testing for v1, pipe in `use_v1` as an argument to + # which_attn_to_use override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index a1dd5eee..3d3724c5 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -16,7 +16,7 @@ from tests.kernels.utils import * from vllm.attention import (Attention, AttentionBackend, AttentionMetadata, AttentionType) from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP -from vllm.attention.selector import (_Backend, get_attn_backend, +from vllm.attention.selector import (_Backend, _cached_get_attn_backend, global_force_attn_backend_context_manager) from vllm.forward_context import set_forward_context from vllm.platforms import current_platform @@ -774,7 +774,7 @@ def set_reset_environment(attn_backend): default_dtype = torch.get_default_dtype() if attn_backend.name == 'FLASH_ATTN': torch.set_default_dtype(torch.bfloat16) - get_attn_backend.cache_clear() + _cached_get_attn_backend.cache_clear() yield # Reset the torch datatype to what it was before the test # so as not to impact the remaining tests. diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 991602da..664707e9 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -89,7 +89,6 @@ def get_global_forced_attn_backend() -> Optional[_Backend]: return forced_attn_backend -@lru_cache(maxsize=None) def get_attn_backend( head_size: int, dtype: torch.dtype, @@ -99,6 +98,31 @@ def get_attn_backend( is_blocksparse: bool = False, ) -> Type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" + # Accessing envs.* behind an @lru_cache decorator can cause the wrong + # value to be returned from the cache if the value changes between calls. + # To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the + # private function. + return _cached_get_attn_backend( + head_size=head_size, + dtype=dtype, + kv_cache_dtype=kv_cache_dtype, + block_size=block_size, + is_attention_free=is_attention_free, + is_blocksparse=is_blocksparse, + use_v1=envs.VLLM_USE_V1, + ) + + +@lru_cache(maxsize=None) +def _cached_get_attn_backend( + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + is_attention_free: bool, + is_blocksparse: bool = False, + use_v1: bool = False, +) -> Type[AttentionBackend]: if is_blocksparse: logger.info("Using BlocksparseFlashAttention backend.") from vllm.attention.backends.blocksparse_attn import ( @@ -106,7 +130,7 @@ def get_attn_backend( return BlocksparseFlashAttentionBackend backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size, - is_attention_free) + is_attention_free, use_v1) if backend == _Backend.FLASH_ATTN: logger.info("Using Flash Attention backend.") from vllm.attention.backends.flash_attn import ( # noqa: F401 @@ -162,13 +186,12 @@ def get_attn_backend( raise ValueError("Invalid attention backend.") -def which_attn_to_use( - head_size: int, - dtype: torch.dtype, - kv_cache_dtype: Optional[str], - block_size: int, - is_attention_free: bool, -) -> _Backend: +def which_attn_to_use(head_size: int, + dtype: torch.dtype, + kv_cache_dtype: Optional[str], + block_size: int, + is_attention_free: bool, + use_v1: bool = False) -> _Backend: """Returns which flash attention backend to use.""" # Default case. selected_backend = _Backend.FLASH_ATTN @@ -228,7 +251,7 @@ def which_attn_to_use( if current_platform.is_hpu(): return _Backend.HPU_ATTN - if envs.VLLM_USE_V1: + if use_v1: return _Backend.FLASH_ATTN_VLLM_V1 # FlashAttn in NVIDIA GPUs. diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index e1dcb828..889845ee 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -6,7 +6,9 @@ from typing import Iterator, List, Optional, Union import cloudpickle import zmq +import vllm.envs from vllm import AsyncEngineArgs, SamplingParams +from vllm.engine.llm_engine import LLMEngine # yapf conflicts with isort for this block # yapf: disable from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, @@ -17,17 +19,11 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, RPCStartupRequest, RPCStartupResponse, RPCUProfileRequest) # yapf: enable -from vllm.envs import VLLM_USE_V1 from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext -if VLLM_USE_V1: - from vllm.v1.engine.llm_engine import LLMEngine -else: - from vllm.engine.llm_engine import LLMEngine - logger = init_logger(__name__) POLLING_TIMEOUT_MS = 10000 @@ -117,11 +113,17 @@ class MQLLMEngine: load_general_plugins() engine_config = engine_args.create_engine_config() + if vllm.envs.VLLM_USE_V1: + # Lazy import: the v1 package isn't distributed + from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine + engine_class = V1LLMEngine + else: + engine_class = LLMEngine - executor_class = LLMEngine._get_executor_cls(engine_config) + executor_class = engine_class._get_executor_cls(engine_config) use_async_sockets = (engine_config.model_config.use_async_output_proc - and not VLLM_USE_V1) + and not vllm.envs.VLLM_USE_V1) return cls(ipc_path=ipc_path, use_async_sockets=use_async_sockets, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b18974c5..d8b60a5e 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,7 +1,7 @@ import itertools import warnings from contextlib import contextmanager -from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, +from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, Union, cast, overload) from tqdm import tqdm @@ -10,6 +10,7 @@ from vllm import envs from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, BeamSearchSequence, get_beam_search_score) from vllm.engine.arg_utils import EngineArgs, TaskOption +from vllm.engine.llm_engine import LLMEngine from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam, apply_hf_chat_template, apply_mistral_chat_template, @@ -31,11 +32,6 @@ from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.usage.usage_lib import UsageContext from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of -if envs.VLLM_USE_V1: - from vllm.v1.engine.llm_engine import LLMEngine # type: ignore -else: - from vllm.engine.llm_engine import LLMEngine # type: ignore - logger = init_logger(__name__) @@ -206,10 +202,21 @@ class LLM: pooling_returned_token_ids=pooling_returned_token_ids, **kwargs, ) - self.llm_engine = LLMEngine.from_engine_args( + # Logic to switch between engines is done at runtime instead of import + # to avoid import order issues + self.engine_class = self.get_engine_class() + self.llm_engine = self.engine_class.from_engine_args( engine_args, usage_context=UsageContext.LLM_CLASS) self.request_counter = Counter() + @staticmethod + def get_engine_class() -> Type[LLMEngine]: + if envs.VLLM_USE_V1: + # Lazy import: the v1 package isn't distributed + from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine + return V1LLMEngine # type: ignore + return LLMEngine + def get_tokenizer(self) -> AnyTokenizer: return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer @@ -394,7 +401,7 @@ class LLM: priority=priority) outputs = self._run_engine(use_tqdm=use_tqdm) - return LLMEngine.validate_outputs(outputs, RequestOutput) + return self.engine_class.validate_outputs(outputs, RequestOutput) def beam_search( self, @@ -769,7 +776,8 @@ class LLM: ) outputs = self._run_engine(use_tqdm=use_tqdm) - return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput) + return self.engine_class.validate_outputs(outputs, + EmbeddingRequestOutput) def start_profile(self) -> None: self.llm_engine.start_profile() diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index f86c6ec3..c10efefe 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -30,6 +30,15 @@ if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): else: flashinfer_top_k_top_p_sampling = None + +def get_sampler() -> torch.nn.Module: + if envs.VLLM_USE_V1: + # Lazy import: the v1 package isn't distributed + from vllm.v1.sample.sampler import Sampler as V1Sampler + return V1Sampler() + return Sampler() + + # (num_token_ids, num_parent_ids) per sequence group. SampleResultType = List[Tuple[List[int], List[int]]] diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 5b712ba8..4fec314a 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -23,7 +23,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.deepspeedfp import ( DeepSpeedFPConfig, DeepSpeedFPParameter) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -436,7 +436,7 @@ class ArcticForCausalLM(nn.Module, SupportsPP): self.unpadded_vocab_size = config.vocab_size self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 1fbf4135..cce182da 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -37,7 +37,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -352,7 +352,7 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 85de1a81..fd600adc 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -34,7 +34,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -838,7 +838,7 @@ class BartForConditionalGeneration(nn.Module): self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() def forward( self, diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index db1f9264..efd24e7c 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -13,7 +13,7 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.utils import consecutive_placeholder_ranges @@ -525,7 +525,7 @@ class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): if hasattr(self.language_model, "sampler"): return self.language_model.sampler - return Sampler() + return get_sampler() def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index b2c109a2..c2440ee7 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -33,7 +33,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -298,7 +298,7 @@ class BloomForCausalLM(nn.Module, SupportsPP): self.config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 9f6c6786..58841f17 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -21,7 +21,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -946,7 +946,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 881b8656..032fa82a 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -24,7 +24,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -616,7 +616,7 @@ class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP, self.transformer.embedding.weight) self.lm_head = self.transformer.output_layer self.logits_processor = LogitsProcessor(config.padded_vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() def forward(self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 835682ca..718f26be 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -37,7 +37,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -355,7 +355,7 @@ class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP): cache_config, quant_config, lora_config=lora_config) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 3e60eee2..ae433831 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -14,7 +14,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -373,7 +373,7 @@ class DbrxForCausalLM(nn.Module, SupportsPP): ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index d278ea5b..53a1c7cf 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -41,7 +41,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -399,7 +399,7 @@ class DeepseekForCausalLM(nn.Module, SupportsPP): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 834be78b..95bbf4fb 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -42,7 +42,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -496,7 +496,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP): config.hidden_size, quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 23efe035..a8d591b9 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -42,7 +42,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -478,7 +478,7 @@ class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) - self.sampler = Sampler() + self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 6f8a7a70..daf49521 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -38,7 +38,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -426,7 +426,7 @@ class FalconForCausalLM(nn.Module, SupportsPP): quant_config=quant_config, ) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index 6840ac8b..184bee5f 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -10,7 +10,7 @@ from vllm.config import CacheConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.bart import (BartDecoder, BartEncoder, BartParallelLMHead, @@ -112,7 +112,7 @@ class Florence2LanguageForConditionalGeneration(nn.Module): self.logits_processor = LogitsProcessor(self.vocab_size, config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() def forward( self, diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index fc3f5cb2..1cc3ea67 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -33,7 +33,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -393,7 +393,7 @@ class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): quant_config, prefix=maybe_prefix(prefix, "model")) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index c3658801..16e0d6b3 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -33,7 +33,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -414,7 +414,7 @@ class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.model = Gemma2Model(config, cache_config, quant_config) self.logits_processor = LogitsProcessor( config.vocab_size, soft_cap=config.final_logit_softcapping) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 8147037e..7f81bbff 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -33,7 +33,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -259,7 +259,7 @@ class GPT2LMHeadModel(nn.Module, SupportsPP): self.lm_head = ParallelLMHead(self.config.vocab_size, self.config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 9f44fa76..4be8e419 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -33,7 +33,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -285,7 +285,7 @@ class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 6fcccdfb..834b4aff 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -32,7 +32,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -247,7 +247,7 @@ class GPTJForCausalLM(nn.Module, SupportsPP): quant_config=quant_config, ) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index d3f86558..1903156d 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -32,7 +32,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -260,7 +260,7 @@ class GPTNeoXForCausalLM(nn.Module, SupportsPP): if self.config.tie_word_embeddings: self.embed_out.weight = self.gpt_neox.embed_in.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.gpt_neox.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index bee48f37..8a75b9cb 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -42,7 +42,7 @@ from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -411,7 +411,7 @@ class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, scale=logit_scale) - self.sampler = Sampler() + self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index 691a6e77..b4da986e 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -39,7 +39,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -371,7 +371,7 @@ class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): scale=1 / self.config.logits_scaling) - self.sampler = Sampler() + self.sampler = get_sampler() def forward( self, diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index afefb6cd..7ddb1e2a 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -20,7 +20,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -338,7 +338,7 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP): if self.config.tie_word_embeddings: self.output.weight = self.model.tok_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index d2ec0ff6..bb9d3888 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -21,7 +21,7 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.model_executor.layers.quantization import (AWQConfig, QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.models.intern_vit import (InternVisionModel, InternVisionPatchModel) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -467,7 +467,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): if hasattr(self.language_model, "sampler"): return self.language_model.sampler - return Sampler() + return get_sampler() def _init_vision_model( self, diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 301893f7..23fdca09 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -34,7 +34,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -308,7 +308,7 @@ class JAISLMHeadModel(nn.Module, SupportsPP): config.mup_width_scale) self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size, scale=self.output_logits_scale) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 81d88a47..9b18a1b6 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -17,7 +17,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -383,7 +383,7 @@ class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA): self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() def forward(self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index d768a57b..9e8a403b 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -42,7 +42,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -536,7 +536,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) - self.sampler = Sampler() + self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 7fbd59eb..bdd67b12 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -14,7 +14,7 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors @@ -302,7 +302,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): if hasattr(self.language_model, "sampler"): return self.language_model.sampler - return Sampler() + return get_sampler() def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 7a2c9559..37b8baa8 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -16,7 +16,7 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext) from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -327,7 +327,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, if hasattr(self.language_model, "sampler"): return self.language_model.sampler - return Sampler() + return get_sampler() def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: expected_dims = (2, ) diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index b755e234..69bfc80a 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -15,7 +15,7 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -289,7 +289,7 @@ class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, if hasattr(self.language_model, "sampler"): return self.language_model.sampler - return Sampler() + return get_sampler() def _validate_video_pixel_values( self, data: Union[torch.Tensor, List[torch.Tensor]] diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index f410d645..26ece819 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -19,7 +19,7 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.utils import (cached_get_tokenizer, @@ -437,7 +437,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, if hasattr(self.language_model, "sampler"): return self.language_model.sampler - return Sampler() + return get_sampler() def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: expected_dims = (2, ) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index aac4b7aa..91161957 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -13,7 +13,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mamba.mamba_mixer import MambaMixer from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -169,7 +169,7 @@ class MambaForCausalLM(nn.Module, HasInnerState, IsAttentionFree): self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() def forward(self, input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index acf03cd8..7704431a 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -43,7 +43,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -496,7 +496,7 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.logits_processor = LogitsProcessor(unpadded_vocab_size, config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 5acd3f65..4ffe33bb 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -41,7 +41,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.resampler import (BaseResampler, Resampler2, get_2d_sincos_pos_embed) -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -420,7 +420,7 @@ class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): quant_config=quant_config, prefix="llm.lm_head") self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.llm.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index e9b9c4d8..f5c28e7d 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -38,7 +38,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -366,7 +366,7 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 9647d69b..007c4e2e 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -40,7 +40,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -366,7 +366,7 @@ class MixtralForCausalLM(nn.Module, SupportsPP): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 5fa8d19b..d442ffe3 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -44,7 +44,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -1141,7 +1141,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal): ) self.logits_processor = LogitsProcessor(config.output_hidden_states, config.text_config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() def compute_logits( self, diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index ae218d74..fde44265 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -6,7 +6,7 @@ import torch.nn as nn from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -137,7 +137,7 @@ class MLPSpeculator(nn.Module): self.config = config self.logits_processor = LogitsProcessor(config.vocab_size, config.vocab_size, 1.0) - self.sampler = Sampler() + self.sampler = get_sampler() def generate_proposals( self, diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 785b5367..3a50923d 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -33,7 +33,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -1053,7 +1053,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): self.logits_processor = LogitsProcessor(config.embedding_size or config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 7f0658f4..b3977812 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -16,7 +16,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -281,7 +281,7 @@ class MPTForCausalLM(nn.Module, SupportsPP): self.transformer = MPTModel(config, cache_config, quant_config) self.lm_head = self.transformer.wte self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index b6490645..8d128a42 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -36,7 +36,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -441,7 +441,7 @@ class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) - self.sampler = Sampler() + self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() self.make_empty_intermediate_tensors = ( diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index dd3f5828..545d86ee 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -37,7 +37,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -309,7 +309,7 @@ class OlmoForCausalLM(nn.Module, SupportsPP): quant_config=quant_config, ) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index 374cbb8d..de30b527 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -28,7 +28,7 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -323,7 +323,7 @@ class OlmoeForCausalLM(nn.Module, SupportsPP): config.hidden_size, quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index d140f423..a453376d 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -33,7 +33,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -362,7 +362,7 @@ class OPTForCausalLM(nn.Module, SupportsPP): self.lm_head = ParallelLMHead(config.vocab_size, config.word_embed_proj_dim) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index a338a93c..d6ec1fb6 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -20,7 +20,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -284,7 +284,7 @@ class OrionForCausalLM(nn.Module, SupportsPP): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index 112bf6f3..11e7c8ab 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -36,7 +36,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -279,7 +279,7 @@ class PersimmonForCausalLM(nn.Module, SupportsPP): config.hidden_size, bias=False) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index d308f491..4dae6e32 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -51,7 +51,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -300,7 +300,7 @@ class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): bias=True, quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index 3a7afc60..92bf0e61 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -15,7 +15,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -386,7 +386,7 @@ class Phi3SmallForCausalLM(nn.Module, SupportsPP): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 1c41891c..a84d6b31 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -32,7 +32,7 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.logger import init_logger from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.models.clip import CLIPVisionModel @@ -570,7 +570,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): if hasattr(self.language_model, "sampler"): return self.language_model.sampler - return Sampler() + return get_sampler() def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: expected_dims = (2, ) diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 59843ae3..19e2621e 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -38,7 +38,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -562,7 +562,7 @@ class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 6e909243..facf1969 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -25,7 +25,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import merge_multimodal_embeddings from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -190,7 +190,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, if hasattr(self.language_model, "sampler"): return self.language_model.sampler - return Sampler() + return get_sampler() def forward( self, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 4044ddbb..c91c2caa 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -36,7 +36,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -884,7 +884,7 @@ class QWenBaseModel(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA): if self.config.tie_word_embeddings: self.lm_head.weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.transformer.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 49b3de13..1e99c1b1 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -39,7 +39,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -444,7 +444,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): prefix, "lm_head")) self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index 556c0940..54a7085f 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -36,7 +36,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) @@ -295,7 +295,7 @@ class Qwen2AudioForConditionalGeneration(nn.Module, SupportsMultiModal, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.text_config.vocab_size, logit_scale) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 98bb48a2..c8c48c08 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -44,7 +44,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -393,7 +393,7 @@ class Qwen2MoeForCausalLM(nn.Module, SupportsPP): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index fad9137d..af263262 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -52,7 +52,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.qwen2 import Qwen2Model @@ -990,7 +990,7 @@ class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 1b233ac7..931e48a4 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -42,7 +42,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -449,7 +449,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) - self.sampler = Sampler() + self.sampler = get_sampler() else: self.lm_head = PPMissingLayer() diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 34389b64..4cb55506 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -34,7 +34,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -261,7 +261,7 @@ class StablelmForCausalLM(nn.Module, SupportsPP): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index a5e4155f..0b0e3f21 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -34,7 +34,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -269,7 +269,7 @@ class Starcoder2ForCausalLM(nn.Module, SupportsPP): ) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 749750fc..3a343986 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -21,7 +21,7 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalInputs, @@ -379,7 +379,7 @@ class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): if hasattr(self.language_model, "sampler"): return self.language_model.sampler - return Sampler() + return get_sampler() def _audio_features_to_embeddings( self, input_features: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index e559988a..1d08b382 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -37,7 +37,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -334,7 +334,7 @@ class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + self.sampler = get_sampler() self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 906f0677..e73a1e60 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -136,7 +136,7 @@ class FlashAttentionImpl(AttentionImpl): "key/v_scale is not supported in FlashAttention.") output = torch.empty_like(query) - torch.ops.vllm.unified_flash_attention( + torch.ops.vllm.unified_v1_flash_attention( output, query, key, @@ -156,7 +156,7 @@ class FlashAttentionImpl(AttentionImpl): return output -def unified_flash_attention( +def unified_v1_flash_attention( output: torch.Tensor, query: torch.Tensor, key: torch.Tensor, @@ -222,7 +222,7 @@ def unified_flash_attention( output[:num_actual_tokens].copy_(attn_output) -def unified_flash_attention_fake( +def unified_v1_flash_attention_fake( output: torch.Tensor, query: torch.Tensor, key: torch.Tensor, @@ -243,8 +243,8 @@ def unified_flash_attention_fake( direct_register_custom_op( - op_name="unified_flash_attention", - op_func=unified_flash_attention, + op_name="unified_v1_flash_attention", + op_func=unified_v1_flash_attention, mutates_args=["kv_cache", "output"], - fake_impl=unified_flash_attention_fake, + fake_impl=unified_v1_flash_attention_fake, ) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 64cc1814..5f572048 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -155,6 +155,12 @@ class LLMEngine: # GPU and CPU blocks, which are profiled in the distributed executor. self.scheduler = Scheduler(scheduler_config, cache_config, lora_config) + def __del__(self): + # Small hack- implicit clean up of resources on garbage collect + # TODO: this should probably be explicitly invoked when we're done with + # the engine + self.terminate_detokenizer() + def _initialize_kv_caches(self) -> None: num_gpu_blocks, _ = self.model_executor.determine_num_available_blocks( ) diff --git a/vllm/v1/tokenizer/detokenizer.py b/vllm/v1/tokenizer/detokenizer.py index 4bbcf471..e485fcc3 100644 --- a/vllm/v1/tokenizer/detokenizer.py +++ b/vllm/v1/tokenizer/detokenizer.py @@ -73,7 +73,7 @@ class Detokenizer: return None def terminate(self) -> None: - self.push_socket.send(b"", flags=zmq.NOBLOCK) + self.detokenizer.kill() self.detokenizer.join() @@ -108,10 +108,10 @@ class DetokenizerProc(multiprocessing.Process): self.push_socket.bind(f"tcp://*:{self.push_port}") while True: + if self.pull_socket.poll(timeout=1000) == 0: + # Nothing to read + continue message = self.pull_socket.recv() - if message == b"": - # Terminate signal. - break inputs = self.msgpack_decoder.decode(message) for req_id in inputs.free_req_ids: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 63bf7c2e..e6383b59 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2,7 +2,6 @@ import os import time from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Set -from unittest.mock import patch import numpy as np import torch @@ -26,7 +25,6 @@ from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, FlashAttentionMetadata) from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.sample.sampler import Sampler if TYPE_CHECKING: from vllm.v1.core.scheduler import SchedulerOutput @@ -418,8 +416,7 @@ class GPUModelRunner: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 - with patch("vllm.model_executor.layers.sampler.Sampler", Sampler): - self.model = get_model(vllm_config=self.vllm_config) + self.model = get_model(vllm_config=self.vllm_config) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB",