[Kernel] Add Exllama as a backend for compressed-tensors (#9395)
This commit is contained in:
parent
dbfa8d31d5
commit
e312e52b44
@ -66,6 +66,7 @@ if TYPE_CHECKING:
|
|||||||
VLLM_SKIP_P2P_CHECK: bool = False
|
VLLM_SKIP_P2P_CHECK: bool = False
|
||||||
VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1: bool = False
|
VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1: bool = False
|
||||||
VLLM_TORCH_COMPILE_LEVEL: int = 0
|
VLLM_TORCH_COMPILE_LEVEL: int = 0
|
||||||
|
VLLM_DISABLED_KERNELS: List[str] = []
|
||||||
|
|
||||||
|
|
||||||
def get_default_cache_root():
|
def get_default_cache_root():
|
||||||
@ -430,6 +431,14 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
"VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1":
|
"VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1":
|
||||||
lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1", "0"
|
lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1", "0"
|
||||||
) == "1",
|
) == "1",
|
||||||
|
|
||||||
|
# List of quantization kernels that should be disabled, used for testing
|
||||||
|
# and performance comparisons. Currently only affects MPLinearKernel
|
||||||
|
# selection
|
||||||
|
# (kernels: MacheteLinearKernel, MarlinLinearKernel, ExllamaLinearKernel)
|
||||||
|
"VLLM_DISABLED_KERNELS":
|
||||||
|
lambda: [] if "VLLM_DISABLED_KERNELS" not in os.environ else os.environ[
|
||||||
|
"VLLM_DISABLED_KERNELS"].split(","),
|
||||||
}
|
}
|
||||||
|
|
||||||
# end-env-vars-definition
|
# end-env-vars-definition
|
||||||
|
@ -42,6 +42,10 @@ class MPLinearKernel(ABC):
|
|||||||
self.config = c
|
self.config = c
|
||||||
self.w_q_name = w_q_param_name
|
self.w_q_name = w_q_param_name
|
||||||
self.w_s_name = w_s_param_name
|
self.w_s_name = w_s_param_name
|
||||||
|
if c.zero_points:
|
||||||
|
assert w_zp_param_name is not None
|
||||||
|
if c.has_g_idx:
|
||||||
|
assert w_gidx_param_name is not None
|
||||||
self.w_zp_name = w_zp_param_name
|
self.w_zp_name = w_zp_param_name
|
||||||
self.w_gidx_name = w_gidx_param_name
|
self.w_gidx_name = w_gidx_param_name
|
||||||
|
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import os
|
|
||||||
from typing import List, Optional, Type
|
from typing import List, Optional, Type
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.model_executor.layers.quantization.kernels.exllama import (
|
||||||
|
ExllamaLinearKernel)
|
||||||
from vllm.model_executor.layers.quantization.kernels.machete import (
|
from vllm.model_executor.layers.quantization.kernels.machete import (
|
||||||
MacheteLinearKernel)
|
MacheteLinearKernel)
|
||||||
from vllm.model_executor.layers.quantization.kernels.marlin import (
|
from vllm.model_executor.layers.quantization.kernels.marlin import (
|
||||||
@ -13,6 +15,7 @@ from vllm.platforms import current_platform
|
|||||||
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
|
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
|
||||||
MacheteLinearKernel,
|
MacheteLinearKernel,
|
||||||
MarlinLinearKernel,
|
MarlinLinearKernel,
|
||||||
|
ExllamaLinearKernel,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -45,8 +48,7 @@ def choose_mp_linear_kernel(
|
|||||||
|
|
||||||
failure_reasons = []
|
failure_reasons = []
|
||||||
for kernel in _POSSIBLE_KERNELS:
|
for kernel in _POSSIBLE_KERNELS:
|
||||||
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\
|
if kernel.__name__ in envs.VLLM_DISABLED_KERNELS:
|
||||||
.split(","):
|
|
||||||
failure_reasons.append(
|
failure_reasons.append(
|
||||||
f' {kernel.__name__} disabled by environment variable')
|
f' {kernel.__name__} disabled by environment variable')
|
||||||
continue
|
continue
|
||||||
|
140
vllm/model_executor/layers/quantization/kernels/exllama.py
Normal file
140
vllm/model_executor/layers/quantization/kernels/exllama.py
Normal file
@ -0,0 +1,140 @@
|
|||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
|
pack_quantized_values_into_int32)
|
||||||
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||||
|
permute_param_layout_)
|
||||||
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
|
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ExllamaLinearKernel(MPLinearKernel):
|
||||||
|
SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||||
|
# In theory supports `scalar_types.uint2b2, scalar_types.uint3b4` too but
|
||||||
|
# currently untested so not added to the list
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
return 60
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def can_implement(cls,
|
||||||
|
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||||
|
if c.has_g_idx and\
|
||||||
|
c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||||
|
return False, "Act reordering currently not supported by Exllama, "\
|
||||||
|
"when the input features are partitioned across "\
|
||||||
|
"devices"
|
||||||
|
|
||||||
|
if c.partition_weight_shape[1] % (32 // c.weight_type.size_bits) != 0:
|
||||||
|
return False, "Output features must be a multiple of the pack " \
|
||||||
|
"factor (32 / num_bits) so that we can correctly " \
|
||||||
|
"pack the zero points"
|
||||||
|
|
||||||
|
if c.act_type != torch.float16:
|
||||||
|
return False, "Exllama only supports float16 activations"
|
||||||
|
|
||||||
|
if c.weight_type not in cls.SUPPORTED_QUANT_TYPES:
|
||||||
|
return False, f"Quant type ({c.weight_type}) not supported by "\
|
||||||
|
"Exllama, supported types are: "\
|
||||||
|
f"{cls.SUPPORTED_QUANT_TYPES}"
|
||||||
|
|
||||||
|
if c.full_weight_shape[0] % c.group_size != 0:
|
||||||
|
return False, f"Group size ({c.group_size}) does not evenly divide"\
|
||||||
|
" the number of input features "\
|
||||||
|
f"({c.full_weight_shape[0]})"
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||||
|
c = self.config
|
||||||
|
|
||||||
|
# For Exllama, we need to set a zero-point tensor if there is not one
|
||||||
|
if not c.zero_points:
|
||||||
|
self.w_zp_name = "qzeros"
|
||||||
|
device = getattr(layer, self.w_q_name).device
|
||||||
|
groups = c.partition_weight_shape[0] // c.group_size
|
||||||
|
out_features = c.partition_weight_shape[1]
|
||||||
|
|
||||||
|
if c.weight_type.has_bias():
|
||||||
|
# if the type has a bias we have to create a zeros tensor that
|
||||||
|
# contains the bias values repeated for each group (-1 due to
|
||||||
|
# a bug in the original GPTQ checkpoint format leading to
|
||||||
|
# exllama kernel adding 1 to the zero points during inference)
|
||||||
|
# Documentation of the bug can be found here:
|
||||||
|
# https://garden.danieldk.eu/GPTQ-Checkpoint-Format
|
||||||
|
zeros = torch.full((groups, out_features),
|
||||||
|
c.weight_type.bias - 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"A 0 zero-point is not supported by Exllama due to "
|
||||||
|
"a bug in the original GPTQ checkpoint format leading to "
|
||||||
|
"exllama kernel adding 1 to the zero points during "
|
||||||
|
"inference")
|
||||||
|
zeros = pack_quantized_values_into_int32(zeros,
|
||||||
|
c.weight_type,
|
||||||
|
packed_dim=1)
|
||||||
|
setattr(layer, self.w_zp_name,
|
||||||
|
torch.nn.Parameter(zeros, requires_grad=False))
|
||||||
|
|
||||||
|
if c.has_g_idx:
|
||||||
|
|
||||||
|
def transform_w_g_idx(x):
|
||||||
|
# Exllama wants the permutation array instead of the group
|
||||||
|
# indices
|
||||||
|
return torch.argsort(x).to(torch.int)
|
||||||
|
|
||||||
|
self._transform_param(layer, self.w_gidx_name, transform_w_g_idx)
|
||||||
|
else:
|
||||||
|
self.w_gidx_name = "g_idx"
|
||||||
|
empty_g_idx = torch.nn.Parameter(torch.empty((0, ),
|
||||||
|
dtype=torch.int,
|
||||||
|
device=device),
|
||||||
|
requires_grad=False)
|
||||||
|
setattr(layer, self.w_gidx_name, empty_g_idx)
|
||||||
|
|
||||||
|
def transform_w_q(x):
|
||||||
|
assert isinstance(x, BasevLLMParameter)
|
||||||
|
assert self.w_gidx_name is not None
|
||||||
|
g_idx = getattr(layer, self.w_gidx_name)
|
||||||
|
|
||||||
|
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||||
|
x_cont = x.data.contiguous()
|
||||||
|
ops.gptq_shuffle(x_cont, g_idx, c.weight_type.size_bits)
|
||||||
|
return x_cont
|
||||||
|
|
||||||
|
def transform_w_s(x):
|
||||||
|
assert isinstance(x, BasevLLMParameter)
|
||||||
|
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||||
|
x.data = x.data.contiguous()
|
||||||
|
return x.to(dtype=c.act_type)
|
||||||
|
|
||||||
|
# Repack weights and scales for Machete
|
||||||
|
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||||
|
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||||
|
|
||||||
|
def apply_weights(self,
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
c = self.config
|
||||||
|
|
||||||
|
x_2d = x.reshape(-1, x.shape[-1])
|
||||||
|
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
|
||||||
|
|
||||||
|
w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer)
|
||||||
|
|
||||||
|
assert w_zp is not None, "Zero points are required by Exllama"
|
||||||
|
assert w_g_idx is not None, "Group index is required by Exllama"
|
||||||
|
output = ops.gptq_gemm(x_2d, w_q, w_zp, w_s, w_g_idx, True,
|
||||||
|
c.weight_type.size_bits)
|
||||||
|
|
||||||
|
if bias is not None:
|
||||||
|
output.add_(bias)
|
||||||
|
return output.reshape(out_shape)
|
@ -8,7 +8,7 @@ from vllm.model_executor.layers.quantization.utils.machete_utils import (
|
|||||||
MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape,
|
MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape,
|
||||||
query_machete_supported_quant_types)
|
query_machete_supported_quant_types)
|
||||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||||
pack_weights_into_int32, unpack_weights_into_int32)
|
pack_quantized_values_into_int32, unpack_quantized_values_into_int32)
|
||||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||||
permute_param_layout_)
|
permute_param_layout_)
|
||||||
|
|
||||||
@ -71,11 +71,11 @@ class MacheteLinearKernel(MPLinearKernel):
|
|||||||
assert isinstance(x, BasevLLMParameter)
|
assert isinstance(x, BasevLLMParameter)
|
||||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||||
if c.has_g_idx:
|
if c.has_g_idx:
|
||||||
x_unpacked = unpack_weights_into_int32(x.data,
|
x_unpacked = unpack_quantized_values_into_int32(x.data,
|
||||||
c.weight_type,
|
c.weight_type,
|
||||||
packed_dim=0)
|
packed_dim=0)
|
||||||
x_perm = x_unpacked[perm, :]
|
x_perm = x_unpacked[perm, :]
|
||||||
x.data = pack_weights_into_int32(x_perm,
|
x.data = pack_quantized_values_into_int32(x_perm,
|
||||||
c.weight_type,
|
c.weight_type,
|
||||||
packed_dim=0)
|
packed_dim=0)
|
||||||
x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
|
x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
|
||||||
|
@ -20,7 +20,7 @@ FUSED_LAYER_NAME_MAPPING = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def pack_weights_into_int32(w_q: torch.Tensor,
|
def pack_quantized_values_into_int32(w_q: torch.Tensor,
|
||||||
wtype: ScalarType,
|
wtype: ScalarType,
|
||||||
packed_dim: int = 0):
|
packed_dim: int = 0):
|
||||||
# move dim to pack to the end
|
# move dim to pack to the end
|
||||||
@ -42,7 +42,7 @@ def pack_weights_into_int32(w_q: torch.Tensor,
|
|||||||
return res.permute(inv_perm)
|
return res.permute(inv_perm)
|
||||||
|
|
||||||
|
|
||||||
def unpack_weights_into_int32(w_q: torch.Tensor,
|
def unpack_quantized_values_into_int32(w_q: torch.Tensor,
|
||||||
wtype: ScalarType,
|
wtype: ScalarType,
|
||||||
packed_dim: int = 0):
|
packed_dim: int = 0):
|
||||||
# move dim to pack to the end
|
# move dim to pack to the end
|
||||||
|
@ -27,6 +27,8 @@ class scalar_types:
|
|||||||
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value)
|
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE.value)
|
||||||
|
|
||||||
# "gptq" types
|
# "gptq" types
|
||||||
|
uint2b2 = ScalarType.uint(2, 2)
|
||||||
|
uint3b4 = ScalarType.uint(3, 4)
|
||||||
uint4b8 = ScalarType.uint(4, 8)
|
uint4b8 = ScalarType.uint(4, 8)
|
||||||
uint8b128 = ScalarType.uint(8, 128)
|
uint8b128 = ScalarType.uint(8, 128)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user