# SPDX-License-Identifier: Apache-2.0 import torch def native_w8a8_block_matmul(A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, Bs: torch.Tensor, block_size, output_dtype): """This function performs matrix multiplication with block-wise quantization using native torch. It is agnostic to the input data type and can be used for both int8 and fp8 data types. It takes two input tensors `A` and `B` (int8) with scales `As` and `Bs` (float32). The output is returned in the specified `output_dtype`. """ A = A.to(torch.float32) B = B.to(torch.float32) assert A.shape[-1] == B.shape[-1] assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 assert len(block_size) == 2 block_n, block_k = block_size[0], block_size[1] assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] assert A.shape[:-1] == As.shape[:-1] M = A.numel() // A.shape[-1] N, K = B.shape origin_C_shape = A.shape[:-1] + (N, ) A = A.reshape(M, A.shape[-1]) As = As.reshape(M, As.shape[-1]) n_tiles = (N + block_n - 1) // block_n k_tiles = (K + block_k - 1) // block_k assert n_tiles == Bs.shape[0] assert k_tiles == Bs.shape[1] C_shape = (M, N) C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) A_tiles = [ A[:, i * block_k:min((i + 1) * block_k, K)] for i in range(k_tiles) ] B_tiles = [[ B[ j * block_n:min((j + 1) * block_n, N), i * block_k:min((i + 1) * block_k, K), ] for i in range(k_tiles) ] for j in range(n_tiles)] C_tiles = [ C[:, j * block_n:min((j + 1) * block_n, N)] for j in range(n_tiles) ] As_tiles = [As[:, i:i + 1] for i in range(k_tiles)] for i in range(k_tiles): for j in range(n_tiles): a = A_tiles[i] b = B_tiles[j][i] c = C_tiles[j] s = As_tiles[i] * Bs[j][i] c[:, :] += torch.matmul(a, b.t()) * s C = C.reshape(origin_C_shape).to(output_dtype) return C