[Bugfix] Fix broken OpenAI tensorizer test (#8258)

This commit is contained in:
Cyrus Leung 2024-09-07 16:02:39 +08:00 committed by GitHub
parent ce2702a923
commit 9f68e00d27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 81 additions and 40 deletions

View File

@ -20,7 +20,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.entrypoints.openai.cli_args import make_arg_parser
from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip
@ -89,11 +89,11 @@ class RemoteOpenAIServer:
is_local = os.path.isdir(model) is_local = os.path.isdir(model)
if not is_local: if not is_local:
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine_config = engine_args.create_engine_config() model_config = engine_args.create_model_config()
dummy_loader = DefaultModelLoader(engine_config.load_config) load_config = engine_args.create_load_config()
dummy_loader._prepare_weights(engine_config.model_config.model,
engine_config.model_config.revision, model_loader = get_model_loader(load_config)
fall_back_to_pt=True) model_loader.download_model(model_config)
env = os.environ.copy() env = os.environ.copy()
# the current process might initialize cuda, # the current process might initialize cuda,

View File

@ -771,33 +771,8 @@ class EngineArgs:
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs}) engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return engine_args return engine_args
def create_engine_config(self) -> EngineConfig: def create_model_config(self) -> ModelConfig:
# gguf file needs a specific model loader and doesn't use hf_repo return ModelConfig(
if check_gguf_file(self.model):
self.quantization = self.load_format = "gguf"
# bitsandbytes quantization needs a specific model loader
# so we make sure the quant method and the load format are consistent
if (self.quantization == "bitsandbytes" or
self.qlora_adapter_name_or_path is not None) and \
self.load_format != "bitsandbytes":
raise ValueError(
"BitsAndBytes quantization and QLoRA adapter only support "
f"'bitsandbytes' load format, but got {self.load_format}")
if (self.load_format == "bitsandbytes" or
self.qlora_adapter_name_or_path is not None) and \
self.quantization != "bitsandbytes":
raise ValueError(
"BitsAndBytes load format and QLoRA adapter only support "
f"'bitsandbytes' quantization, but got {self.quantization}")
assert self.cpu_offload_gb >= 0, (
"CPU offload space must be non-negative"
f", but got {self.cpu_offload_gb}")
device_config = DeviceConfig(device=self.device)
model_config = ModelConfig(
model=self.model, model=self.model,
tokenizer=self.tokenizer, tokenizer=self.tokenizer,
tokenizer_mode=self.tokenizer_mode, tokenizer_mode=self.tokenizer_mode,
@ -825,6 +800,42 @@ class EngineArgs:
config_format=self.config_format, config_format=self.config_format,
) )
def create_load_config(self) -> LoadConfig:
return LoadConfig(
load_format=self.load_format,
download_dir=self.download_dir,
model_loader_extra_config=self.model_loader_extra_config,
ignore_patterns=self.ignore_patterns,
)
def create_engine_config(self) -> EngineConfig:
# gguf file needs a specific model loader and doesn't use hf_repo
if check_gguf_file(self.model):
self.quantization = self.load_format = "gguf"
# bitsandbytes quantization needs a specific model loader
# so we make sure the quant method and the load format are consistent
if (self.quantization == "bitsandbytes" or
self.qlora_adapter_name_or_path is not None) and \
self.load_format != "bitsandbytes":
raise ValueError(
"BitsAndBytes quantization and QLoRA adapter only support "
f"'bitsandbytes' load format, but got {self.load_format}")
if (self.load_format == "bitsandbytes" or
self.qlora_adapter_name_or_path is not None) and \
self.quantization != "bitsandbytes":
raise ValueError(
"BitsAndBytes load format and QLoRA adapter only support "
f"'bitsandbytes' quantization, but got {self.quantization}")
assert self.cpu_offload_gb >= 0, (
"CPU offload space must be non-negative"
f", but got {self.cpu_offload_gb}")
device_config = DeviceConfig(device=self.device)
model_config = self.create_model_config()
cache_config = CacheConfig( cache_config = CacheConfig(
block_size=self.block_size if self.device != "neuron" else block_size=self.block_size if self.device != "neuron" else
self.max_model_len, # neuron needs block_size = max_model_len self.max_model_len, # neuron needs block_size = max_model_len
@ -967,12 +978,7 @@ class EngineArgs:
self.model_loader_extra_config[ self.model_loader_extra_config[
"qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
load_config = LoadConfig( load_config = self.create_load_config()
load_format=self.load_format,
download_dir=self.download_dir,
model_loader_extra_config=self.model_loader_extra_config,
ignore_patterns=self.ignore_patterns,
)
prompt_adapter_config = PromptAdapterConfig( prompt_adapter_config = PromptAdapterConfig(
max_prompt_adapters=self.max_prompt_adapters, max_prompt_adapters=self.max_prompt_adapters,

View File

@ -185,6 +185,11 @@ class BaseModelLoader(ABC):
def __init__(self, load_config: LoadConfig): def __init__(self, load_config: LoadConfig):
self.load_config = load_config self.load_config = load_config
@abstractmethod
def download_model(self, model_config: ModelConfig) -> None:
"""Download a model so that it can be immediately loaded."""
raise NotImplementedError
@abstractmethod @abstractmethod
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
@ -193,7 +198,7 @@ class BaseModelLoader(ABC):
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module: cache_config: CacheConfig) -> nn.Module:
"""Load a model with the given configurations.""" """Load a model with the given configurations."""
... raise NotImplementedError
class DefaultModelLoader(BaseModelLoader): class DefaultModelLoader(BaseModelLoader):
@ -335,6 +340,11 @@ class DefaultModelLoader(BaseModelLoader):
weights_iterator = _xla_weights_iterator(weights_iterator) weights_iterator = _xla_weights_iterator(weights_iterator)
return weights_iterator return weights_iterator
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model,
model_config.revision,
fall_back_to_pt=True)
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
@ -377,6 +387,9 @@ class DummyModelLoader(BaseModelLoader):
raise ValueError(f"Model loader extra config is not supported for " raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}") f"load format {load_config.load_format}")
def download_model(self, model_config: ModelConfig) -> None:
pass # Nothing to download
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
@ -467,6 +480,12 @@ class TensorizerLoader(BaseModelLoader):
model = load_with_tensorizer(tensorizer_config, **extra_kwargs) model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
return model.eval() return model.eval()
def download_model(self, model_config: ModelConfig) -> None:
self.tensorizer_config.verify_with_model_config(model_config)
with self.tensorizer_config.open_stream():
pass
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
@ -568,6 +587,9 @@ class ShardedStateLoader(BaseModelLoader):
ignore_patterns=self.load_config.ignore_patterns, ignore_patterns=self.load_config.ignore_patterns,
) )
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
@ -995,6 +1017,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
set_weight_attrs( set_weight_attrs(
param, {"matmul_state": [None] * len(quant_states)}) param, {"matmul_state": [None] * len(quant_states)})
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
@ -1070,6 +1095,9 @@ class GGUFModelLoader(BaseModelLoader):
return gguf_quant_weights_iterator(model_name_or_path, return gguf_quant_weights_iterator(model_name_or_path,
gguf_to_hf_name_map) gguf_to_hf_name_map)
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model)
def load_model(self, *, model_config: ModelConfig, def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig, device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],

View File

@ -99,6 +99,13 @@ class TensorizerConfig:
"Loading a model using Tensorizer with quantization on vLLM" "Loading a model using Tensorizer with quantization on vLLM"
" is unstable and may lead to errors.") " is unstable and may lead to errors.")
def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None):
if tensorizer_args is None:
tensorizer_args = self._construct_tensorizer_args()
return open_stream(self.tensorizer_uri,
**tensorizer_args.stream_params)
def load_with_tensorizer(tensorizer_config: TensorizerConfig, def load_with_tensorizer(tensorizer_config: TensorizerConfig,
**extra_kwargs) -> nn.Module: **extra_kwargs) -> nn.Module: