[Core] Interface for accessing model from VllmRunner
(#10353)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
83609791d2
commit
59a0192fb9
@ -244,6 +244,7 @@ def video_assets() -> _VideoAssets:
|
|||||||
|
|
||||||
|
|
||||||
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
|
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
|
||||||
|
_R = TypeVar("_R")
|
||||||
|
|
||||||
|
|
||||||
class HfRunner:
|
class HfRunner:
|
||||||
@ -930,6 +931,10 @@ class VllmRunner:
|
|||||||
req_outputs = self.model.score(text_1, text_2)
|
req_outputs = self.model.score(text_1, text_2)
|
||||||
return [req_output.outputs.score for req_output in req_outputs]
|
return [req_output.outputs.score for req_output in req_outputs]
|
||||||
|
|
||||||
|
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
|
||||||
|
executor = self.model.llm_engine.model_executor
|
||||||
|
return executor.apply_model(func)
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -51,7 +51,9 @@ def test_custom_executor(model, tmp_path):
|
|||||||
assert not os.path.exists(".marker")
|
assert not os.path.exists(".marker")
|
||||||
|
|
||||||
engine_args = EngineArgs(
|
engine_args = EngineArgs(
|
||||||
model=model, distributed_executor_backend=CustomUniExecutor)
|
model=model,
|
||||||
|
distributed_executor_backend=CustomUniExecutor,
|
||||||
|
)
|
||||||
engine = LLMEngine.from_engine_args(engine_args)
|
engine = LLMEngine.from_engine_args(engine_args)
|
||||||
sampling_params = SamplingParams(max_tokens=1)
|
sampling_params = SamplingParams(max_tokens=1)
|
||||||
|
|
||||||
|
@ -25,13 +25,12 @@ def test_model_loading_with_params(vllm_runner):
|
|||||||
with vllm_runner(model_name=MODEL_NAME,
|
with vllm_runner(model_name=MODEL_NAME,
|
||||||
revision=REVISION,
|
revision=REVISION,
|
||||||
dtype="float16",
|
dtype="float16",
|
||||||
max_model_len=MAX_MODEL_LEN) as model:
|
max_model_len=MAX_MODEL_LEN) as vllm_model:
|
||||||
output = model.encode("Write a short story about a robot that"
|
output = vllm_model.encode("Write a short story about a robot that"
|
||||||
" dreams for the first time.\n")
|
" dreams for the first time.\n")
|
||||||
|
|
||||||
model_config = model.model.llm_engine.model_config
|
model_config = vllm_model.model.llm_engine.model_config
|
||||||
|
model_tokenizer = vllm_model.model.llm_engine.tokenizer
|
||||||
model_tokenizer = model.model.llm_engine.tokenizer
|
|
||||||
|
|
||||||
# asserts on the bert model config file
|
# asserts on the bert model config file
|
||||||
assert model_config.encoder_config["max_seq_length"] == 512
|
assert model_config.encoder_config["max_seq_length"] == 512
|
||||||
@ -46,11 +45,13 @@ def test_model_loading_with_params(vllm_runner):
|
|||||||
assert model_tokenizer.tokenizer_config["do_lower_case"]
|
assert model_tokenizer.tokenizer_config["do_lower_case"]
|
||||||
assert model_tokenizer.tokenizer.model_max_length == 512
|
assert model_tokenizer.tokenizer.model_max_length == 512
|
||||||
|
|
||||||
model = model.model.llm_engine.model_executor\
|
def check_model(model):
|
||||||
.driver_worker.model_runner.model
|
assert isinstance(model, BertEmbeddingModel)
|
||||||
assert isinstance(model, BertEmbeddingModel)
|
assert model._pooler.pooling_type == PoolingType.CLS
|
||||||
assert model._pooler.pooling_type == PoolingType.CLS
|
assert model._pooler.normalize
|
||||||
assert model._pooler.normalize
|
|
||||||
|
vllm_model.apply_model(check_model)
|
||||||
|
|
||||||
# assert output
|
# assert output
|
||||||
assert output
|
assert output
|
||||||
|
|
||||||
@ -64,13 +65,12 @@ def test_roberta_model_loading_with_params(vllm_runner):
|
|||||||
with vllm_runner(model_name=MODEL_NAME_ROBERTA,
|
with vllm_runner(model_name=MODEL_NAME_ROBERTA,
|
||||||
revision=REVISION_ROBERTA,
|
revision=REVISION_ROBERTA,
|
||||||
dtype="float16",
|
dtype="float16",
|
||||||
max_model_len=MAX_MODEL_LEN) as model:
|
max_model_len=MAX_MODEL_LEN) as vllm_model:
|
||||||
output = model.encode("Write a short story about a robot that"
|
output = vllm_model.encode("Write a short story about a robot that"
|
||||||
" dreams for the first time.\n")
|
" dreams for the first time.\n")
|
||||||
|
|
||||||
model_config = model.model.llm_engine.model_config
|
model_config = vllm_model.model.llm_engine.model_config
|
||||||
|
model_tokenizer = vllm_model.model.llm_engine.tokenizer
|
||||||
model_tokenizer = model.model.llm_engine.tokenizer
|
|
||||||
|
|
||||||
# asserts on the bert model config file
|
# asserts on the bert model config file
|
||||||
assert model_config.encoder_config["max_seq_length"] == 512
|
assert model_config.encoder_config["max_seq_length"] == 512
|
||||||
@ -84,11 +84,12 @@ def test_roberta_model_loading_with_params(vllm_runner):
|
|||||||
assert model_tokenizer.tokenizer_id == "intfloat/multilingual-e5-large"
|
assert model_tokenizer.tokenizer_id == "intfloat/multilingual-e5-large"
|
||||||
assert not model_tokenizer.tokenizer_config["do_lower_case"]
|
assert not model_tokenizer.tokenizer_config["do_lower_case"]
|
||||||
|
|
||||||
model = model.model.llm_engine.model_executor\
|
def check_model(model):
|
||||||
.driver_worker.model_runner.model
|
assert isinstance(model, RobertaEmbeddingModel)
|
||||||
assert isinstance(model, RobertaEmbeddingModel)
|
assert model._pooler.pooling_type == PoolingType.MEAN
|
||||||
assert model._pooler.pooling_type == PoolingType.MEAN
|
assert model._pooler.normalize
|
||||||
assert model._pooler.normalize
|
|
||||||
|
vllm_model.apply_model(check_model)
|
||||||
|
|
||||||
# assert output
|
# assert output
|
||||||
assert output
|
assert output
|
||||||
@ -103,17 +104,18 @@ def test_facebook_roberta_model_loading_with_params(vllm_runner):
|
|||||||
model_name = "FacebookAI/roberta-base"
|
model_name = "FacebookAI/roberta-base"
|
||||||
with vllm_runner(model_name=model_name,
|
with vllm_runner(model_name=model_name,
|
||||||
dtype="float16",
|
dtype="float16",
|
||||||
max_model_len=MAX_MODEL_LEN) as model:
|
max_model_len=MAX_MODEL_LEN) as vllm_model:
|
||||||
output = model.encode("Write a short story about a robot that"
|
output = vllm_model.encode("Write a short story about a robot that"
|
||||||
" dreams for the first time.\n")
|
" dreams for the first time.\n")
|
||||||
|
|
||||||
model_tokenizer = model.model.llm_engine.tokenizer
|
model_tokenizer = vllm_model.model.llm_engine.tokenizer
|
||||||
assert model_tokenizer.tokenizer_id == model_name
|
assert model_tokenizer.tokenizer_id == model_name
|
||||||
|
|
||||||
model = model.model.llm_engine.model_executor\
|
def check_model(model):
|
||||||
.driver_worker.model_runner.model
|
assert isinstance(model, RobertaEmbeddingModel)
|
||||||
assert not hasattr(model, "lm_head")
|
assert not hasattr(model, "lm_head")
|
||||||
assert isinstance(model, RobertaEmbeddingModel)
|
assert isinstance(model._pooler, CLSPool)
|
||||||
assert isinstance(model._pooler, CLSPool)
|
|
||||||
|
vllm_model.apply_model(check_model)
|
||||||
|
|
||||||
assert output
|
assert output
|
||||||
|
@ -33,10 +33,13 @@ def test_models(
|
|||||||
|
|
||||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|
||||||
# This test is for verifying whether the model's extra_repr
|
# This test is for verifying whether the model's extra_repr
|
||||||
# can be printed correctly.
|
# can be printed correctly.
|
||||||
print(vllm_model.model.llm_engine.model_executor.driver_worker.
|
def print_model(model):
|
||||||
model_runner.model)
|
print(model)
|
||||||
|
|
||||||
|
vllm_model.apply_model(print_model)
|
||||||
|
|
||||||
for i in range(len(example_prompts)):
|
for i in range(len(example_prompts)):
|
||||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||||
|
@ -51,10 +51,13 @@ def test_models(
|
|||||||
|
|
||||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|
||||||
# This test is for verifying whether the model's extra_repr
|
# This test is for verifying whether the model's extra_repr
|
||||||
# can be printed correctly.
|
# can be printed correctly.
|
||||||
print(vllm_model.model.llm_engine.model_executor.driver_worker.
|
def print_model(model):
|
||||||
model_runner.model)
|
print(model)
|
||||||
|
|
||||||
|
vllm_model.apply_model(print_model)
|
||||||
|
|
||||||
for i in range(len(example_prompts)):
|
for i in range(len(example_prompts)):
|
||||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||||
|
@ -73,10 +73,13 @@ def test_models(
|
|||||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
example_prompts, max_tokens, num_logprobs)
|
example_prompts, max_tokens, num_logprobs)
|
||||||
|
|
||||||
# This test is for verifying whether the model's extra_repr
|
# This test is for verifying whether the model's extra_repr
|
||||||
# can be printed correctly.
|
# can be printed correctly.
|
||||||
print(vllm_model.model.llm_engine.model_executor.driver_worker.
|
def print_model(model):
|
||||||
model_runner.model)
|
print(model)
|
||||||
|
|
||||||
|
vllm_model.apply_model(print_model)
|
||||||
|
|
||||||
check_logprobs_close(
|
check_logprobs_close(
|
||||||
outputs_0_lst=hf_outputs,
|
outputs_0_lst=hf_outputs,
|
||||||
|
@ -5,7 +5,6 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from vllm.entrypoints.llm import LLM
|
|
||||||
from vllm.multimodal.image import rescale_image_size
|
from vllm.multimodal.image import rescale_image_size
|
||||||
from vllm.multimodal.video import rescale_video_size, sample_frames_from_video
|
from vllm.multimodal.video import rescale_video_size, sample_frames_from_video
|
||||||
|
|
||||||
@ -69,7 +68,7 @@ class Qwen2VLPromptVideoEmbeddingInput(TypedDict):
|
|||||||
|
|
||||||
def batch_make_image_embeddings(
|
def batch_make_image_embeddings(
|
||||||
image_batches: List[Union[Image.Image, List[Image.Image]]], processor,
|
image_batches: List[Union[Image.Image, List[Image.Image]]], processor,
|
||||||
llm: LLM) -> List[Qwen2VLPromptImageEmbeddingInput]:
|
llm: VllmRunner) -> List[Qwen2VLPromptImageEmbeddingInput]:
|
||||||
"""batched image embeddings for Qwen2-VL
|
"""batched image embeddings for Qwen2-VL
|
||||||
|
|
||||||
This will infer all images' embeddings in a single batch,
|
This will infer all images' embeddings in a single batch,
|
||||||
@ -106,16 +105,18 @@ def batch_make_image_embeddings(
|
|||||||
image_grid_thw = preprocess_result["image_grid_thw"]
|
image_grid_thw = preprocess_result["image_grid_thw"]
|
||||||
|
|
||||||
# pixel values to embeddings & grid_thws
|
# pixel values to embeddings & grid_thws
|
||||||
with torch.no_grad():
|
def get_image_embeds(model):
|
||||||
visual = llm.llm_engine.model_executor.driver_worker. \
|
with torch.no_grad():
|
||||||
model_runner.model.visual
|
visual = model.visual
|
||||||
|
|
||||||
pixel_values_on_device = pixel_values.to(visual.device,
|
pixel_values_on_device = pixel_values.to(visual.device,
|
||||||
dtype=visual.dtype)
|
dtype=visual.dtype)
|
||||||
image_grid_thw_on_device = image_grid_thw.to(visual.device,
|
image_grid_thw_on_device = image_grid_thw.to(visual.device,
|
||||||
dtype=torch.int64)
|
dtype=torch.int64)
|
||||||
image_embeds = visual(pixel_values_on_device,
|
return visual(pixel_values_on_device,
|
||||||
grid_thw=image_grid_thw_on_device)
|
grid_thw=image_grid_thw_on_device)
|
||||||
|
|
||||||
|
image_embeds = torch.concat(llm.apply_model(get_image_embeds))
|
||||||
|
|
||||||
# split into original batches
|
# split into original batches
|
||||||
result: List[Qwen2VLPromptImageEmbeddingInput] = []
|
result: List[Qwen2VLPromptImageEmbeddingInput] = []
|
||||||
@ -150,7 +151,7 @@ def batch_make_image_embeddings(
|
|||||||
|
|
||||||
def batch_make_video_embeddings(
|
def batch_make_video_embeddings(
|
||||||
video_batches: PromptVideoInput, processor,
|
video_batches: PromptVideoInput, processor,
|
||||||
llm: LLM) -> List[Qwen2VLPromptVideoEmbeddingInput]:
|
llm: VllmRunner) -> List[Qwen2VLPromptVideoEmbeddingInput]:
|
||||||
"""batched video embeddings for Qwen2-VL
|
"""batched video embeddings for Qwen2-VL
|
||||||
|
|
||||||
A NDArray represents a single video's all frames.
|
A NDArray represents a single video's all frames.
|
||||||
@ -187,16 +188,18 @@ def batch_make_video_embeddings(
|
|||||||
video_grid_thw = preprocess_result["video_grid_thw"]
|
video_grid_thw = preprocess_result["video_grid_thw"]
|
||||||
|
|
||||||
# pixel values to embeddings & grid_thws
|
# pixel values to embeddings & grid_thws
|
||||||
with torch.no_grad():
|
def get_image_embeds(model):
|
||||||
visual = llm.llm_engine.model_executor.driver_worker.\
|
with torch.no_grad():
|
||||||
model_runner.model.visual
|
visual = model.visual
|
||||||
|
|
||||||
pixel_values_on_device = pixel_values.to(visual.device,
|
pixel_values_on_device = pixel_values.to(visual.device,
|
||||||
dtype=visual.dtype)
|
dtype=visual.dtype)
|
||||||
video_grid_thw_on_device = video_grid_thw.to(visual.device,
|
video_grid_thw_on_device = video_grid_thw.to(visual.device,
|
||||||
dtype=torch.int64)
|
dtype=torch.int64)
|
||||||
video_embeds = visual(pixel_values_on_device,
|
return visual(pixel_values_on_device,
|
||||||
grid_thw=video_grid_thw_on_device)
|
grid_thw=video_grid_thw_on_device)
|
||||||
|
|
||||||
|
video_embeds = torch.concat(llm.apply_model(get_image_embeds))
|
||||||
|
|
||||||
# split into original batches
|
# split into original batches
|
||||||
result: List[Qwen2VLPromptVideoEmbeddingInput] = []
|
result: List[Qwen2VLPromptVideoEmbeddingInput] = []
|
||||||
@ -278,9 +281,9 @@ def run_embedding_input_test(
|
|||||||
max_tokens,
|
max_tokens,
|
||||||
num_logprobs=num_logprobs,
|
num_logprobs=num_logprobs,
|
||||||
images=batch_make_image_embeddings(
|
images=batch_make_image_embeddings(
|
||||||
images, processor, vllm_model.model) if images else None,
|
images, processor, vllm_model) if images else None,
|
||||||
videos=batch_make_video_embeddings(
|
videos=batch_make_video_embeddings(
|
||||||
videos, processor, vllm_model.model) if videos else None)
|
videos, processor, vllm_model) if videos else None)
|
||||||
for prompts, images, videos in inputs
|
for prompts, images, videos in inputs
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -24,10 +24,13 @@ def test_classification_models(
|
|||||||
) -> None:
|
) -> None:
|
||||||
with vllm_runner(model, dtype=dtype) as vllm_model:
|
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||||
vllm_outputs = vllm_model.classify(example_prompts)
|
vllm_outputs = vllm_model.classify(example_prompts)
|
||||||
|
|
||||||
# This test is for verifying whether the model's extra_repr
|
# This test is for verifying whether the model's extra_repr
|
||||||
# can be printed correctly.
|
# can be printed correctly.
|
||||||
print(vllm_model.model.llm_engine.model_executor.driver_worker.
|
def print_model(model):
|
||||||
model_runner.model)
|
print(model)
|
||||||
|
|
||||||
|
vllm_model.apply_model(print_model)
|
||||||
|
|
||||||
with hf_runner(model,
|
with hf_runner(model,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
@ -62,10 +62,13 @@ def test_models(
|
|||||||
max_model_len=None,
|
max_model_len=None,
|
||||||
**vllm_extra_kwargs) as vllm_model:
|
**vllm_extra_kwargs) as vllm_model:
|
||||||
vllm_outputs = vllm_model.encode(example_prompts)
|
vllm_outputs = vllm_model.encode(example_prompts)
|
||||||
|
|
||||||
# This test is for verifying whether the model's extra_repr
|
# This test is for verifying whether the model's extra_repr
|
||||||
# can be printed correctly.
|
# can be printed correctly.
|
||||||
print(vllm_model.model.llm_engine.model_executor.driver_worker.
|
def print_model(model):
|
||||||
model_runner.model)
|
print(model)
|
||||||
|
|
||||||
|
vllm_model.apply_model(print_model)
|
||||||
|
|
||||||
check_embeddings_close(
|
check_embeddings_close(
|
||||||
embeddings_0_lst=hf_outputs,
|
embeddings_0_lst=hf_outputs,
|
||||||
|
@ -30,50 +30,55 @@ from vllm.platforms import current_platform
|
|||||||
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
|
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
|
||||||
model_path, strategy, quant_type, shape_0, is_symmetric = model_args
|
model_path, strategy, quant_type, shape_0, is_symmetric = model_args
|
||||||
with vllm_runner(model_path, enforce_eager=True) as llm:
|
with vllm_runner(model_path, enforce_eager=True) as llm:
|
||||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
|
||||||
layer = model.model.layers[0]
|
|
||||||
|
|
||||||
qkv_proj = layer.self_attn.qkv_proj
|
def check_model(model):
|
||||||
o_proj = layer.self_attn.o_proj
|
layer = model.model.layers[0]
|
||||||
gate_up_proj = layer.mlp.gate_up_proj
|
|
||||||
down_proj = layer.mlp.down_proj
|
|
||||||
|
|
||||||
# assert zp for symmetric and asymmetric cases
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
def zp_valid(zp: Optional[torch.Tensor]):
|
o_proj = layer.self_attn.o_proj
|
||||||
if is_symmetric:
|
gate_up_proj = layer.mlp.gate_up_proj
|
||||||
return zp is None
|
down_proj = layer.mlp.down_proj
|
||||||
|
|
||||||
return zp is not None and zp.dtype is torch.int32
|
# assert zp for symmetric and asymmetric cases
|
||||||
|
def zp_valid(zp: Optional[torch.Tensor]):
|
||||||
|
if is_symmetric:
|
||||||
|
return zp is None
|
||||||
|
|
||||||
assert zp_valid(qkv_proj.input_zero_point)
|
return zp is not None and zp.dtype is torch.int32
|
||||||
assert zp_valid(o_proj.input_zero_point)
|
|
||||||
assert zp_valid(gate_up_proj.input_zero_point)
|
|
||||||
assert zp_valid(down_proj.input_zero_point)
|
|
||||||
|
|
||||||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
assert zp_valid(qkv_proj.input_zero_point)
|
||||||
assert isinstance(o_proj.quant_method, CompressedTensorsLinearMethod)
|
assert zp_valid(o_proj.input_zero_point)
|
||||||
assert isinstance(gate_up_proj.quant_method,
|
assert zp_valid(gate_up_proj.input_zero_point)
|
||||||
CompressedTensorsLinearMethod)
|
assert zp_valid(down_proj.input_zero_point)
|
||||||
assert isinstance(down_proj.quant_method,
|
|
||||||
CompressedTensorsLinearMethod)
|
|
||||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
|
|
||||||
|
|
||||||
assert qkv_proj.scheme.strategy == strategy
|
assert isinstance(qkv_proj.quant_method,
|
||||||
assert qkv_proj.scheme.is_static_input_scheme
|
CompressedTensorsLinearMethod)
|
||||||
expected_type = torch.int8
|
assert isinstance(o_proj.quant_method,
|
||||||
|
CompressedTensorsLinearMethod)
|
||||||
|
assert isinstance(gate_up_proj.quant_method,
|
||||||
|
CompressedTensorsLinearMethod)
|
||||||
|
assert isinstance(down_proj.quant_method,
|
||||||
|
CompressedTensorsLinearMethod)
|
||||||
|
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
|
||||||
|
|
||||||
assert qkv_proj.weight.dtype is expected_type
|
assert qkv_proj.scheme.strategy == strategy
|
||||||
assert o_proj.weight.dtype is expected_type
|
assert qkv_proj.scheme.is_static_input_scheme
|
||||||
assert gate_up_proj.weight.dtype is expected_type
|
expected_type = torch.int8
|
||||||
|
|
||||||
if qkv_proj.scheme.strategy == "tensor":
|
assert qkv_proj.weight.dtype is expected_type
|
||||||
# Make sure it is a channelwise buffer
|
assert o_proj.weight.dtype is expected_type
|
||||||
# After running process_weights_after_loading
|
assert gate_up_proj.weight.dtype is expected_type
|
||||||
assert len(qkv_proj.weight_scale.shape) == 2
|
|
||||||
assert qkv_proj.weight_scale.shape[0] == shape_0
|
if qkv_proj.scheme.strategy == "tensor":
|
||||||
assert qkv_proj.weight_scale.shape[1] == 1
|
# Make sure it is a channelwise buffer
|
||||||
assert qkv_proj.weight_scale.dtype is torch.float32
|
# After running process_weights_after_loading
|
||||||
assert qkv_proj.input_scale.dtype is torch.float32
|
assert len(qkv_proj.weight_scale.shape) == 2
|
||||||
|
assert qkv_proj.weight_scale.shape[0] == shape_0
|
||||||
|
assert qkv_proj.weight_scale.shape[1] == 1
|
||||||
|
assert qkv_proj.weight_scale.dtype is torch.float32
|
||||||
|
assert qkv_proj.input_scale.dtype is torch.float32
|
||||||
|
|
||||||
|
llm.apply_model(check_model)
|
||||||
|
|
||||||
output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
|
output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
|
||||||
assert output
|
assert output
|
||||||
@ -129,16 +134,20 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
|
|||||||
def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
|
def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
|
||||||
model_path, strategy = model_args
|
model_path, strategy = model_args
|
||||||
with vllm_runner(model_path, dtype=torch.float16) as llm:
|
with vllm_runner(model_path, dtype=torch.float16) as llm:
|
||||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
|
||||||
layer = model.model.layers[0]
|
|
||||||
|
|
||||||
qkv_proj = layer.self_attn.qkv_proj
|
def check_model(model):
|
||||||
|
layer = model.model.layers[0]
|
||||||
|
|
||||||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
|
|
||||||
assert not qkv_proj.scheme.is_static_input_scheme
|
assert isinstance(qkv_proj.quant_method,
|
||||||
assert qkv_proj.scheme.strategy == strategy
|
CompressedTensorsLinearMethod)
|
||||||
assert qkv_proj.weight.dtype is torch.int8
|
assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
|
||||||
|
assert not qkv_proj.scheme.is_static_input_scheme
|
||||||
|
assert qkv_proj.scheme.strategy == strategy
|
||||||
|
assert qkv_proj.weight.dtype is torch.int8
|
||||||
|
|
||||||
|
llm.apply_model(check_model)
|
||||||
|
|
||||||
output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
|
output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
|
||||||
assert output
|
assert output
|
||||||
@ -152,19 +161,24 @@ def test_compressed_tensors_w8a8_dynamic_per_token(vllm_runner, model_args):
|
|||||||
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
|
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
|
||||||
model, strategy, group, pack_factor = wNa16_args
|
model, strategy, group, pack_factor = wNa16_args
|
||||||
with vllm_runner(model) as llm:
|
with vllm_runner(model) as llm:
|
||||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
|
||||||
layer = model.model.layers[0]
|
|
||||||
|
|
||||||
qkv_proj = layer.self_attn.qkv_proj
|
def check_model(model):
|
||||||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
layer = model.model.layers[0]
|
||||||
assert isinstance(qkv_proj.scheme, CompressedTensorsWNA16)
|
|
||||||
|
|
||||||
assert qkv_proj.scheme.strategy == strategy
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
assert qkv_proj.scheme.group_size == (-1 if group is None else group)
|
assert isinstance(qkv_proj.quant_method,
|
||||||
|
CompressedTensorsLinearMethod)
|
||||||
|
assert isinstance(qkv_proj.scheme, CompressedTensorsWNA16)
|
||||||
|
|
||||||
assert qkv_proj.weight_packed.dtype is torch.int32
|
assert qkv_proj.scheme.strategy == strategy
|
||||||
assert qkv_proj.weight_scale.dtype is torch.float16
|
assert qkv_proj.scheme.group_size == (-1
|
||||||
assert qkv_proj.scheme.pack_factor == pack_factor
|
if group is None else group)
|
||||||
|
|
||||||
|
assert qkv_proj.weight_packed.dtype is torch.int32
|
||||||
|
assert qkv_proj.weight_scale.dtype is torch.float16
|
||||||
|
assert qkv_proj.scheme.pack_factor == pack_factor
|
||||||
|
|
||||||
|
llm.apply_model(check_model)
|
||||||
|
|
||||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||||
assert output
|
assert output
|
||||||
@ -173,14 +187,18 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
|
|||||||
def test_compressed_tensors_w4a16_marlin24(vllm_runner):
|
def test_compressed_tensors_w4a16_marlin24(vllm_runner):
|
||||||
model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
|
model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
|
||||||
with vllm_runner(model_path) as llm:
|
with vllm_runner(model_path) as llm:
|
||||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
|
||||||
layer = model.model.layers[0]
|
|
||||||
|
|
||||||
qkv_proj = layer.self_attn.qkv_proj
|
def check_model(model):
|
||||||
|
layer = model.model.layers[0]
|
||||||
|
|
||||||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Sparse24)
|
|
||||||
assert qkv_proj.weight_packed.dtype is torch.int32
|
assert isinstance(qkv_proj.quant_method,
|
||||||
|
CompressedTensorsLinearMethod)
|
||||||
|
assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Sparse24)
|
||||||
|
assert qkv_proj.weight_packed.dtype is torch.int32
|
||||||
|
|
||||||
|
llm.apply_model(check_model)
|
||||||
|
|
||||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||||
assert output
|
assert output
|
||||||
@ -189,23 +207,27 @@ def test_compressed_tensors_w4a16_marlin24(vllm_runner):
|
|||||||
def test_compressed_tensors_fp8(vllm_runner):
|
def test_compressed_tensors_fp8(vllm_runner):
|
||||||
model_path = "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
|
model_path = "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
|
||||||
with vllm_runner(model_path) as llm:
|
with vllm_runner(model_path) as llm:
|
||||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
|
||||||
layer = model.model.layers[0]
|
|
||||||
|
|
||||||
qkv_proj = layer.self_attn.qkv_proj
|
def check_model(model):
|
||||||
|
layer = model.model.layers[0]
|
||||||
|
|
||||||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
assert isinstance(
|
|
||||||
qkv_proj.scheme,
|
|
||||||
(CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8))
|
|
||||||
|
|
||||||
assert qkv_proj.input_scale.dtype is torch.float32
|
assert isinstance(qkv_proj.quant_method,
|
||||||
|
CompressedTensorsLinearMethod)
|
||||||
|
assert isinstance(
|
||||||
|
qkv_proj.scheme,
|
||||||
|
(CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8))
|
||||||
|
|
||||||
if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8):
|
assert qkv_proj.input_scale.dtype is torch.float32
|
||||||
assert len(qkv_proj.input_scale.shape) == 0
|
|
||||||
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
|
if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8):
|
||||||
assert qkv_proj.weight_scale.dtype is torch.float32
|
assert len(qkv_proj.input_scale.shape) == 0
|
||||||
assert len(qkv_proj.weight_scale.shape) == 0
|
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
|
||||||
|
assert qkv_proj.weight_scale.dtype is torch.float32
|
||||||
|
assert len(qkv_proj.weight_scale.shape) == 0
|
||||||
|
|
||||||
|
llm.apply_model(check_model)
|
||||||
|
|
||||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||||
assert output
|
assert output
|
||||||
@ -248,12 +270,15 @@ def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy):
|
|||||||
def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
|
def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
|
||||||
model, weight_strategy, input_strategy = args_2of4
|
model, weight_strategy, input_strategy = args_2of4
|
||||||
with vllm_runner(model) as llm:
|
with vllm_runner(model) as llm:
|
||||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
|
||||||
layer = model.model.layers[0]
|
|
||||||
|
|
||||||
qkv_proj = layer.self_attn.qkv_proj
|
def check_model(model):
|
||||||
assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
|
layer = model.model.layers[0]
|
||||||
_test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)
|
|
||||||
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
|
assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
|
||||||
|
_test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)
|
||||||
|
|
||||||
|
llm.apply_model(check_model)
|
||||||
|
|
||||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||||
print(output)
|
print(output)
|
||||||
@ -273,12 +298,15 @@ def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
|
|||||||
def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
|
def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
|
||||||
model, weight_strategy, input_strategy = args_2of4
|
model, weight_strategy, input_strategy = args_2of4
|
||||||
with vllm_runner(model) as llm:
|
with vllm_runner(model) as llm:
|
||||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
|
||||||
layer = model.model.layers[0]
|
|
||||||
|
|
||||||
qkv_proj = layer.self_attn.qkv_proj
|
def check_model(model):
|
||||||
assert qkv_proj.scheme.weights_dtype == torch.int8
|
layer = model.model.layers[0]
|
||||||
_test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)
|
|
||||||
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
|
assert qkv_proj.scheme.weights_dtype == torch.int8
|
||||||
|
_test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)
|
||||||
|
|
||||||
|
llm.apply_model(check_model)
|
||||||
|
|
||||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||||
print(output)
|
print(output)
|
||||||
@ -293,20 +321,24 @@ def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
|
|||||||
def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
|
def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
|
||||||
model = args_2of4
|
model = args_2of4
|
||||||
with vllm_runner(model) as llm:
|
with vllm_runner(model) as llm:
|
||||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
|
||||||
layer = model.model.layers[0]
|
|
||||||
|
|
||||||
qkv_proj = layer.self_attn.qkv_proj
|
def check_model(model):
|
||||||
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
|
layer = model.model.layers[0]
|
||||||
assert isinstance(qkv_proj.scheme, CompressedTensors24)
|
|
||||||
|
|
||||||
assert qkv_proj.scheme.weight_quant is None
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
assert qkv_proj.scheme.input_quant is None
|
assert isinstance(qkv_proj.quant_method,
|
||||||
assert not qkv_proj.scheme.quantized
|
CompressedTensorsLinearMethod)
|
||||||
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
|
assert isinstance(qkv_proj.scheme, CompressedTensors24)
|
||||||
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501
|
|
||||||
assert sparsity_map.get("Linear").format == "dense"
|
assert qkv_proj.scheme.weight_quant is None
|
||||||
assert sparsity_map.get("Linear").sparsity_structure == "2:4"
|
assert qkv_proj.scheme.input_quant is None
|
||||||
|
assert not qkv_proj.scheme.quantized
|
||||||
|
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
|
||||||
|
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501
|
||||||
|
assert sparsity_map.get("Linear").format == "dense"
|
||||||
|
assert sparsity_map.get("Linear").sparsity_structure == "2:4"
|
||||||
|
|
||||||
|
llm.apply_model(check_model)
|
||||||
|
|
||||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||||
print(output)
|
print(output)
|
||||||
|
@ -49,13 +49,17 @@ KV_CACHE_MODELS = [
|
|||||||
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str):
|
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str):
|
||||||
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:
|
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:
|
||||||
|
|
||||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
def check_model(model):
|
||||||
attn = model.model.layers[0].self_attn.attn
|
attn = model.model.layers[0].self_attn.attn
|
||||||
assert isinstance(attn.quant_method, Fp8KVCacheMethod)
|
|
||||||
# NOTE: it is valid for scales to be 1.0 (default value), but we know
|
assert isinstance(attn.quant_method, Fp8KVCacheMethod)
|
||||||
# these checkpoints have scales < 1.0
|
|
||||||
assert 0.0 < attn._k_scale < 1.0
|
# NOTE: it is valid for scales to be 1.0 (default value), but
|
||||||
assert 0.0 < attn._v_scale < 1.0
|
# we know these checkpoints have scales < 1.0
|
||||||
|
assert 0.0 < attn._k_scale < 1.0
|
||||||
|
assert 0.0 < attn._v_scale < 1.0
|
||||||
|
|
||||||
|
llm.apply_model(check_model)
|
||||||
|
|
||||||
# note: this does not test accuracy, just that we can run through
|
# note: this does not test accuracy, just that we can run through
|
||||||
# see lm-eval tests for accuracy
|
# see lm-eval tests for accuracy
|
||||||
@ -77,22 +81,24 @@ def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
|
|||||||
quantization="fp8",
|
quantization="fp8",
|
||||||
kv_cache_dtype=kv_cache_dtype) as llm:
|
kv_cache_dtype=kv_cache_dtype) as llm:
|
||||||
|
|
||||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
def check_model(model):
|
||||||
fc1 = model.model.decoder.layers[0].fc1
|
fc1 = model.model.decoder.layers[0].fc1
|
||||||
assert isinstance(fc1.quant_method, Fp8LinearMethod)
|
assert isinstance(fc1.quant_method, Fp8LinearMethod)
|
||||||
if kv_cache_dtype == "fp8":
|
if kv_cache_dtype == "fp8":
|
||||||
attn = model.model.decoder.layers[0].self_attn.attn
|
attn = model.model.decoder.layers[0].self_attn.attn
|
||||||
assert isinstance(attn.quant_method, Fp8KVCacheMethod)
|
assert isinstance(attn.quant_method, Fp8KVCacheMethod)
|
||||||
assert attn._k_scale == 1.0
|
assert attn._k_scale == 1.0
|
||||||
assert attn._v_scale == 1.0
|
assert attn._v_scale == 1.0
|
||||||
|
|
||||||
if current_platform.has_device_capability(89) and not force_marlin:
|
if current_platform.has_device_capability(89) and not force_marlin:
|
||||||
# For GPUs with hardware support, we keep weights in fp8
|
# For GPUs with hardware support, we keep weights in fp8
|
||||||
assert fc1.weight.dtype == torch.float8_e4m3fn
|
assert fc1.weight.dtype == torch.float8_e4m3fn
|
||||||
else:
|
else:
|
||||||
# For GPUs without hardware support, we pack the fp8 weights
|
# For GPUs without hardware support, we pack the fp8 weights
|
||||||
# for weight-only quantization using Marlin kernels
|
# for weight-only quantization using Marlin kernels
|
||||||
assert fc1.weight.dtype == torch.int32
|
assert fc1.weight.dtype == torch.int32
|
||||||
|
|
||||||
|
llm.apply_model(check_model)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
|
||||||
|
@ -28,20 +28,23 @@ def test_lm_head(
|
|||||||
model_lm_head_quant: Tuple[str, bool],
|
model_lm_head_quant: Tuple[str, bool],
|
||||||
) -> None:
|
) -> None:
|
||||||
model, lm_head_quantized = model_lm_head_quant
|
model, lm_head_quantized = model_lm_head_quant
|
||||||
vllm_model = vllm_runner(model, dtype=torch.float16, max_model_len=2048)
|
|
||||||
|
|
||||||
lm_head_layer = (vllm_model.model.llm_engine.model_executor.driver_worker.
|
with vllm_runner(model, dtype=torch.float16,
|
||||||
model_runner.model.lm_head)
|
max_model_len=2048) as vllm_model:
|
||||||
|
|
||||||
if lm_head_quantized:
|
def check_model(model):
|
||||||
assert isinstance(
|
lm_head_layer = model.lm_head
|
||||||
lm_head_layer.linear_method,
|
|
||||||
(GPTQLinearMethod, GPTQMarlinLinearMethod, MarlinLinearMethod))
|
|
||||||
else:
|
|
||||||
assert isinstance(lm_head_layer.linear_method,
|
|
||||||
UnquantizedEmbeddingMethod)
|
|
||||||
|
|
||||||
print(
|
if lm_head_quantized:
|
||||||
vllm_model.generate_greedy(prompts=["Hello my name is"],
|
assert isinstance(lm_head_layer.linear_method,
|
||||||
max_tokens=10)[0][1])
|
(GPTQLinearMethod, GPTQMarlinLinearMethod,
|
||||||
del vllm_model
|
MarlinLinearMethod))
|
||||||
|
else:
|
||||||
|
assert isinstance(lm_head_layer.linear_method,
|
||||||
|
UnquantizedEmbeddingMethod)
|
||||||
|
|
||||||
|
vllm_model.apply_model(check_model)
|
||||||
|
|
||||||
|
print(
|
||||||
|
vllm_model.generate_greedy(prompts=["Hello my name is"],
|
||||||
|
max_tokens=10)[0][1])
|
||||||
|
@ -12,19 +12,22 @@ from vllm.model_executor.layers.quantization.quark.quark import ( # noqa: E501
|
|||||||
def test_quark_fp8(vllm_runner):
|
def test_quark_fp8(vllm_runner):
|
||||||
model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
|
model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
|
||||||
with vllm_runner(model_path) as llm:
|
with vllm_runner(model_path) as llm:
|
||||||
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
|
|
||||||
layer = model.model.layers[0]
|
|
||||||
|
|
||||||
qkv_proj = layer.self_attn.qkv_proj
|
def check_model(model):
|
||||||
|
layer = model.model.layers[0]
|
||||||
|
|
||||||
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
|
qkv_proj = layer.self_attn.qkv_proj
|
||||||
assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)
|
|
||||||
|
|
||||||
if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
|
assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
|
||||||
assert len(qkv_proj.input_scale.shape) == 0
|
assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)
|
||||||
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
|
|
||||||
#assert qkv_proj.weight.dtype is torch.float8_e4m3fnuz
|
if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
|
||||||
assert len(qkv_proj.weight_scale.shape) == 0
|
assert len(qkv_proj.input_scale.shape) == 0
|
||||||
|
assert qkv_proj.weight.dtype is torch.float8_e4m3fn
|
||||||
|
#assert qkv_proj.weight.dtype is torch.float8_e4m3fnuz
|
||||||
|
assert len(qkv_proj.weight_scale.shape) == 0
|
||||||
|
|
||||||
|
llm.apply_model(check_model)
|
||||||
|
|
||||||
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
output = llm.generate_greedy("Hello my name is", max_tokens=20)
|
||||||
assert output
|
assert output
|
||||||
|
@ -3,6 +3,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import subprocess
|
import subprocess
|
||||||
|
from functools import partial
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
@ -24,7 +25,6 @@ from vllm.model_executor.model_loader.tensorizer import (TensorizerConfig,
|
|||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.utils import PlaceholderModule, import_from_path
|
from vllm.utils import PlaceholderModule, import_from_path
|
||||||
|
|
||||||
from ..conftest import VllmRunner
|
|
||||||
from ..utils import VLLM_PATH, RemoteOpenAIServer
|
from ..utils import VLLM_PATH, RemoteOpenAIServer
|
||||||
from .conftest import retry_until_skip
|
from .conftest import retry_until_skip
|
||||||
|
|
||||||
@ -58,16 +58,6 @@ def is_curl_installed():
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_torch_model(vllm_runner: VllmRunner):
|
|
||||||
return vllm_runner \
|
|
||||||
.model \
|
|
||||||
.llm_engine \
|
|
||||||
.model_executor \
|
|
||||||
.driver_worker \
|
|
||||||
.model_runner \
|
|
||||||
.model
|
|
||||||
|
|
||||||
|
|
||||||
def write_keyfile(keyfile_path: str):
|
def write_keyfile(keyfile_path: str):
|
||||||
encryption_params = EncryptionParams.random()
|
encryption_params = EncryptionParams.random()
|
||||||
pathlib.Path(keyfile_path).parent.mkdir(parents=True, exist_ok=True)
|
pathlib.Path(keyfile_path).parent.mkdir(parents=True, exist_ok=True)
|
||||||
@ -121,8 +111,10 @@ def test_deserialized_encrypted_vllm_model_has_same_outputs(
|
|||||||
|
|
||||||
config_for_serializing = TensorizerConfig(tensorizer_uri=model_path,
|
config_for_serializing = TensorizerConfig(tensorizer_uri=model_path,
|
||||||
encryption_keyfile=key_path)
|
encryption_keyfile=key_path)
|
||||||
serialize_vllm_model(get_torch_model(vllm_model),
|
|
||||||
config_for_serializing)
|
vllm_model.apply_model(
|
||||||
|
partial(serialize_vllm_model,
|
||||||
|
tensorizer_config=config_for_serializing))
|
||||||
|
|
||||||
config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
|
config_for_deserializing = TensorizerConfig(tensorizer_uri=model_path,
|
||||||
encryption_keyfile=key_path)
|
encryption_keyfile=key_path)
|
||||||
@ -175,8 +167,10 @@ def test_vllm_model_can_load_with_lora(vllm_runner, tmp_path):
|
|||||||
with vllm_runner(model_ref, ) as vllm_model:
|
with vllm_runner(model_ref, ) as vllm_model:
|
||||||
model_path = tmp_path / (model_ref + ".tensors")
|
model_path = tmp_path / (model_ref + ".tensors")
|
||||||
|
|
||||||
serialize_vllm_model(get_torch_model(vllm_model),
|
vllm_model.apply_model(
|
||||||
TensorizerConfig(tensorizer_uri=model_path))
|
partial(
|
||||||
|
serialize_vllm_model,
|
||||||
|
tensorizer_config=TensorizerConfig(tensorizer_uri=model_path)))
|
||||||
|
|
||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
model_ref,
|
model_ref,
|
||||||
@ -215,8 +209,10 @@ def test_openai_apiserver_with_tensorizer(vllm_runner, tmp_path):
|
|||||||
with vllm_runner(model_ref, ) as vllm_model:
|
with vllm_runner(model_ref, ) as vllm_model:
|
||||||
model_path = tmp_path / (model_ref + ".tensors")
|
model_path = tmp_path / (model_ref + ".tensors")
|
||||||
|
|
||||||
serialize_vllm_model(get_torch_model(vllm_model),
|
vllm_model.apply_model(
|
||||||
TensorizerConfig(tensorizer_uri=model_path))
|
partial(
|
||||||
|
serialize_vllm_model,
|
||||||
|
tensorizer_config=TensorizerConfig(tensorizer_uri=model_path)))
|
||||||
|
|
||||||
model_loader_extra_config = {
|
model_loader_extra_config = {
|
||||||
"tensorizer_uri": str(model_path),
|
"tensorizer_uri": str(model_path),
|
||||||
@ -337,7 +333,9 @@ def test_vllm_tensorized_model_has_same_outputs(vllm_runner, tmp_path):
|
|||||||
|
|
||||||
with vllm_runner(model_ref) as vllm_model:
|
with vllm_runner(model_ref) as vllm_model:
|
||||||
outputs = vllm_model.generate(prompts, sampling_params)
|
outputs = vllm_model.generate(prompts, sampling_params)
|
||||||
serialize_vllm_model(get_torch_model(vllm_model), config)
|
|
||||||
|
vllm_model.apply_model(
|
||||||
|
partial(serialize_vllm_model, tensorizer_config=config))
|
||||||
|
|
||||||
assert is_vllm_tensorized(config)
|
assert is_vllm_tensorized(config)
|
||||||
|
|
||||||
|
@ -5,10 +5,10 @@ from collections import deque
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
|
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
|
||||||
Iterable, List, Mapping, NamedTuple, Optional)
|
List, Mapping, NamedTuple, Optional)
|
||||||
from typing import Sequence as GenericSequence
|
from typing import Sequence as GenericSequence
|
||||||
from typing import Set, Tuple, Type, Union, cast, overload
|
from typing import Set, Type, Union, cast, overload
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from typing_extensions import TypeVar, deprecated
|
from typing_extensions import TypeVar, deprecated
|
||||||
@ -1818,17 +1818,6 @@ class LLMEngine:
|
|||||||
def stop_profile(self) -> None:
|
def stop_profile(self) -> None:
|
||||||
self.model_executor.stop_profile()
|
self.model_executor.stop_profile()
|
||||||
|
|
||||||
def collective_rpc(self,
|
|
||||||
method: Union[str, Callable],
|
|
||||||
timeout: Optional[float] = None,
|
|
||||||
args: Tuple = (),
|
|
||||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
|
||||||
"""
|
|
||||||
See LLM.collective_rpc for more details.
|
|
||||||
"""
|
|
||||||
return self.model_executor.collective_rpc(method, timeout, args,
|
|
||||||
kwargs)
|
|
||||||
|
|
||||||
def check_health(self) -> None:
|
def check_health(self) -> None:
|
||||||
if self.tokenizer:
|
if self.tokenizer:
|
||||||
self.tokenizer.check_health()
|
self.tokenizer.check_health()
|
||||||
|
@ -5,8 +5,9 @@ from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence,
|
|||||||
Tuple, Type, Union, cast, overload)
|
Tuple, Type, Union, cast, overload)
|
||||||
|
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
|
import torch.nn as nn
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import TypeVar, deprecated
|
||||||
|
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
|
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
|
||||||
@ -42,6 +43,8 @@ from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
_R = TypeVar("_R", default=Any)
|
||||||
|
|
||||||
|
|
||||||
class LLM:
|
class LLM:
|
||||||
"""An LLM for generating texts from given prompts and sampling parameters.
|
"""An LLM for generating texts from given prompts and sampling parameters.
|
||||||
@ -464,25 +467,42 @@ class LLM:
|
|||||||
return self.engine_class.validate_outputs(outputs, RequestOutput)
|
return self.engine_class.validate_outputs(outputs, RequestOutput)
|
||||||
|
|
||||||
def collective_rpc(self,
|
def collective_rpc(self,
|
||||||
method: Union[str, Callable],
|
method: Union[str, Callable[..., _R]],
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
args: Tuple = (),
|
args: Tuple = (),
|
||||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
kwargs: Optional[Dict[str, Any]] = None) -> List[_R]:
|
||||||
"""
|
"""
|
||||||
Run a method on all workers, with homogeneous arguments.
|
Execute an RPC call on all workers.
|
||||||
The main extension point for the LLM entrypoint.
|
|
||||||
Users can provide custom worker class through `worker_cls`
|
Args:
|
||||||
argument, and implement new methods in the worker class.
|
method: Name of the worker method to execute, or a callable that
|
||||||
Then, users can call the new methods through this API.
|
is serialized and sent to all workers to execute.
|
||||||
It is recommended to use this API to only pass control messages,
|
|
||||||
and set up data-plane communication to pass data.
|
If the method is a callable, it should accept an additional
|
||||||
The method can also be a callable, which will be serialized
|
`self` argument, in addition to the arguments passed in `args`
|
||||||
and sent to all workers to execute.
|
and `kwargs`. The `self` argument will be the worker object.
|
||||||
If the method is a callable, it should accept an additional
|
timeout: Maximum time in seconds to wait for execution. Raises a
|
||||||
`self` argument, in addition to the arguments passed in `args`
|
:exc:`TimeoutError` on timeout. `None` means wait indefinitely.
|
||||||
and `kwargs`. The `self` argument will be the worker object.
|
args: Positional arguments to pass to the worker method.
|
||||||
|
kwargs: Keyword arguments to pass to the worker method.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list containing the results from each worker.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
It is recommended to use this API to only pass control messages,
|
||||||
|
and set up data-plane communication to pass data.
|
||||||
"""
|
"""
|
||||||
return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
|
executor = self.llm_engine.model_executor
|
||||||
|
return executor.collective_rpc(method, timeout, args, kwargs)
|
||||||
|
|
||||||
|
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
|
||||||
|
"""
|
||||||
|
Run a function directly on the model inside each worker,
|
||||||
|
returning the result for each of them.
|
||||||
|
"""
|
||||||
|
executor = self.llm_engine.model_executor
|
||||||
|
return executor.apply_model(func)
|
||||||
|
|
||||||
def beam_search(
|
def beam_search(
|
||||||
self,
|
self,
|
||||||
|
@ -3,6 +3,9 @@ from abc import ABC, abstractmethod
|
|||||||
from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
|
from typing import (Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple,
|
||||||
Union)
|
Union)
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from typing_extensions import TypeVar
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
@ -11,9 +14,12 @@ from vllm.platforms import current_platform
|
|||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput
|
from vllm.sequence import ExecuteModelRequest, PoolerOutput
|
||||||
from vllm.utils import make_async
|
from vllm.utils import make_async
|
||||||
|
from vllm.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
_R = TypeVar("_R", default=Any)
|
||||||
|
|
||||||
|
|
||||||
class ExecutorBase(ABC):
|
class ExecutorBase(ABC):
|
||||||
"""Base class for all executors.
|
"""Base class for all executors.
|
||||||
@ -44,22 +50,37 @@ class ExecutorBase(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _init_executor(self) -> None:
|
def _init_executor(self) -> None:
|
||||||
pass
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def collective_rpc(self,
|
def collective_rpc(self,
|
||||||
method: Union[str, Callable],
|
method: Union[str, Callable[..., _R]],
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
args: Tuple = (),
|
args: Tuple = (),
|
||||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
kwargs: Optional[Dict[str, Any]] = None) -> List[_R]:
|
||||||
"""
|
"""
|
||||||
The main interface of the executor to run a method on all workers,
|
Execute an RPC call on all workers.
|
||||||
with homogeneous arguments.
|
|
||||||
If the args are heterogeneous, then we can pack them into a list,
|
Args:
|
||||||
and unpack them in the method of every worker, because every worker
|
method: Name of the worker method to execute, or a callable that
|
||||||
knows their own rank.
|
is serialized and sent to all workers to execute.
|
||||||
|
|
||||||
|
If the method is a callable, it should accept an additional
|
||||||
|
`self` argument, in addition to the arguments passed in `args`
|
||||||
|
and `kwargs`. The `self` argument will be the worker object.
|
||||||
|
timeout: Maximum time in seconds to wait for execution. Raises a
|
||||||
|
:exc:`TimeoutError` on timeout. `None` means wait indefinitely.
|
||||||
|
args: Positional arguments to pass to the worker method.
|
||||||
|
kwargs: Keyword arguments to pass to the worker method.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list containing the results from each worker.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
It is recommended to use this API to only pass control messages,
|
||||||
|
and set up data-plane communication to pass data.
|
||||||
"""
|
"""
|
||||||
pass
|
raise NotImplementedError
|
||||||
|
|
||||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
"""Determine the number of available blocks for the GPU KV cache and
|
"""Determine the number of available blocks for the GPU KV cache and
|
||||||
@ -97,6 +118,17 @@ class ExecutorBase(ABC):
|
|||||||
self.collective_rpc("initialize_cache",
|
self.collective_rpc("initialize_cache",
|
||||||
args=(num_gpu_blocks, num_cpu_blocks))
|
args=(num_gpu_blocks, num_cpu_blocks))
|
||||||
|
|
||||||
|
def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
|
||||||
|
"""
|
||||||
|
Run a function directly on the model inside each worker,
|
||||||
|
returning the result for each of them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def rpc_func(worker: WorkerBase) -> _R:
|
||||||
|
return func(worker.get_model())
|
||||||
|
|
||||||
|
return self.collective_rpc(rpc_func)
|
||||||
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self, execute_model_req: ExecuteModelRequest
|
self, execute_model_req: ExecuteModelRequest
|
||||||
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
|
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
|
||||||
|
@ -148,7 +148,7 @@ class MultiprocessingDistributedExecutor(DistributedExecutorBase):
|
|||||||
async_run_tensor_parallel_workers_only: bool = False,
|
async_run_tensor_parallel_workers_only: bool = False,
|
||||||
max_concurrent_workers: Optional[int] = None,
|
max_concurrent_workers: Optional[int] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Any:
|
) -> List[Any]:
|
||||||
"""Runs the given method on all workers.
|
"""Runs the given method on all workers.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -459,16 +459,7 @@ def tensorize_vllm_model(engine_args: EngineArgs,
|
|||||||
stream.write(encryption_params.key)
|
stream.write(encryption_params.key)
|
||||||
|
|
||||||
engine = LLMEngine.from_engine_args(engine_args)
|
engine = LLMEngine.from_engine_args(engine_args)
|
||||||
if tensorizer_config._is_sharded:
|
engine.model_executor.collective_rpc(
|
||||||
# if the engine is a distributed engine (for tensor parallel) then each
|
"save_tensorized_model",
|
||||||
# worker shard needs to serialize its part of the model.
|
kwargs=dict(tensorizer_config=tensorizer_config),
|
||||||
engine.model_executor._run_workers(
|
)
|
||||||
"save_tensorized_model",
|
|
||||||
tensorizer_config=tensorizer_config,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# with a single worker, we can get to the underlying model directly
|
|
||||||
serialize_vllm_model(
|
|
||||||
engine.model_executor.driver_worker.model_runner.model,
|
|
||||||
tensorizer_config,
|
|
||||||
)
|
|
||||||
|
@ -2,6 +2,7 @@ import weakref
|
|||||||
from typing import List, Optional, Set, Tuple
|
from typing import List, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.sequence import ExecuteModelRequest
|
from vllm.sequence import ExecuteModelRequest
|
||||||
@ -10,6 +11,10 @@ from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase
|
|||||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyModel(nn.Module):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class NGramWorker(NonLLMProposerWorkerBase):
|
class NGramWorker(NonLLMProposerWorkerBase):
|
||||||
"""NGramWorker provides a light drafter without need for model.
|
"""NGramWorker provides a light drafter without need for model.
|
||||||
|
|
||||||
@ -36,7 +41,6 @@ class NGramWorker(NonLLMProposerWorkerBase):
|
|||||||
|
|
||||||
def init_device(self):
|
def init_device(self):
|
||||||
self.device = torch.device(f"{self.device_type}:{self.local_rank}")
|
self.device = torch.device(f"{self.device_type}:{self.local_rank}")
|
||||||
self.load_model = lambda *args, **kwargs: None
|
|
||||||
|
|
||||||
# Current NGramWorker only supports Top1Proposer
|
# Current NGramWorker only supports Top1Proposer
|
||||||
self._proposer = Top1Proposer(
|
self._proposer = Top1Proposer(
|
||||||
@ -45,6 +49,12 @@ class NGramWorker(NonLLMProposerWorkerBase):
|
|||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def load_model(self) -> None:
|
||||||
|
pass # Dummy
|
||||||
|
|
||||||
|
def get_model(self) -> nn.Module:
|
||||||
|
return _DummyModel()
|
||||||
|
|
||||||
def sampler_output(
|
def sampler_output(
|
||||||
self,
|
self,
|
||||||
execute_model_req: ExecuteModelRequest,
|
execute_model_req: ExecuteModelRequest,
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from typing import List, Optional, Set, Tuple
|
from typing import List, Optional, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.distributed.parallel_state import (get_tp_group,
|
from vllm.distributed.parallel_state import (get_tp_group,
|
||||||
init_model_parallel_group,
|
init_model_parallel_group,
|
||||||
@ -15,6 +16,10 @@ from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyModel(nn.Module):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SmallerTpProposerWorker(ProposerWorkerBase):
|
class SmallerTpProposerWorker(ProposerWorkerBase):
|
||||||
"""Class which allows a speculative draft model to run with smaller tensor
|
"""Class which allows a speculative draft model to run with smaller tensor
|
||||||
parallel degree than target model.
|
parallel degree than target model.
|
||||||
@ -139,6 +144,13 @@ class SmallerTpProposerWorker(ProposerWorkerBase):
|
|||||||
return self._worker.get_spec_proposals(
|
return self._worker.get_spec_proposals(
|
||||||
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
execute_model_req, seq_ids_with_bonus_token_in_last_step)
|
||||||
|
|
||||||
|
def get_model(self) -> nn.Module:
|
||||||
|
if self._is_dummy:
|
||||||
|
return _DummyModel()
|
||||||
|
|
||||||
|
with self._patch_tensor_parallel_group():
|
||||||
|
return self._worker.get_model()
|
||||||
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||||
|
@ -4,6 +4,7 @@ from functools import cached_property
|
|||||||
from typing import Any, Dict, List, Optional, Set, Tuple, Type
|
from typing import Any, Dict, List, Optional, Set, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
|
from vllm.config import ParallelConfig, SpeculativeConfig, VllmConfig
|
||||||
from vllm.distributed.communication_op import broadcast_tensor_dict
|
from vllm.distributed.communication_op import broadcast_tensor_dict
|
||||||
@ -403,6 +404,9 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
|||||||
self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
|
self.proposer_worker.initialize_cache(num_gpu_blocks=num_gpu_blocks,
|
||||||
num_cpu_blocks=num_cpu_blocks)
|
num_cpu_blocks=num_cpu_blocks)
|
||||||
|
|
||||||
|
def get_model(self) -> nn.Module:
|
||||||
|
return self.scorer_worker.get_model()
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
|
@ -94,22 +94,12 @@ class MultiprocExecutor(Executor):
|
|||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
args: Tuple = (),
|
args: Tuple = (),
|
||||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
kwargs: Optional[Dict] = None) -> List[Any]:
|
||||||
"""
|
|
||||||
Execute an RPC call on workers.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
method: Name of the worker method to execute
|
|
||||||
timeout: Maximum time in seconds to wait for execution. Rases a
|
|
||||||
TimeoutError on timeout. None means wait indefinitely.
|
|
||||||
args: Positional arguments to pass to the worker method
|
|
||||||
kwargs: Keyword arguments to pass to the worker method
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of results from each worker
|
|
||||||
"""
|
|
||||||
start_time = time.monotonic()
|
start_time = time.monotonic()
|
||||||
kwargs = kwargs or {}
|
kwargs = kwargs or {}
|
||||||
|
|
||||||
|
# NOTE: If the args are heterogeneous, then we pack them into a list,
|
||||||
|
# and unpack them in the method of every worker, because every worker
|
||||||
|
# knows their own rank.
|
||||||
try:
|
try:
|
||||||
if isinstance(method, str):
|
if isinstance(method, str):
|
||||||
send_method = method
|
send_method = method
|
||||||
|
@ -689,6 +689,9 @@ class GPUModelRunner:
|
|||||||
encoder_outputs.append(encoder_output[start_idx:end_idx])
|
encoder_outputs.append(encoder_output[start_idx:end_idx])
|
||||||
return encoder_outputs
|
return encoder_outputs
|
||||||
|
|
||||||
|
def get_model(self) -> nn.Module:
|
||||||
|
return self.model
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
|
@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
|
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
|
||||||
@ -176,6 +177,9 @@ class Worker:
|
|||||||
# the model initialization and profiling.
|
# the model initialization and profiling.
|
||||||
set_random_seed(self.model_config.seed)
|
set_random_seed(self.model_config.seed)
|
||||||
|
|
||||||
|
def get_model(self) -> nn.Module:
|
||||||
|
return self.model_runner.get_model()
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
|
@ -509,6 +509,9 @@ class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
|
|||||||
)
|
)
|
||||||
self.model = self.lora_manager.create_lora_manager(self.model)
|
self.model = self.lora_manager.create_lora_manager(self.model)
|
||||||
|
|
||||||
|
def get_model(self) -> nn.Module:
|
||||||
|
return self.model
|
||||||
|
|
||||||
def _prepare_model_input_tensors(
|
def _prepare_model_input_tensors(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
@ -21,6 +21,7 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
|
|||||||
import habana_frameworks.torch as htorch
|
import habana_frameworks.torch as htorch
|
||||||
import habana_frameworks.torch.internal.bridge_config as bc
|
import habana_frameworks.torch.internal.bridge_config as bc
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
from vllm_hpu_extension.ops import LoraMask as LoraMask
|
from vllm_hpu_extension.ops import LoraMask as LoraMask
|
||||||
from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
|
from vllm_hpu_extension.profiler import (HabanaHighLevelProfiler,
|
||||||
HabanaMemoryProfiler, format_bytes)
|
HabanaMemoryProfiler, format_bytes)
|
||||||
@ -676,6 +677,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
|
|||||||
msg = f"Loading model weights took in total {m.get_summary_string()}"
|
msg = f"Loading model weights took in total {m.get_summary_string()}"
|
||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
|
|
||||||
|
def get_model(self) -> nn.Module:
|
||||||
|
return self.model
|
||||||
|
|
||||||
def _use_graphs(self, batch_size, seq_len, is_prompt):
|
def _use_graphs(self, batch_size, seq_len, is_prompt):
|
||||||
if self.enforce_eager:
|
if self.enforce_eager:
|
||||||
return False
|
return False
|
||||||
|
@ -1176,6 +1176,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||||
backend=backend)
|
backend=backend)
|
||||||
|
|
||||||
|
def get_model(self) -> nn.Module:
|
||||||
|
return self.model
|
||||||
|
|
||||||
def save_sharded_state(
|
def save_sharded_state(
|
||||||
self,
|
self,
|
||||||
path: str,
|
path: str,
|
||||||
|
@ -7,6 +7,7 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, Iterable, List,
|
|||||||
Optional, Type, TypeVar)
|
Optional, Type, TypeVar)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
from torch import is_tensor
|
from torch import is_tensor
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@ -264,6 +265,10 @@ class ModelRunnerBase(ABC, Generic[T]):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_model(self) -> nn.Module:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
model_input: T,
|
model_input: T,
|
||||||
@ -297,9 +302,9 @@ class ModelRunnerWrapperBase:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
moderl_runner: ModelRunnerBase,
|
model_runner: ModelRunnerBase,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_runner: ModelRunnerBase = moderl_runner
|
self.model_runner: ModelRunnerBase = model_runner
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
def __getattr__(self, attr):
|
||||||
return getattr(self.model_runner, attr)
|
return getattr(self.model_runner, attr)
|
||||||
|
@ -113,6 +113,9 @@ class NeuronModelRunner(ModelRunnerBase[ModelInputForNeuron]):
|
|||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Supports only Transformer-NeuronX based models.")
|
"Supports only Transformer-NeuronX based models.")
|
||||||
|
|
||||||
|
def get_model(self) -> nn.Module:
|
||||||
|
return self.model
|
||||||
|
|
||||||
def _prepare_prompt(
|
def _prepare_prompt(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
@ -84,6 +84,9 @@ class OpenVINOModelRunner(ModelRunnerBase):
|
|||||||
kv_cache_dtype=self.kv_cache_dtype,
|
kv_cache_dtype=self.kv_cache_dtype,
|
||||||
ov_core=self.ov_core)
|
ov_core=self.ov_core)
|
||||||
|
|
||||||
|
def get_model(self) -> nn.Module:
|
||||||
|
return self.model
|
||||||
|
|
||||||
def _prepare_model_input(
|
def _prepare_model_input(
|
||||||
self,
|
self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||||||
import openvino as ov
|
import openvino as ov
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
from vllm.attention import get_attn_backend
|
from vllm.attention import get_attn_backend
|
||||||
@ -362,6 +363,9 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase):
|
|||||||
) -> None:
|
) -> None:
|
||||||
self.cache_engine.copy(blocks_to_copy) # type: ignore
|
self.cache_engine.copy(blocks_to_copy) # type: ignore
|
||||||
|
|
||||||
|
def get_model(self) -> nn.Module:
|
||||||
|
return self.model_runner.get_model()
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
|
@ -158,6 +158,9 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
fullgraph=True,
|
fullgraph=True,
|
||||||
dynamic=False)
|
dynamic=False)
|
||||||
|
|
||||||
|
def get_model(self) -> nn.Module:
|
||||||
|
return self.model.model
|
||||||
|
|
||||||
def _dummy_run(
|
def _dummy_run(
|
||||||
self,
|
self,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
|
|||||||
|
|
||||||
import cloudpickle
|
import cloudpickle
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
from vllm.config import ObservabilityConfig, VllmConfig
|
from vllm.config import ObservabilityConfig, VllmConfig
|
||||||
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
|
from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group
|
||||||
@ -90,6 +91,11 @@ class WorkerBase(ABC):
|
|||||||
if output is None:
|
if output is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_model(self) -> nn.Module:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||||
@ -147,6 +153,9 @@ class DelegateWorkerBase(WorkerBase):
|
|||||||
num_cpu_blocks: int) -> None:
|
num_cpu_blocks: int) -> None:
|
||||||
self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
self.worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
|
||||||
|
|
||||||
|
def get_model(self) -> nn.Module:
|
||||||
|
return self.worker.get_model()
|
||||||
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||||
@ -363,6 +372,9 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
|||||||
else:
|
else:
|
||||||
return self._get_worker_input_from_broadcast()
|
return self._get_worker_input_from_broadcast()
|
||||||
|
|
||||||
|
def get_model(self) -> nn.Module:
|
||||||
|
return self.model_runner.get_model()
|
||||||
|
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||||
|
@ -416,6 +416,9 @@ class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
|||||||
logger.info("Loading model weights took %.4f GB",
|
logger.info("Loading model weights took %.4f GB",
|
||||||
self.model_memory_usage / float(2**30))
|
self.model_memory_usage / float(2**30))
|
||||||
|
|
||||||
|
def get_model(self) -> nn.Module:
|
||||||
|
return self.model
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vocab_size(self) -> int:
|
def vocab_size(self) -> int:
|
||||||
return self.model_config.get_vocab_size()
|
return self.model_config.get_vocab_size()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user