[Bugfix] Fix broken OpenAI tensorizer test (#8258)
This commit is contained in:
parent
ce2702a923
commit
9f68e00d27
@ -20,7 +20,7 @@ from vllm.distributed import (ensure_model_parallel_initialized,
|
|||||||
init_distributed_environment)
|
init_distributed_environment)
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||||
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
from vllm.model_executor.model_loader.loader import get_model_loader
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip
|
from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip
|
||||||
|
|
||||||
@ -89,11 +89,11 @@ class RemoteOpenAIServer:
|
|||||||
is_local = os.path.isdir(model)
|
is_local = os.path.isdir(model)
|
||||||
if not is_local:
|
if not is_local:
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||||
engine_config = engine_args.create_engine_config()
|
model_config = engine_args.create_model_config()
|
||||||
dummy_loader = DefaultModelLoader(engine_config.load_config)
|
load_config = engine_args.create_load_config()
|
||||||
dummy_loader._prepare_weights(engine_config.model_config.model,
|
|
||||||
engine_config.model_config.revision,
|
model_loader = get_model_loader(load_config)
|
||||||
fall_back_to_pt=True)
|
model_loader.download_model(model_config)
|
||||||
|
|
||||||
env = os.environ.copy()
|
env = os.environ.copy()
|
||||||
# the current process might initialize cuda,
|
# the current process might initialize cuda,
|
||||||
|
@ -771,33 +771,8 @@ class EngineArgs:
|
|||||||
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
|
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||||
return engine_args
|
return engine_args
|
||||||
|
|
||||||
def create_engine_config(self) -> EngineConfig:
|
def create_model_config(self) -> ModelConfig:
|
||||||
# gguf file needs a specific model loader and doesn't use hf_repo
|
return ModelConfig(
|
||||||
if check_gguf_file(self.model):
|
|
||||||
self.quantization = self.load_format = "gguf"
|
|
||||||
|
|
||||||
# bitsandbytes quantization needs a specific model loader
|
|
||||||
# so we make sure the quant method and the load format are consistent
|
|
||||||
if (self.quantization == "bitsandbytes" or
|
|
||||||
self.qlora_adapter_name_or_path is not None) and \
|
|
||||||
self.load_format != "bitsandbytes":
|
|
||||||
raise ValueError(
|
|
||||||
"BitsAndBytes quantization and QLoRA adapter only support "
|
|
||||||
f"'bitsandbytes' load format, but got {self.load_format}")
|
|
||||||
|
|
||||||
if (self.load_format == "bitsandbytes" or
|
|
||||||
self.qlora_adapter_name_or_path is not None) and \
|
|
||||||
self.quantization != "bitsandbytes":
|
|
||||||
raise ValueError(
|
|
||||||
"BitsAndBytes load format and QLoRA adapter only support "
|
|
||||||
f"'bitsandbytes' quantization, but got {self.quantization}")
|
|
||||||
|
|
||||||
assert self.cpu_offload_gb >= 0, (
|
|
||||||
"CPU offload space must be non-negative"
|
|
||||||
f", but got {self.cpu_offload_gb}")
|
|
||||||
|
|
||||||
device_config = DeviceConfig(device=self.device)
|
|
||||||
model_config = ModelConfig(
|
|
||||||
model=self.model,
|
model=self.model,
|
||||||
tokenizer=self.tokenizer,
|
tokenizer=self.tokenizer,
|
||||||
tokenizer_mode=self.tokenizer_mode,
|
tokenizer_mode=self.tokenizer_mode,
|
||||||
@ -825,6 +800,42 @@ class EngineArgs:
|
|||||||
config_format=self.config_format,
|
config_format=self.config_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def create_load_config(self) -> LoadConfig:
|
||||||
|
return LoadConfig(
|
||||||
|
load_format=self.load_format,
|
||||||
|
download_dir=self.download_dir,
|
||||||
|
model_loader_extra_config=self.model_loader_extra_config,
|
||||||
|
ignore_patterns=self.ignore_patterns,
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_engine_config(self) -> EngineConfig:
|
||||||
|
# gguf file needs a specific model loader and doesn't use hf_repo
|
||||||
|
if check_gguf_file(self.model):
|
||||||
|
self.quantization = self.load_format = "gguf"
|
||||||
|
|
||||||
|
# bitsandbytes quantization needs a specific model loader
|
||||||
|
# so we make sure the quant method and the load format are consistent
|
||||||
|
if (self.quantization == "bitsandbytes" or
|
||||||
|
self.qlora_adapter_name_or_path is not None) and \
|
||||||
|
self.load_format != "bitsandbytes":
|
||||||
|
raise ValueError(
|
||||||
|
"BitsAndBytes quantization and QLoRA adapter only support "
|
||||||
|
f"'bitsandbytes' load format, but got {self.load_format}")
|
||||||
|
|
||||||
|
if (self.load_format == "bitsandbytes" or
|
||||||
|
self.qlora_adapter_name_or_path is not None) and \
|
||||||
|
self.quantization != "bitsandbytes":
|
||||||
|
raise ValueError(
|
||||||
|
"BitsAndBytes load format and QLoRA adapter only support "
|
||||||
|
f"'bitsandbytes' quantization, but got {self.quantization}")
|
||||||
|
|
||||||
|
assert self.cpu_offload_gb >= 0, (
|
||||||
|
"CPU offload space must be non-negative"
|
||||||
|
f", but got {self.cpu_offload_gb}")
|
||||||
|
|
||||||
|
device_config = DeviceConfig(device=self.device)
|
||||||
|
model_config = self.create_model_config()
|
||||||
|
|
||||||
cache_config = CacheConfig(
|
cache_config = CacheConfig(
|
||||||
block_size=self.block_size if self.device != "neuron" else
|
block_size=self.block_size if self.device != "neuron" else
|
||||||
self.max_model_len, # neuron needs block_size = max_model_len
|
self.max_model_len, # neuron needs block_size = max_model_len
|
||||||
@ -967,12 +978,7 @@ class EngineArgs:
|
|||||||
self.model_loader_extra_config[
|
self.model_loader_extra_config[
|
||||||
"qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
|
"qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
|
||||||
|
|
||||||
load_config = LoadConfig(
|
load_config = self.create_load_config()
|
||||||
load_format=self.load_format,
|
|
||||||
download_dir=self.download_dir,
|
|
||||||
model_loader_extra_config=self.model_loader_extra_config,
|
|
||||||
ignore_patterns=self.ignore_patterns,
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_adapter_config = PromptAdapterConfig(
|
prompt_adapter_config = PromptAdapterConfig(
|
||||||
max_prompt_adapters=self.max_prompt_adapters,
|
max_prompt_adapters=self.max_prompt_adapters,
|
||||||
|
@ -185,6 +185,11 @@ class BaseModelLoader(ABC):
|
|||||||
def __init__(self, load_config: LoadConfig):
|
def __init__(self, load_config: LoadConfig):
|
||||||
self.load_config = load_config
|
self.load_config = load_config
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def download_model(self, model_config: ModelConfig) -> None:
|
||||||
|
"""Download a model so that it can be immediately loaded."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_model(self, *, model_config: ModelConfig,
|
def load_model(self, *, model_config: ModelConfig,
|
||||||
device_config: DeviceConfig,
|
device_config: DeviceConfig,
|
||||||
@ -193,7 +198,7 @@ class BaseModelLoader(ABC):
|
|||||||
scheduler_config: SchedulerConfig,
|
scheduler_config: SchedulerConfig,
|
||||||
cache_config: CacheConfig) -> nn.Module:
|
cache_config: CacheConfig) -> nn.Module:
|
||||||
"""Load a model with the given configurations."""
|
"""Load a model with the given configurations."""
|
||||||
...
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class DefaultModelLoader(BaseModelLoader):
|
class DefaultModelLoader(BaseModelLoader):
|
||||||
@ -335,6 +340,11 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
weights_iterator = _xla_weights_iterator(weights_iterator)
|
weights_iterator = _xla_weights_iterator(weights_iterator)
|
||||||
return weights_iterator
|
return weights_iterator
|
||||||
|
|
||||||
|
def download_model(self, model_config: ModelConfig) -> None:
|
||||||
|
self._prepare_weights(model_config.model,
|
||||||
|
model_config.revision,
|
||||||
|
fall_back_to_pt=True)
|
||||||
|
|
||||||
def load_model(self, *, model_config: ModelConfig,
|
def load_model(self, *, model_config: ModelConfig,
|
||||||
device_config: DeviceConfig,
|
device_config: DeviceConfig,
|
||||||
lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
@ -377,6 +387,9 @@ class DummyModelLoader(BaseModelLoader):
|
|||||||
raise ValueError(f"Model loader extra config is not supported for "
|
raise ValueError(f"Model loader extra config is not supported for "
|
||||||
f"load format {load_config.load_format}")
|
f"load format {load_config.load_format}")
|
||||||
|
|
||||||
|
def download_model(self, model_config: ModelConfig) -> None:
|
||||||
|
pass # Nothing to download
|
||||||
|
|
||||||
def load_model(self, *, model_config: ModelConfig,
|
def load_model(self, *, model_config: ModelConfig,
|
||||||
device_config: DeviceConfig,
|
device_config: DeviceConfig,
|
||||||
lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
@ -467,6 +480,12 @@ class TensorizerLoader(BaseModelLoader):
|
|||||||
model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
|
model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
|
||||||
return model.eval()
|
return model.eval()
|
||||||
|
|
||||||
|
def download_model(self, model_config: ModelConfig) -> None:
|
||||||
|
self.tensorizer_config.verify_with_model_config(model_config)
|
||||||
|
|
||||||
|
with self.tensorizer_config.open_stream():
|
||||||
|
pass
|
||||||
|
|
||||||
def load_model(self, *, model_config: ModelConfig,
|
def load_model(self, *, model_config: ModelConfig,
|
||||||
device_config: DeviceConfig,
|
device_config: DeviceConfig,
|
||||||
lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
@ -568,6 +587,9 @@ class ShardedStateLoader(BaseModelLoader):
|
|||||||
ignore_patterns=self.load_config.ignore_patterns,
|
ignore_patterns=self.load_config.ignore_patterns,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def download_model(self, model_config: ModelConfig) -> None:
|
||||||
|
self._prepare_weights(model_config.model, model_config.revision)
|
||||||
|
|
||||||
def load_model(self, *, model_config: ModelConfig,
|
def load_model(self, *, model_config: ModelConfig,
|
||||||
device_config: DeviceConfig,
|
device_config: DeviceConfig,
|
||||||
lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
@ -995,6 +1017,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
|||||||
set_weight_attrs(
|
set_weight_attrs(
|
||||||
param, {"matmul_state": [None] * len(quant_states)})
|
param, {"matmul_state": [None] * len(quant_states)})
|
||||||
|
|
||||||
|
def download_model(self, model_config: ModelConfig) -> None:
|
||||||
|
self._prepare_weights(model_config.model, model_config.revision)
|
||||||
|
|
||||||
def load_model(self, *, model_config: ModelConfig,
|
def load_model(self, *, model_config: ModelConfig,
|
||||||
device_config: DeviceConfig,
|
device_config: DeviceConfig,
|
||||||
lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
@ -1070,6 +1095,9 @@ class GGUFModelLoader(BaseModelLoader):
|
|||||||
return gguf_quant_weights_iterator(model_name_or_path,
|
return gguf_quant_weights_iterator(model_name_or_path,
|
||||||
gguf_to_hf_name_map)
|
gguf_to_hf_name_map)
|
||||||
|
|
||||||
|
def download_model(self, model_config: ModelConfig) -> None:
|
||||||
|
self._prepare_weights(model_config.model)
|
||||||
|
|
||||||
def load_model(self, *, model_config: ModelConfig,
|
def load_model(self, *, model_config: ModelConfig,
|
||||||
device_config: DeviceConfig,
|
device_config: DeviceConfig,
|
||||||
lora_config: Optional[LoRAConfig],
|
lora_config: Optional[LoRAConfig],
|
||||||
|
@ -99,6 +99,13 @@ class TensorizerConfig:
|
|||||||
"Loading a model using Tensorizer with quantization on vLLM"
|
"Loading a model using Tensorizer with quantization on vLLM"
|
||||||
" is unstable and may lead to errors.")
|
" is unstable and may lead to errors.")
|
||||||
|
|
||||||
|
def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None):
|
||||||
|
if tensorizer_args is None:
|
||||||
|
tensorizer_args = self._construct_tensorizer_args()
|
||||||
|
|
||||||
|
return open_stream(self.tensorizer_uri,
|
||||||
|
**tensorizer_args.stream_params)
|
||||||
|
|
||||||
|
|
||||||
def load_with_tensorizer(tensorizer_config: TensorizerConfig,
|
def load_with_tensorizer(tensorizer_config: TensorizerConfig,
|
||||||
**extra_kwargs) -> nn.Module:
|
**extra_kwargs) -> nn.Module:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user