[Core] Interface for accessing model from VllmRunner (#10353)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-01-20 15:00:59 +08:00 committed by GitHub
parent 83609791d2
commit 59a0192fb9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 460 additions and 293 deletions

View File

@ -244,6 +244,7 @@ def video_assets() -> _VideoAssets:
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
_R = TypeVar("_R")
class HfRunner:
@ -930,6 +931,10 @@ class VllmRunner:
req_outputs = self.model.score(text_1, text_2)
return [req_output.outputs.score for req_output in req_outputs]
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
executor = self.model.llm_engine.model_executor
return executor.apply_model(func)
def __enter__(self):
return self

View File

@ -51,7 +51,9 @@ def test_custom_executor(model, tmp_path):
assert not os.path.exists(".marker")
engine_args = EngineArgs(
model=model, distributed_executor_backend=CustomUniExecutor)
model=model,
distributed_executor_backend=CustomUniExecutor,
)
engine = LLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(max_tokens=1)

View File

@ -25,13 +25,12 @@ def test_model_loading_with_params(vllm_runner):
with vllm_runner(model_name=MODEL_NAME,
revision=REVISION,
dtype="float16",
max_model_len=MAX_MODEL_LEN) as model:
output = model.encode("Write a short story about a robot that"
" dreams for the first time.\n")
max_model_len=MAX_MODEL_LEN) as vllm_model:
output = vllm_model.encode("Write a short story about a robot that"
" dreams for the first time.\n")
model_config = model.model.llm_engine.model_config
model_tokenizer = model.model.llm_engine.tokenizer
model_config = vllm_model.model.llm_engine.model_config
model_tokenizer = vllm_model.model.llm_engine.tokenizer
# asserts on the bert model config file
assert model_config.encoder_config["max_seq_length"] == 512
@ -46,11 +45,13 @@ def test_model_loading_with_params(vllm_runner):
assert model_tokenizer.tokenizer_config["do_lower_case"]
assert model_tokenizer.tokenizer.model_max_length == 512
model = model.model.llm_engine.model_executor\
.driver_worker.model_runner.model
assert isinstance(model, BertEmbeddingModel)
assert model._pooler.pooling_type == PoolingType.CLS
assert model._pooler.normalize
def check_model(model):
assert isinstance(model, BertEmbeddingModel)
assert model._pooler.pooling_type == PoolingType.CLS
assert model._pooler.normalize
vllm_model.apply_model(check_model)
# assert output
assert output
@ -64,13 +65,12 @@ def test_roberta_model_loading_with_params(vllm_runner):
with vllm_runner(model_name=MODEL_NAME_ROBERTA,
revision=REVISION_ROBERTA,
dtype="float16",
max_model_len=MAX_MODEL_LEN) as model:
output = model.encode("Write a short story about a robot that"
" dreams for the first time.\n")
max_model_len=MAX_MODEL_LEN) as vllm_model:
output = vllm_model.encode("Write a short story about a robot that"
" dreams for the first time.\n")
model_config = model.model.llm_engine.model_config
model_tokenizer = model.model.llm_engine.tokenizer
model_config = vllm_model.model.llm_engine.model_config
model_tokenizer = vllm_model.model.llm_engine.tokenizer
# asserts on the bert model config file
assert model_config.encoder_config["max_seq_length"] == 512
@ -84,11 +84,12 @@ def test_roberta_model_loading_with_params(vllm_runner):
assert model_tokenizer.tokenizer_id == "intfloat/multilingual-e5-large"
assert not model_tokenizer.tokenizer_config["do_lower_case"]
model = model.model.llm_engine.model_executor\
.driver_worker.model_runner.model
assert isinstance(model, RobertaEmbeddingModel)
assert model._pooler.pooling_type == PoolingType.MEAN
assert model._pooler.normalize
def check_model(model):
assert isinstance(model, RobertaEmbeddingModel)
assert model._pooler.pooling_type == PoolingType.MEAN
assert model._pooler.normalize
vllm_model.apply_model(check_model)
# assert output
assert output
@ -103,17 +104,18 @@ def test_facebook_roberta_model_loading_with_params(vllm_runner):
model_name = "FacebookAI/roberta-base"
with vllm_runner(model_name=model_name,
dtype="float16",
max_model_len=MAX_MODEL_LEN) as model:
output = model.encode("Write a short story about a robot that"
" dreams for the first time.\n")
max_model_len=MAX_MODEL_LEN) as vllm_model:
output = vllm_model.encode("Write a short story about a robot that"
" dreams for the first time.\n")
model_tokenizer = model.model.llm_engine.tokenizer
model_tokenizer = vllm_model.model.llm_engine.tokenizer
assert model_tokenizer.tokenizer_id == model_name
model = model.model.llm_engine.model_executor\
.driver_worker.model_runner.model
assert not hasattr(model, "lm_head")
assert isinstance(model, RobertaEmbeddingModel)
assert isinstance(model._pooler, CLSPool)
def check_model(model):
assert isinstance(model, RobertaEmbeddingModel)
assert not hasattr(model, "lm_head")
assert isinstance(model._pooler, CLSPool)
vllm_model.apply_model(check_model)
assert output

