[Misc] Consolidate ModelConfig code related to HF config (#10104)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
1fa020c539
commit
db7db4aab9
@ -359,7 +359,7 @@ Feature x Hardware
|
||||
- ✅
|
||||
- ✅
|
||||
- ✅
|
||||
- `✗ <https://github.com/vllm-project/vllm/blob/a84e598e2125960d3b4f716b78863f24ac562947/vllm/worker/cpu_model_runner.py#L125>`__
|
||||
- ✅
|
||||
- ✗
|
||||
* - :abbr:`logP (Logprobs)`
|
||||
- ✅
|
||||
|
@ -165,3 +165,41 @@ def test_rope_customization():
|
||||
assert getattr(longchat_model_config.hf_config, "rope_scaling",
|
||||
None) == TEST_ROPE_SCALING
|
||||
assert longchat_model_config.max_model_len == 4096
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model_id", "is_encoder_decoder"), [
|
||||
("facebook/opt-125m", False),
|
||||
("facebook/bart-base", True),
|
||||
("meta-llama/Llama-3.2-1B", False),
|
||||
("meta-llama/Llama-3.2-11B-Vision", True),
|
||||
])
|
||||
def test_is_encoder_decoder(model_id, is_encoder_decoder):
|
||||
config = ModelConfig(
|
||||
model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float16",
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert config.is_encoder_decoder == is_encoder_decoder
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model_id", "uses_mrope"), [
|
||||
("facebook/opt-125m", False),
|
||||
("Qwen/Qwen2-VL-2B-Instruct", True),
|
||||
])
|
||||
def test_uses_mrope(model_id, uses_mrope):
|
||||
config = ModelConfig(
|
||||
model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float16",
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert config.uses_mrope == uses_mrope
|
||||
|
@ -15,7 +15,8 @@ from vllm.platforms import current_platform
|
||||
from vllm.tracing import is_otel_available, otel_import_error_traceback
|
||||
from vllm.transformers_utils.config import (ConfigFormat, get_config,
|
||||
get_hf_image_processor_config,
|
||||
get_hf_text_config)
|
||||
get_hf_text_config,
|
||||
is_encoder_decoder, uses_mrope)
|
||||
from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
|
||||
print_warning_once)
|
||||
|
||||
@ -667,12 +668,13 @@ class ModelConfig:
|
||||
return self.multimodal_config
|
||||
|
||||
@property
|
||||
def is_encoder_decoder_model(self) -> bool:
|
||||
def is_encoder_decoder(self) -> bool:
|
||||
"""Extract the HF encoder/decoder model flag."""
|
||||
return getattr(
|
||||
self.hf_config, "is_encoder_decoder",
|
||||
False) or (hasattr(self.hf_config, "text_config") and getattr(
|
||||
self.hf_config.text_config, "is_encoder_decoder", False))
|
||||
return is_encoder_decoder(self.hf_config)
|
||||
|
||||
@property
|
||||
def uses_mrope(self) -> bool:
|
||||
return uses_mrope(self.hf_config)
|
||||
|
||||
@property
|
||||
def is_multimodal_model(self) -> bool:
|
||||
|
@ -580,4 +580,4 @@ class InputPreprocessor:
|
||||
)
|
||||
|
||||
def is_encoder_decoder_model(self):
|
||||
return self.model_config.is_encoder_decoder_model
|
||||
return self.model_config.is_encoder_decoder
|
||||
|
@ -129,6 +129,15 @@ def uses_mrope(config: PretrainedConfig) -> bool:
|
||||
return "mrope_section" in rope_scaling
|
||||
|
||||
|
||||
def is_encoder_decoder(config: PretrainedConfig) -> bool:
|
||||
"""Detect if the model with this config is used as an encoder/decoder."""
|
||||
text_config = getattr(config, "text_config", None)
|
||||
if text_config is not None:
|
||||
return is_encoder_decoder(text_config)
|
||||
|
||||
return getattr(config, "is_encoder_decoder", False)
|
||||
|
||||
|
||||
def get_config(
|
||||
model: Union[str, Path],
|
||||
trust_remote_code: bool,
|
||||
|
@ -88,9 +88,6 @@ STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not "
|
||||
"currently supported with encoder/"
|
||||
"decoder models.")
|
||||
|
||||
STR_NOT_IMPL_ENC_DEC_CPU = ("CPU is not currently supported with "
|
||||
"encoder/decoder models.")
|
||||
|
||||
# Efficiently import all enc/dec error strings
|
||||
# rather than having to import all of the above
|
||||
STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
|
||||
@ -105,7 +102,6 @@ STR_NOT_IMPL_ENC_DEC_ERR_STRS = {
|
||||
"STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC,
|
||||
"STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND,
|
||||
"STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER,
|
||||
"STR_NOT_IMPL_ENC_DEC_CPU": STR_NOT_IMPL_ENC_DEC_CPU
|
||||
}
|
||||
|
||||
# Constants related to forcing the attention backend selection
|
||||
|
@ -18,7 +18,6 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||
MultiModalInputs, MultiModalPlaceholderMap)
|
||||
from vllm.sequence import (IntermediateTensors, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
from vllm.worker.model_runner_base import (
|
||||
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
|
||||
@ -163,7 +162,7 @@ class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
|
||||
# special processing for mrope position deltas.
|
||||
mrope_positions = None
|
||||
if self.runner.model_is_mrope:
|
||||
if self.runner.model_config.uses_mrope:
|
||||
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
|
||||
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
|
||||
assert image_grid_thw is not None or video_grid_thw is not None, (
|
||||
@ -446,12 +445,6 @@ class CPUModelRunner(ModelRunnerBase[ModelInputForCPU]):
|
||||
# Lazy initialization.
|
||||
self.model: nn.Module # Set after init_Model
|
||||
|
||||
@property
|
||||
def model_is_mrope(self) -> bool:
|
||||
"""Detect if the model has "mrope" rope_scaling type.
|
||||
mrope requires keep "rope_deltas" between prompt and decoding phases."""
|
||||
return uses_mrope(self.model_config.hf_config)
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
|
||||
|
@ -151,7 +151,7 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
self.local_omp_cpuid = omp_cpuids.split("|")[rank]
|
||||
|
||||
ModelRunnerClass: Type[CPUModelRunner] = CPUModelRunner
|
||||
if self._is_encoder_decoder_model():
|
||||
if self.model_config.is_encoder_decoder:
|
||||
ModelRunnerClass = CPUEncoderDecoderModelRunner
|
||||
self.model_runner: CPUModelRunner = ModelRunnerClass(
|
||||
vllm_config=vllm_config,
|
||||
@ -188,9 +188,6 @@ class CPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
self.profiler.stop()
|
||||
|
||||
def _is_encoder_decoder_model(self):
|
||||
return self.model_config.is_encoder_decoder_model
|
||||
|
||||
def init_device(self) -> None:
|
||||
if self.local_omp_cpuid != "all":
|
||||
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
|
||||
|
@ -47,7 +47,6 @@ from vllm.prompt_adapter.worker_manager import (
|
||||
LRUCacheWorkerPromptAdapterManager)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||
from vllm.transformers_utils.config import uses_mrope
|
||||
from vllm.utils import (DeviceMemoryProfiler, GiB_bytes, PyObjectCache,
|
||||
async_tensor_h2d, flatten_2d_lists,
|
||||
is_pin_memory_available, supports_dynamo,
|
||||
@ -493,7 +492,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
seq_len = min(seq_len, context_len + token_chunk_size)
|
||||
elif self.runner.scheduler_config.is_multi_step or \
|
||||
self.runner.model_config.is_encoder_decoder_model:
|
||||
self.runner.model_config.is_encoder_decoder:
|
||||
context_len = seq_len - 1
|
||||
else:
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
@ -666,7 +665,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
inter_data.multi_modal_placeholder_maps = placeholder_maps
|
||||
|
||||
# special processing for mrope position deltas.
|
||||
if self.runner.model_is_mrope:
|
||||
if self.runner.model_config.uses_mrope:
|
||||
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
|
||||
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
|
||||
assert image_grid_thw is not None or video_grid_thw is not None, (
|
||||
@ -711,7 +710,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
|
||||
encoder_seq_len = 0
|
||||
|
||||
if self.runner.model_config.is_encoder_decoder_model:
|
||||
if self.runner.model_config.is_encoder_decoder:
|
||||
encoder_seq_len = seq_group_metadata.encoder_seq_data.get_len()
|
||||
|
||||
inter_data = self.init_cached_inter_data(
|
||||
@ -837,7 +836,7 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
||||
if not inter_data.is_prompt:
|
||||
max_decode_seq_len = max(max_decode_seq_len,
|
||||
max(inter_data.seq_lens))
|
||||
if self.runner.model_config.is_encoder_decoder_model:
|
||||
if self.runner.model_config.is_encoder_decoder:
|
||||
max_encoder_seq_len = max(max_encoder_seq_len,
|
||||
inter_data.encoder_seq_len)
|
||||
|
||||
@ -1375,12 +1374,6 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
raise RuntimeError("PromptAdapter is not enabled.")
|
||||
return self.prompt_adapter_manager.list_adapters()
|
||||
|
||||
@property
|
||||
def model_is_mrope(self) -> bool:
|
||||
"""Detect if the model has "mrope" rope_scaling type.
|
||||
mrope requires keep "rope_deltas" between prompt and decoding phases."""
|
||||
return uses_mrope(self.model_config.hf_config)
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None:
|
||||
"""Cuda graph capture a model.
|
||||
@ -1411,7 +1404,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
max_batch_size = self.max_batchsize_to_capture
|
||||
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
||||
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
||||
if self.model_is_mrope:
|
||||
if self.model_config.uses_mrope:
|
||||
input_positions = torch.tile(input_positions, (3, 1))
|
||||
# Prepare dummy previous_hidden_states only if needed by the model.
|
||||
# This is used by draft models such as EAGLE.
|
||||
@ -1447,7 +1440,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self.attn_state.graph_capture_get_metadata_for_batch(
|
||||
batch_size,
|
||||
is_encoder_decoder_model=self.model_config.
|
||||
is_encoder_decoder_model))
|
||||
is_encoder_decoder))
|
||||
|
||||
if self.lora_config:
|
||||
lora_mapping = LoRAMapping(
|
||||
@ -1466,7 +1459,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
graph_runner = CUDAGraphRunner(
|
||||
self.model, self.attn_backend.get_name(),
|
||||
self.attn_state.graph_clone(batch_size),
|
||||
self.model_config.is_encoder_decoder_model)
|
||||
self.model_config.is_encoder_decoder)
|
||||
|
||||
capture_inputs = {
|
||||
"input_ids":
|
||||
@ -1497,7 +1490,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self.model.get_seqlen_agnostic_capture_inputs(
|
||||
batch_size)
|
||||
})
|
||||
if self.model_config.is_encoder_decoder_model:
|
||||
if self.model_config.is_encoder_decoder:
|
||||
# add the additional inputs to capture for
|
||||
# encoder-decoder models.
|
||||
self._update_inputs_to_capture_for_enc_dec_model(
|
||||
|
@ -77,7 +77,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
ModelRunnerClass = model_runner_cls
|
||||
elif model_config.task == "embedding":
|
||||
ModelRunnerClass = EmbeddingModelRunner
|
||||
elif self._is_encoder_decoder_model():
|
||||
elif self.model_config.is_encoder_decoder:
|
||||
ModelRunnerClass = EncoderDecoderModelRunner
|
||||
self.model_runner: GPUModelRunnerBase = ModelRunnerClass(
|
||||
vllm_config=self.vllm_config,
|
||||
@ -119,9 +119,6 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
self.profiler.stop()
|
||||
|
||||
def _is_encoder_decoder_model(self):
|
||||
return self.model_config.is_encoder_decoder_model
|
||||
|
||||
def init_device(self) -> None:
|
||||
if self.device_config.device.type == "cuda":
|
||||
# torch.distributed.all_reduce does not free the input tensor until
|
||||
|
Loading…
x
Reference in New Issue
Block a user