[Bugfix][AMD] Update torch_bindings so that scaled_fp4_quant isn't build on ROCm (#13235)
This commit is contained in:
parent
0c73026844
commit
c9f9d5b397
@ -177,6 +177,10 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
|
||||
|
||||
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
|
||||
torch::Tensor& output_scale,
|
||||
torch::Tensor const& input_scale);
|
||||
#endif
|
||||
|
||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
@ -194,10 +198,6 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
||||
|
||||
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
|
||||
|
||||
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
|
||||
torch::Tensor& output_scale,
|
||||
torch::Tensor const& input_scale);
|
||||
|
||||
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor const& scale);
|
||||
|
||||
|
@ -385,6 +385,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"bool silu_activation,"
|
||||
"int pad_slot_id) -> ()");
|
||||
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
|
||||
|
||||
// Compute NVFP4 block quantized tensor.
|
||||
ops.def(
|
||||
"scaled_fp4_quant(Tensor! output, Tensor input,"
|
||||
" Tensor! output_scale, Tensor input_scale) -> ()");
|
||||
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
|
||||
|
||||
#endif
|
||||
|
||||
// Quantized GEMM for GPTQ.
|
||||
@ -421,12 +428,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
|
||||
&dynamic_per_token_scaled_fp8_quant);
|
||||
|
||||
// Compute NVFP4 block quantized tensor.
|
||||
ops.def(
|
||||
"scaled_fp4_quant(Tensor! output, Tensor input,"
|
||||
" Tensor! output_scale, Tensor input_scale) -> ()");
|
||||
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
|
||||
|
||||
// Compute int8 quantized tensor for given scaling factor.
|
||||
ops.def(
|
||||
"static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale,"
|
||||
|
@ -774,6 +774,7 @@ def scaled_fp4_quant(
|
||||
two values are packed into a uint8 and float8_e4m3 scaling factors
|
||||
in the sizzled layout.
|
||||
"""
|
||||
assert not current_platform.is_rocm()
|
||||
assert input.ndim >= 1, (
|
||||
f'input.ndim needs to be >= 1, but got {input.ndim}.')
|
||||
other_dims = 1 if input.ndim == 1 else -1
|
||||
|
Loading…
x
Reference in New Issue
Block a user