View File

@ -33,10 +33,13 @@ def test_models(
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
def print_model(model):
print(model)
vllm_model.apply_model(print_model)
for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]

View File

@ -51,10 +51,13 @@ def test_models(
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
def print_model(model):
print(model)
vllm_model.apply_model(print_model)
for i in range(len(example_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]

View File

@ -73,10 +73,13 @@ def test_models(
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
def print_model(model):
print(model)
vllm_model.apply_model(print_model)
check_logprobs_close(
outputs_0_lst=hf_outputs,

View File

@ -5,7 +5,6 @@ import pytest
import torch
from PIL import Image
from vllm.entrypoints.llm import LLM
from vllm.multimodal.image import rescale_image_size
from vllm.multimodal.video import rescale_video_size, sample_frames_from_video
@ -69,7 +68,7 @@ class Qwen2VLPromptVideoEmbeddingInput(TypedDict):
def batch_make_image_embeddings(
image_batches: List[Union[Image.Image, List[Image.Image]]], processor,
llm: LLM) -> List[Qwen2VLPromptImageEmbeddingInput]:
llm: VllmRunner) -> List[Qwen2VLPromptImageEmbeddingInput]:
"""batched image embeddings for Qwen2-VL
This will infer all images' embeddings in a single batch,
@ -106,16 +105,18 @@ def batch_make_image_embeddings(
image_grid_thw = preprocess_result["image_grid_thw"]
# pixel values to embeddings & grid_thws
with torch.no_grad():
visual = llm.llm_engine.model_executor.driver_worker. \
model_runner.model.visual
def get_image_embeds(model):
with torch.no_grad():
visual = model.visual
pixel_values_on_device = pixel_values.to(visual.device,
dtype=visual.dtype)
image_grid_thw_on_device = image_grid_thw.to(visual.device,
dtype=torch.int64)
image_embeds = visual(pixel_values_on_device,
grid_thw=image_grid_thw_on_device)
pixel_values_on_device = pixel_values.to(visual.device,
dtype=visual.dtype)
image_grid_thw_on_device = image_grid_thw.to(visual.device,
dtype=torch.int64)
return visual(pixel_values_on_device,
grid_thw=image_grid_thw_on_device)
image_embeds = torch.concat(llm.apply_model(get_image_embeds))
# split into original batches
result: List[Qwen2VLPromptImageEmbeddingInput] = []
@ -150,7 +151,7 @@ def batch_make_image_embeddings(
def batch_make_video_embeddings(
video_batches: PromptVideoInput, processor,
llm: LLM) -> List[Qwen2VLPromptVideoEmbeddingInput]:
llm: VllmRunner) -> List[Qwen2VLPromptVideoEmbeddingInput]:
"""batched video embeddings for Qwen2-VL
A NDArray represents a single video's all frames.
@ -187,16 +188,18 @@ def batch_make_video_embeddings(
video_grid_thw = preprocess_result["video_grid_thw"]
# pixel values to embeddings & grid_thws
with torch.no_grad():
visual = llm.llm_engine.model_executor.driver_worker.\
model_runner.model.visual
def get_image_embeds(model):
with torch.no_grad():
visual = model.visual
pixel_values_on_device = pixel_values.to(visual.device,
dtype=visual.dtype)
video_grid_thw_on_device = video_grid_thw.to(visual.device,
dtype=torch.int64)
video_embeds = visual(pixel_values_on_device,
grid_thw=video_grid_thw_on_device)
pixel_values_on_device = pixel_values.to(visual.device,
dtype=visual.dtype)
video_grid_thw_on_device = video_grid_thw.to(visual.device,
dtype=torch.int64)
return visual(pixel_values_on_device,
grid_thw=video_grid_thw_on_device)
video_embeds = torch.concat(llm.apply_model(get_image_embeds))
# split into original batches
result: List[Qwen2VLPromptVideoEmbeddingInput] = []
@ -278,9 +281,9 @@ def run_embedding_input_test(
max_tokens,
num_logprobs=num_logprobs,
images=batch_make_image_embeddings(
images, processor, vllm_model.model) if images else None,
images, processor, vllm_model) if images else None,
videos=batch_make_video_embeddings(
videos, processor, vllm_model.model) if videos else None)
videos, processor, vllm_model) if videos else None)
for prompts, images, videos in inputs
]

