[Bugfix] Marlin 2:4 temp fix for large M dim (>256) (#10464)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
parent
d5b68aba2f
commit
d200972e7f
@ -910,13 +910,16 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C,
|
|||||||
// than better compute utilization
|
// than better compute utilization
|
||||||
thread_k = 128;
|
thread_k = 128;
|
||||||
thread_m = 128;
|
thread_m = 128;
|
||||||
} else if (prob_n <= 256) {
|
} else {
|
||||||
thread_k = 64;
|
thread_k = 64;
|
||||||
thread_m = 256;
|
thread_m = 256;
|
||||||
} else {
|
|
||||||
thread_k = 32;
|
|
||||||
thread_m = 512;
|
|
||||||
}
|
}
|
||||||
|
// Also had
|
||||||
|
// if prob_n > 256
|
||||||
|
// thread_k = 32;
|
||||||
|
// thread_m = 512;
|
||||||
|
// but this is broken,
|
||||||
|
// TODO(Lucas, Alex M): figure out why
|
||||||
}
|
}
|
||||||
|
|
||||||
int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction
|
int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction
|
||||||
@ -1079,6 +1082,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
// Verify A device and strides
|
// Verify A device and strides
|
||||||
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
|
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
|
||||||
TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
|
TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
|
||||||
|
TORCH_CHECK(a.dtype() == torch::kFloat16,
|
||||||
|
"A is not float16, currently only float16 is supported");
|
||||||
|
|
||||||
// Verify B device and strides
|
// Verify B device and strides
|
||||||
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
||||||
@ -1091,6 +1096,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
|||||||
// Verify scales device and strides
|
// Verify scales device and strides
|
||||||
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
|
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
|
||||||
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
|
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
|
||||||
|
TORCH_CHECK(b_scales.dtype() == torch::kFloat16,
|
||||||
|
"A is not float16, currently only float16 is supported");
|
||||||
|
|
||||||
// Alloc C matrix
|
// Alloc C matrix
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
|
||||||
|
@ -50,6 +50,8 @@ MNK_FACTORS = [
|
|||||||
(13, 17, 67),
|
(13, 17, 67),
|
||||||
(26, 37, 13),
|
(26, 37, 13),
|
||||||
(67, 13, 11),
|
(67, 13, 11),
|
||||||
|
(257, 13, 11),
|
||||||
|
(658, 13, 11),
|
||||||
]
|
]
|
||||||
|
|
||||||
DTYPES = [torch.float16, torch.bfloat16]
|
DTYPES = [torch.float16, torch.bfloat16]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user