View File

@ -24,10 +24,13 @@ def test_classification_models(
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.classify(example_prompts)
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
def print_model(model):
print(model)
vllm_model.apply_model(print_model)
with hf_runner(model,
dtype=dtype,

View File

@ -62,10 +62,13 @@ def test_models(
max_model_len=None,
**vllm_extra_kwargs) as vllm_model:
vllm_outputs = vllm_model.encode(example_prompts)
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)
def print_model(model):
print(model)
vllm_model.apply_model(print_model)
check_embeddings_close(
embeddings_0_lst=hf_outputs,

View File

@ -30,50 +30,55 @@ from vllm.platforms import current_platform
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
model_path, strategy, quant_type, shape_0, is_symmetric = model_args
with vllm_runner(model_path, enforce_eager=True) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
o_proj = layer.self_attn.o_proj
gate_up_proj = layer.mlp.gate_up_proj
down_proj = layer.mlp.down_proj
def check_model(model):
layer = model.model.layers[0]
# assert zp for symmetric and asymmetric cases
def zp_valid(zp: Optional[torch.Tensor]):
if is_symmetric:
return zp is None
qkv_proj = layer.self_attn.qkv_proj
o_proj = layer.self_attn.o_proj
gate_up_proj = layer.mlp.gate_up_proj
down_proj = layer.mlp.down_proj
return zp is not None and zp.dtype is torch.int32
# assert zp for symmetric and asymmetric cases
def zp_valid(zp: Optional[torch.Tensor]):
if is_symmetric:
return zp is None
assert zp_valid(qkv_proj.input_zero_point)
assert zp_valid(o_proj.input_zero_point)
assert zp_valid(gate_up_proj.input_zero_point)
assert zp_valid(down_proj.input_zero_point)
return zp is not None and zp.dtype is torch.int32
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(o_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(gate_up_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(down_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
assert zp_valid(qkv_proj.input_zero_point)
assert zp_valid(o_proj.input_zero_point)
assert zp_valid(gate_up_proj.input_zero_point)
assert zp_valid(down_proj.input_zero_point)
assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.scheme.is_static_input_scheme
expected_type = torch.int8
assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(o_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(gate_up_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(down_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
assert qkv_proj.weight.dtype is expected_type
assert o_proj.weight.dtype is expected_type
assert gate_up_proj.weight.dtype is expected_type
assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.scheme.is_static_input_scheme
expected_type = torch.int8
if qkv_proj.scheme.strategy == "tensor":
# Make sure it is a channelwise buffer
# After running process_weights_after_loading
assert len(qkv_proj.weight_scale.shape) == 2
assert qkv_proj.weight_scale.shape[0] == shape_0
assert qkv_proj.weight_scale.shape[1] == 1
assert qkv_proj.weight_scale.dtype is torch.float32
assert qkv_proj.input_scale.dtype is torch.float32
assert qkv_proj.weight.dtype is expected_type
assert o_proj.weight.dtype is expected_type
assert gate_up_proj.weight.dtype is expected_type
if qkv_proj.scheme.strategy == "tensor":
# Make sure it is a channelwise buffer
# After running process_weights_after_loading
assert len(qkv_proj.weight_scale.shape) == 2
assert qkv_proj.weight_scale.shape[0] == shape_0
assert qkv_proj.weight_scale.shape[1] == 1
assert qkv_proj.weight_scale.dtype is torch.float32
assert qkv_proj.input_scale.dtype is torch.float32
llm.apply_model(check_model)
output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
assert output
@ -129,16 +134,20 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
model_path, strategy = model_args
with vllm_runner(model_path, dtype=torch.float16) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
def check_model(model):
layer = model.model.layers[0]
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
assert not qkv_proj.scheme.is_static_input_scheme
assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.weight.dtype is torch.int8
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
assert not qkv_proj.scheme.is_static_input_scheme
assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.weight.dtype is torch.int8
llm.apply_model(check_model)
output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
assert output
@ -152,19 +161,24 @@ def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
model, strategy, group, pack_factor = wNa16_args
with vllm_runner(model) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsWNA16)
def check_model(model):
layer = model.model.layers[0]
assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.scheme.group_size == (-1 if group is None else group)
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsWNA16)
assert qkv_proj.weight_packed.dtype is torch.int32
assert qkv_proj.weight_scale.dtype is torch.float16
assert qkv_proj.scheme.pack_factor == pack_factor
assert qkv_proj.scheme.strategy == strategy
assert qkv_proj.scheme.group_size == (-1
if group is None else group)
assert qkv_proj.weight_packed.dtype is torch.int32
assert qkv_proj.weight_scale.dtype is torch.float16
assert qkv_proj.scheme.pack_factor == pack_factor
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
@ -173,14 +187,18 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
def test_compressed_tensors_w4a16_marlin24(vllm_runner):
model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
with vllm_runner(model_path) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
def check_model(model):
layer = model.model.layers[0]
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Sparse24)
assert qkv_proj.weight_packed.dtype is torch.int32
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Sparse24)
assert qkv_proj.weight_packed.dtype is torch.int32
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
@ -189,23 +207,27 @@ def test_compressed_tensors_w4a16_marlin24(vllm_runner):
def test_compressed_tensors_fp8(vllm_runner):
model_path = "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
with vllm_runner(model_path) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
def check_model(model):
layer = model.model.layers[0]
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(
qkv_proj.scheme,
(CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8))
qkv_proj = layer.self_attn.qkv_proj
assert qkv_proj.input_scale.dtype is torch.float32
assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(
qkv_proj.scheme,
(CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8))
if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8):
assert len(qkv_proj.input_scale.shape) == 0
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
assert qkv_proj.weight_scale.dtype is torch.float32
assert len(qkv_proj.weight_scale.shape) == 0
assert qkv_proj.input_scale.dtype is torch.float32
if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8):
assert len(qkv_proj.input_scale.shape) == 0
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
assert qkv_proj.weight_scale.dtype is torch.float32
assert len(qkv_proj.weight_scale.shape) == 0
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output
@ -248,12 +270,15 @@ def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy):
def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
_test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
_test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
@ -273,12 +298,15 @@ def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert qkv_proj.scheme.weights_dtype == torch.int8
_test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert qkv_proj.scheme.weights_dtype == torch.int8
_test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)
@ -293,20 +321,24 @@ def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
model = args_2of4
with vllm_runner(model) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensors24)
def check_model(model):
layer = model.model.layers[0]
assert qkv_proj.scheme.weight_quant is None
assert qkv_proj.scheme.input_quant is None
assert not qkv_proj.scheme.quantized
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501
assert sparsity_map.get("Linear").format == "dense"
assert sparsity_map.get("Linear").sparsity_structure == "2:4"
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method,
CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensors24)
assert qkv_proj.scheme.weight_quant is None
assert qkv_proj.scheme.input_quant is None
assert not qkv_proj.scheme.quantized
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501
assert sparsity_map.get("Linear").format == "dense"
assert sparsity_map.get("Linear").sparsity_structure == "2:4"
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
print(output)

View File

@ -49,13 +49,17 @@ KV_CACHE_MODELS = [
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str):
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
attn = model.model.layers[0].self_attn.attn
assert isinstance(attn.quant_method, Fp8KVCacheMethod)
# NOTE: it is valid for scales to be 1.0 (default value), but we know
# these checkpoints have scales < 1.0
assert 0.0 < attn._k_scale < 1.0
assert 0.0 < attn._v_scale < 1.0
def check_model(model):
attn = model.model.layers[0].self_attn.attn
assert isinstance(attn.quant_method, Fp8KVCacheMethod)
# NOTE: it is valid for scales to be 1.0 (default value), but
# we know these checkpoints have scales < 1.0
assert 0.0 < attn._k_scale < 1.0
assert 0.0 < attn._v_scale < 1.0
llm.apply_model(check_model)
# note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy
@ -77,22 +81,24 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
quantization="fp8",
kv_cache_dtype=kv_cache_dtype) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
fc1 = model.model.decoder.layers[0].fc1
assert isinstance(fc1.quant_method, Fp8LinearMethod)
if kv_cache_dtype == "fp8":
attn = model.model.decoder.layers[0].self_attn.attn
assert isinstance(attn.quant_method, Fp8KVCacheMethod)
assert attn._k_scale == 1.0
assert attn._v_scale == 1.0
def check_model(model):
fc1 = model.model.decoder.layers[0].fc1
assert isinstance(fc1.quant_method, Fp8LinearMethod)
if kv_cache_dtype == "fp8":
attn = model.model.decoder.layers[0].self_attn.attn
assert isinstance(attn.quant_method, Fp8KVCacheMethod)
assert attn._k_scale == 1.0
assert attn._v_scale == 1.0
if current_platform.has_device_capability(89) and not force_marlin:
# For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fn
else:
# For GPUs without hardware support, we pack the fp8 weights
# for weight-only quantization using Marlin kernels
assert fc1.weight.dtype == torch.int32
if current_platform.has_device_capability(89) and not force_marlin:
# For GPUs with hardware support, we keep weights in fp8
assert fc1.weight.dtype == torch.float8_e4m3fn
else:
# For GPUs without hardware support, we pack the fp8 weights
# for weight-only quantization using Marlin kernels
assert fc1.weight.dtype == torch.int32
llm.apply_model(check_model)
@pytest.mark.skipif(not is_quant_method_supported("fp8"),

View File

@ -28,20 +28,23 @@ def test_lm_head(
model_lm_head_quant: Tuple[str, bool],
) -> None:
model, lm_head_quantized = model_lm_head_quant
vllm_model = vllm_runner(model, dtype=torch.float16, max_model_len=2048)
lm_head_layer = (vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model.lm_head)
with vllm_runner(model, dtype=torch.float16,
max_model_len=2048) as vllm_model:
if lm_head_quantized:
assert isinstance(
lm_head_layer.linear_method,
(GPTQLinearMethod, GPTQMarlinLinearMethod, MarlinLinearMethod))
else:
assert isinstance(lm_head_layer.linear_method,
UnquantizedEmbeddingMethod)
def check_model(model):
lm_head_layer = model.lm_head
print(
vllm_model.generate_greedy(prompts=["Hello my name is"],
max_tokens=10)[0][1])
del vllm_model
if lm_head_quantized:
assert isinstance(lm_head_layer.linear_method,
(GPTQLinearMethod, GPTQMarlinLinearMethod,
MarlinLinearMethod))
else:
assert isinstance(lm_head_layer.linear_method,
UnquantizedEmbeddingMethod)
vllm_model.apply_model(check_model)
print(
vllm_model.generate_greedy(prompts=["Hello my name is"],
max_tokens=10)[0][1])

View File

@ -12,19 +12,22 @@ from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
def test_quark_fp8(vllm_runner):
model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
with vllm_runner(model_path) as llm:
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
def check_model(model):
layer = model.model.layers[0]
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)
qkv_proj = layer.self_attn.qkv_proj
if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
assert len(qkv_proj.input_scale.shape) == 0
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
#assert qkv_proj.weight.dtype is torch.float8_e4m3fnuz
assert len(qkv_proj.weight_scale.shape) == 0
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)
if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
assert len(qkv_proj.input_scale.shape) == 0
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
#assert qkv_proj.weight.dtype is torch.float8_e4m3fnuz
assert len(qkv_proj.weight_scale.shape) == 0
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=20)
assert output

View File

@ -3,6 +3,7 @@ import json
import os
import pathlib
import subprocess
from functools import partial
from unittest.mock import MagicMock, patch
import openai
@ -24,7 +25,6 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
# yapf: enable
from vllm.utils import PlaceholderModule, import_from_path
from ..conftest import VllmRunner
from ..utils import VLLM_PATH, RemoteOpenAIServer
from .conftest import retry_until_skip
@ -58,16 +58,6 @@ def is_curl_installed():
return False
def get_torch_model(vllm_runner: VllmRunner):
return vllm_runner \
.model \
.llm_engine \
.model_executor \
.driver_worker \
.model_runner \
.model
def write_keyfile(keyfile_path: str):
encryption_params = EncryptionParams.random()
pathlib.Path(keyfile_path).parent.mkdir(parents=True, exist_ok=True)
@ -121,8 +111,10 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
config_for_serializing = TensorizerConfig(tensorizer_uri=model_path,
encryption_keyfile=key_path)
serialize_vllm_model(get_torch_model(vllm_model),
config_for_serializing)
vllm_model.apply_model(
partial(serialize_vllm_model,
tensorizer_config=config_for_serializing))
config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
encryption_keyfile=key_path)
@ -175,8 +167,10 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
with vllm_runner(model_ref, ) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors")
serialize_vllm_model(get_torch_model(vllm_model),
TensorizerConfig(tensorizer_uri=model_path))
vllm_model.apply_model(
partial(
serialize_vllm_model,
tensorizer_config=TensorizerConfig(tensorizer_uri=model_path)))
with vllm_runner(
model_ref,
@ -215,8 +209,10 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
with vllm_runner(model_ref, ) as vllm_model:
model_path = tmp_path / (model_ref + ".tensors")
serialize_vllm_model(get_torch_model(vllm_model),
TensorizerConfig(tensorizer_uri=model_path))
vllm_model.apply_model(
partial(
serialize_vllm_model,
tensorizer_config=TensorizerConfig(tensorizer_uri=model_path)))
model_loader_extra_config = {
"tensorizer_uri": str(model_path),
@ -337,7 +333,9 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
with vllm_runner(model_ref) as vllm_model:
outputs = vllm_model.generate(prompts, sampling_params)
serialize_vllm_model(get_torch_model(vllm_model), config)
vllm_model.apply_model(
partial(serialize_vllm_model, tensorizer_config=config))
assert is_vllm_tensorized(config)

View File

@ -5,10 +5,10 @@ from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional)
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Type, Union, cast, overload
from typing import Set, Type, Union, cast, overload
import torch
from typing_extensions import TypeVar, deprecated
@ -1818,17 +1818,6 @@ class LLMEngine:
def stop_profile(self) -> None:
self.model_executor.stop_profile()
def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
"""
See LLM.collective_rpc for more details.
"""
return self.model_executor.collective_rpc(method, timeout, args,
kwargs)
def check_health(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()

View File

@ -5,8 +5,9 @@ from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence,
Tuple, Type, Union, cast, overload)
import cloudpickle
import torch.nn as nn
from tqdm import tqdm
from typing_extensions import deprecated
from typing_extensions import TypeVar, deprecated
from vllm import envs
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
@ -42,6 +43,8 @@ from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
logger = init_logger(__name__)
_R = TypeVar("_R", default=Any)
class LLM:
"""An LLM for generating texts from given prompts and sampling parameters.
@ -464,25 +467,42 @@ class LLM:
return self.engine_class.validate_outputs(outputs, RequestOutput)
def collective_rpc(self,
method: Union[str, Callable],
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
kwargs: Optional[Dict[str, Any]] = None) -> List[_R]:
"""
Run a method on all workers, with homogeneous arguments.
The main extension point for the LLM entrypoint.
Users can provide custom worker class through `worker_cls`
argument, and implement new methods in the worker class.
Then, users can call the new methods through this API.
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
The method can also be a callable, which will be serialized
and sent to all workers to execute.
If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
Execute an RPC call on all workers.
Args:
method: Name of the worker method to execute, or a callable that
is serialized and sent to all workers to execute.
If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
timeout: Maximum time in seconds to wait for execution. Raises a
:exc:`TimeoutError` on timeout. `None` means wait indefinitely.
args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method.
Returns:
A list containing the results from each worker.
Note:
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
executor = self.llm_engine.model_executor
return executor.collective_rpc(method, timeout, args, kwargs)
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
"""
Run a function directly on the model inside each worker,
returning the result for each of them.
"""
executor = self.llm_engine.model_executor
return executor.apply_model(func)
def beam_search(
self,

View File

@ -3,6 +3,9 @@ from abc import ABC, abstractmethod
from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
Union)
import torch.nn as nn
from typing_extensions import TypeVar
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -11,9 +14,12 @@ from vllm.platforms import current_platform
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import ExecuteModelRequest, PoolerOutput
from vllm.utils import make_async
from vllm.worker.worker_base import WorkerBase
logger = init_logger(__name__)
_R = TypeVar("_R", default=Any)
class ExecutorBase(ABC):
"""Base class for all executors.
@ -44,22 +50,37 @@ class ExecutorBase(ABC):
@abstractmethod
def _init_executor(self) -> None:
pass
raise NotImplementedError
@abstractmethod
def collective_rpc(self,
method: Union[str, Callable],
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
kwargs: Optional[Dict[str, Any]] = None) -> List[_R]:
"""
The main interface of the executor to run a method on all workers,
with homogeneous arguments.
If the args are heterogeneous, then we can pack them into a list,
and unpack them in the method of every worker, because every worker
knows their own rank.
Execute an RPC call on all workers.
Args:
method: Name of the worker method to execute, or a callable that
is serialized and sent to all workers to execute.
If the method is a callable, it should accept an additional
`self` argument, in addition to the arguments passed in `args`
and `kwargs`. The `self` argument will be the worker object.
timeout: Maximum time in seconds to wait for execution. Raises a
:exc:`TimeoutError` on timeout. `None` means wait indefinitely.
args: Positional arguments to pass to the worker method.
kwargs: Keyword arguments to pass to the worker method.
Returns:
A list containing the results from each worker.
Note:
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
pass
raise NotImplementedError
def determine_num_available_blocks(self) -> Tuple[int, int]:
"""Determine the number of available blocks for the GPU KV cache and
@ -97,6 +118,17 @@ class ExecutorBase(ABC):
self.collective_rpc("initialize_cache",
args=(num_gpu_blocks, num_cpu_blocks))
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
"""
Run a function directly on the model inside each worker,
returning the result for each of them.
"""
def rpc_func(worker: WorkerBase) -> _R:
return func(worker.get_model())
return self.collective_rpc(rpc_func)
def execute_model(
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:

View File

@ -148,7 +148,7 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
async_run_tensor_parallel_workers_only: bool = False,
max_concurrent_workers: Optional[int] = None,
**kwargs,
) -> Any:
) -> List[Any]:
"""Runs the given method on all workers.
Args:

View File

@ -459,16 +459,7 @@ def tensorize_vllm_model(engine_args: EngineArgs,
stream.write(encryption_params.key)
engine = LLMEngine.from_engine_args(engine_args)
if tensorizer_config._is_sharded:
# if the engine is a distributed engine (for tensor parallel) then each
# worker shard needs to serialize its part of the model.
engine.model_executor._run_workers(
"save_tensorized_model",
tensorizer_config=tensorizer_config,
)
else:
# with a single worker, we can get to the underlying model directly
serialize_vllm_model(
engine.model_executor.driver_worker.model_runner.model,
tensorizer_config,
)
engine.model_executor.collective_rpc(
"save_tensorized_model",
kwargs=dict(tensorizer_config=tensorizer_config),
)

View File

@ -2,6 +2,7 @@ import weakref
from typing import List, Optional, Set, Tuple
import torch
import torch.nn as nn
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.sequence import ExecuteModelRequest
@ -10,6 +11,10 @@ from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
from vllm.spec_decode.top1_proposer import Top1Proposer
class _DummyModel(nn.Module):
pass
class NGramWorker(NonLLMProposerWorkerBase):
"""NGramWorker provides a light drafter without need for model.
@ -36,7 +41,6 @@ class NGramWorker(NonLLMProposerWorkerBase):
def init_device(self):
self.device = torch.device(f"{self.device_type}:{self.local_rank}")
self.load_model = lambda *args, **kwargs: None
# Current NGramWorker only supports Top1Proposer
self._proposer = Top1Proposer(
@ -45,6 +49,12 @@ class NGramWorker(NonLLMProposerWorkerBase):
vocab_size=self.vocab_size,
)
def load_model(self) -> None:
pass # Dummy
def get_model(self) -> nn.Module:
return _DummyModel()
def sampler_output(
self,
execute_model_req: ExecuteModelRequest,

View File

@ -1,6 +1,7 @@
from typing import List, Optional, Set, Tuple
import torch
import torch.nn as nn
from vllm.distributed.parallel_state import (get_tp_group,
init_model_parallel_group,
@ -15,6 +16,10 @@ from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
logger = init_logger(__name__)
class _DummyModel(nn.Module):
pass
class SmallerTpProposerWorker(ProposerWorkerBase):
"""Class which allows a speculative draft model to run with smaller tensor
parallel degree than target model.
@ -139,6 +144,13 @@ class SmallerTpProposerWorker(ProposerWorkerBase):
return self._worker.get_spec_proposals(
execute_model_req, seq_ids_with_bonus_token_in_last_step)
def get_model(self) -> nn.Module:
if self._is_dummy:
return _DummyModel()
with self._patch_tensor_parallel_group():
return self._worker.get_model()
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None

View File

@ -4,6 +4,7 @@ from functools import cached_property
from typing import Any, Dict, List, Optional, Set, Tuple, Type
import torch
import torch.nn as nn
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
from vllm.distributed.communication_op import broadcast_tensor_dict
@ -403,6 +404,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks)
def get_model(self) -> nn.Module:
return self.scorer_worker.get_model()
@torch.inference_mode()
def execute_model(
self,

View File

@ -94,22 +94,12 @@ class MultiprocExecutor(Executor):
timeout: Optional[float] = None,
args: Tuple = (),
kwargs: Optional[Dict] = None) -> List[Any]:
"""
Execute an RPC call on workers.
Args:
method: Name of the worker method to execute
timeout: Maximum time in seconds to wait for execution. Rases a
TimeoutError on timeout. None means wait indefinitely.
args: Positional arguments to pass to the worker method
kwargs: Keyword arguments to pass to the worker method
Returns:
List of results from each worker
"""
start_time = time.monotonic()
kwargs = kwargs or {}
# NOTE: If the args are heterogeneous, then we pack them into a list,
# and unpack them in the method of every worker, because every worker
# knows their own rank.
try:
if isinstance(method, str):
send_method = method

View File

@ -689,6 +689,9 @@ class GPUModelRunner:
encoder_outputs.append(encoder_output[start_idx:end_idx])
return encoder_outputs
def get_model(self) -> nn.Module:
return self.model
@torch.inference_mode()
def execute_model(
self,

View File

@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional
import torch
import torch.distributed
import torch.nn as nn
import vllm.envs as envs
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
@ -176,6 +177,9 @@ class Worker:
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
@torch.inference_mode()
def execute_model(
self,

View File

@ -509,6 +509,9 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
)
self.model = self.lora_manager.create_lora_manager(self.model)
def get_model(self) -> nn.Module:
return self.model
def _prepare_model_input_tensors(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],

View File

@ -21,6 +21,7 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
import habana_frameworks.torch as htorch
import habana_frameworks.torch.internal.bridge_config as bc
import torch
import torch.nn as nn
from vllm_hpu_extension.ops import LoraMask as LoraMask
from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
HabanaMemoryProfiler, format_bytes)
@ -676,6 +677,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
msg = f"Loading model weights took in total {m.get_summary_string()}"
logger.info(msg)
def get_model(self) -> nn.Module:
return self.model
def _use_graphs(self, batch_size, seq_len, is_prompt):
if self.enforce_eager:
return False

View File

@ -1176,6 +1176,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=backend)
def get_model(self) -> nn.Module:
return self.model
def save_sharded_state(
self,
path: str,

View File

@ -7,6 +7,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
Optional, Type, TypeVar)
import torch
import torch.nn as nn
from torch import is_tensor
from vllm.config import VllmConfig
@ -264,6 +265,10 @@ class ModelRunnerBase(ABC, Generic[T]):
"""
raise NotImplementedError
@abstractmethod
def get_model(self) -> nn.Module:
raise NotImplementedError
def execute_model(
self,
model_input: T,
@ -297,9 +302,9 @@ class ModelRunnerWrapperBase:
def __init__(
self,
moderl_runner: ModelRunnerBase,
model_runner: ModelRunnerBase,
) -> None:
self.model_runner: ModelRunnerBase = moderl_runner
self.model_runner: ModelRunnerBase = model_runner
def __getattr__(self, attr):
return getattr(self.model_runner, attr)

View File

@ -113,6 +113,9 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
raise NotImplementedError(
"Supports only Transformer-NeuronX based models.")
def get_model(self) -> nn.Module:
return self.model
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],

View File

@ -84,6 +84,9 @@ class OpenVINOModelRunner(ModelRunnerBase):
kv_cache_dtype=self.kv_cache_dtype,
ov_core=self.ov_core)
def get_model(self) -> nn.Module:
return self.model
def _prepare_model_input(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],

View File

@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple
import openvino as ov
import torch
import torch.distributed
import torch.nn as nn
import vllm.envs as envs
from vllm.attention import get_attn_backend
@ -362,6 +363,9 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
) -> None:
self.cache_engine.copy(blocks_to_copy) # type: ignore
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
@torch.inference_mode()
def execute_model(
self,

View File

@ -158,6 +158,9 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
fullgraph=True,
dynamic=False)
def get_model(self) -> nn.Module:
return self.model.model
def _dummy_run(
self,
batch_size: int,

View File

@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
import cloudpickle
import torch
import torch.nn as nn
from vllm.config import ObservabilityConfig, VllmConfig
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
@ -90,6 +91,11 @@ class WorkerBase(ABC):
if output is None:
return None
@abstractmethod
def get_model(self) -> nn.Module:
raise NotImplementedError
@abstractmethod
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
@ -147,6 +153,9 @@ class DelegateWorkerBase(WorkerBase):
num_cpu_blocks: int) -> None:
self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
def get_model(self) -> nn.Module:
return self.worker.get_model()
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None
@ -363,6 +372,9 @@ class LocalOrDistributedWorkerBase(WorkerBase):
else:
return self._get_worker_input_from_broadcast()
def get_model(self) -> nn.Module:
return self.model_runner.get_model()
def execute_model(
self,
execute_model_req: Optional[ExecuteModelRequest] = None,

View File

@ -416,6 +416,9 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
logger.info("Loading model weights took %.4f GB",
self.model_memory_usage / float(2**30))
def get_model(self) -> nn.Module:
return self.model
@property
def vocab_size(self) -> int:
return self.model_config.get_vocab_size()