2023-12-14 12:35:58 -05:00
|
|
|
#pragma once
|
|
|
|
|
2023-11-23 16:31:19 -08:00
|
|
|
#include <torch/extension.h>
|
|
|
|
|
|
|
|
void paged_attention_v1(
|
|
|
|
torch::Tensor& out,
|
|
|
|
torch::Tensor& query,
|
|
|
|
torch::Tensor& key_cache,
|
|
|
|
torch::Tensor& value_cache,
|
2023-12-11 02:12:53 +08:00
|
|
|
int num_kv_heads,
|
2023-11-23 16:31:19 -08:00
|
|
|
float scale,
|
|
|
|
torch::Tensor& block_tables,
|
2024-05-04 02:20:12 +09:00
|
|
|
torch::Tensor& seq_lens,
|
2023-11-23 16:31:19 -08:00
|
|
|
int block_size,
|
2024-05-04 02:20:12 +09:00
|
|
|
int max_seq_len,
|
2024-01-29 08:43:54 +08:00
|
|
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
2024-04-03 16:15:55 -05:00
|
|
|
const std::string& kv_cache_dtype,
|
|
|
|
float kv_scale);
|
2023-11-23 16:31:19 -08:00
|
|
|
|
|
|
|
void paged_attention_v2(
|
|
|
|
torch::Tensor& out,
|
|
|
|
torch::Tensor& exp_sums,
|
|
|
|
torch::Tensor& max_logits,
|
|
|
|
torch::Tensor& tmp_out,
|
|
|
|
torch::Tensor& query,
|
|
|
|
torch::Tensor& key_cache,
|
|
|
|
torch::Tensor& value_cache,
|
2023-12-11 02:12:53 +08:00
|
|
|
int num_kv_heads,
|
2023-11-23 16:31:19 -08:00
|
|
|
float scale,
|
|
|
|
torch::Tensor& block_tables,
|
2024-05-04 02:20:12 +09:00
|
|
|
torch::Tensor& seq_lens,
|
2023-11-23 16:31:19 -08:00
|
|
|
int block_size,
|
2024-05-04 02:20:12 +09:00
|
|
|
int max_seq_len,
|
2024-01-29 08:43:54 +08:00
|
|
|
const c10::optional<torch::Tensor>& alibi_slopes,
|
2024-04-03 16:15:55 -05:00
|
|
|
const std::string& kv_cache_dtype,
|
|
|
|
float kv_scale);
|
2023-11-23 16:31:19 -08:00
|
|
|
|
|
|
|
void rms_norm(
|
|
|
|
torch::Tensor& out,
|
|
|
|
torch::Tensor& input,
|
|
|
|
torch::Tensor& weight,
|
|
|
|
float epsilon);
|
|
|
|
|
|
|
|
void fused_add_rms_norm(
|
|
|
|
torch::Tensor& input,
|
|
|
|
torch::Tensor& residual,
|
|
|
|
torch::Tensor& weight,
|
|
|
|
float epsilon);
|
|
|
|
|
|
|
|
void rotary_embedding(
|
|
|
|
torch::Tensor& positions,
|
|
|
|
torch::Tensor& query,
|
|
|
|
torch::Tensor& key,
|
|
|
|
int head_size,
|
|
|
|
torch::Tensor& cos_sin_cache,
|
|
|
|
bool is_neox);
|
|
|
|
|
2024-03-13 13:45:26 -07:00
|
|
|
void batched_rotary_embedding(
|
|
|
|
torch::Tensor& positions,
|
|
|
|
torch::Tensor& query,
|
|
|
|
torch::Tensor& key,
|
|
|
|
int head_size,
|
|
|
|
torch::Tensor& cos_sin_cache,
|
|
|
|
bool is_neox,
|
|
|
|
int rot_dim,
|
|
|
|
torch::Tensor& cos_sin_cache_offsets);
|
|
|
|
|
2023-11-23 16:31:19 -08:00
|
|
|
void silu_and_mul(
|
|
|
|
torch::Tensor& out,
|
|
|
|
torch::Tensor& input);
|
|
|
|
|
2024-02-21 20:17:52 -08:00
|
|
|
void gelu_and_mul(
|
|
|
|
torch::Tensor& out,
|
|
|
|
torch::Tensor& input);
|
|
|
|
|
2024-03-12 22:06:17 -07:00
|
|
|
void gelu_tanh_and_mul(
|
|
|
|
torch::Tensor& out,
|
|
|
|
torch::Tensor& input);
|
|
|
|
|
2023-11-23 16:31:19 -08:00
|
|
|
void gelu_new(
|
|
|
|
torch::Tensor& out,
|
|
|
|
torch::Tensor& input);
|
|
|
|
|
|
|
|
void gelu_fast(
|
|
|
|
torch::Tensor& out,
|
|
|
|
torch::Tensor& input);
|
|
|
|
|
2023-12-08 15:16:52 +08:00
|
|
|
#ifndef USE_ROCM
|
2024-04-23 13:59:33 -04:00
|
|
|
torch::Tensor aqlm_gemm(
|
|
|
|
const torch::Tensor& input,
|
|
|
|
const torch::Tensor& codes,
|
|
|
|
const torch::Tensor& codebooks,
|
|
|
|
const torch::Tensor& scales,
|
|
|
|
const torch::Tensor& codebook_partition_sizes,
|
|
|
|
const std::optional<torch::Tensor>& bias
|
|
|
|
);
|
|
|
|
|
|
|
|
torch::Tensor aqlm_dequant(
|
|
|
|
const torch::Tensor& codes,
|
|
|
|
const torch::Tensor& codebooks,
|
|
|
|
const torch::Tensor& codebook_partition_sizes
|
|
|
|
);
|
|
|
|
|
2023-11-23 16:31:19 -08:00
|
|
|
torch::Tensor awq_gemm(
|
|
|
|
torch::Tensor _in_feats,
|
|
|
|
torch::Tensor _kernel,
|
|
|
|
torch::Tensor _scaling_factors,
|
|
|
|
torch::Tensor _zeros,
|
|
|
|
int split_k_iters);
|
2024-01-27 08:53:17 +01:00
|
|
|
|
|
|
|
torch::Tensor awq_dequantize(
|
|
|
|
torch::Tensor _kernel,
|
|
|
|
torch::Tensor _scaling_factors,
|
|
|
|
torch::Tensor _zeros,
|
|
|
|
int split_k_iters,
|
|
|
|
int thx,
|
|
|
|
int thy);
|
2024-03-01 14:47:51 -06:00
|
|
|
|
|
|
|
torch::Tensor marlin_gemm(
|
|
|
|
torch::Tensor& a,
|
|
|
|
torch::Tensor& b_q_weight,
|
|
|
|
torch::Tensor& b_scales,
|
|
|
|
torch::Tensor& workspace,
|
|
|
|
int64_t size_m,
|
|
|
|
int64_t size_n,
|
|
|
|
int64_t size_k);
|
2024-04-29 12:35:34 -04:00
|
|
|
|
2024-05-16 12:56:15 -04:00
|
|
|
torch::Tensor gptq_marlin_24_gemm(
|
|
|
|
torch::Tensor &a,
|
|
|
|
torch::Tensor &b_q_weight,
|
|
|
|
torch::Tensor &b_meta,
|
|
|
|
torch::Tensor &b_scales,
|
|
|
|
torch::Tensor &workspace,
|
|
|
|
int64_t num_bits,
|
|
|
|
int64_t size_m,
|
|
|
|
int64_t size_n,
|
|
|
|
int64_t size_k);
|
|
|
|
|
2024-04-29 12:35:34 -04:00
|
|
|
torch::Tensor gptq_marlin_gemm(
|
|
|
|
torch::Tensor &a,
|
|
|
|
torch::Tensor &b_q_weight,
|
|
|
|
torch::Tensor &b_scales,
|
|
|
|
torch::Tensor &g_idx,
|
|
|
|
torch::Tensor &perm,
|
|
|
|
torch::Tensor &workspace,
|
2024-05-02 12:56:22 -04:00
|
|
|
int64_t num_bits,
|
2024-04-29 12:35:34 -04:00
|
|
|
int64_t size_m,
|
|
|
|
int64_t size_n,
|
|
|
|
int64_t size_k,
|
|
|
|
bool is_k_full);
|
|
|
|
|
|
|
|
torch::Tensor gptq_marlin_repack(
|
|
|
|
torch::Tensor &b_q_weight,
|
|
|
|
torch::Tensor &perm,
|
|
|
|
int64_t size_k,
|
2024-05-02 12:56:22 -04:00
|
|
|
int64_t size_n,
|
|
|
|
int64_t num_bits);
|
2024-05-16 18:32:50 -04:00
|
|
|
|
|
|
|
int cutlass_scaled_mm_dq(
|
|
|
|
torch::Tensor& out,
|
|
|
|
torch::Tensor const &a,
|
|
|
|
torch::Tensor const &b,
|
|
|
|
torch::Tensor const &a_scales,
|
|
|
|
torch::Tensor const &b_scales);
|
|
|
|
|
2023-12-08 15:16:52 +08:00
|
|
|
#endif
|
2023-11-23 16:31:19 -08:00
|
|
|
|
|
|
|
void squeezellm_gemm(
|
|
|
|
torch::Tensor vec,
|
|
|
|
torch::Tensor mat,
|
|
|
|
torch::Tensor mul,
|
|
|
|
torch::Tensor lookup_table);
|
2023-12-15 19:04:22 +08:00
|
|
|
|
|
|
|
torch::Tensor gptq_gemm(
|
|
|
|
torch::Tensor a,
|
|
|
|
torch::Tensor b_q_weight,
|
|
|
|
torch::Tensor b_gptq_qzeros,
|
|
|
|
torch::Tensor b_gptq_scales,
|
|
|
|
torch::Tensor b_g_idx,
|
2024-02-29 13:52:23 +08:00
|
|
|
bool use_exllama,
|
|
|
|
int bit);
|
2023-12-15 19:04:22 +08:00
|
|
|
|
|
|
|
void gptq_shuffle(
|
|
|
|
torch::Tensor q_weight,
|
2024-02-29 13:52:23 +08:00
|
|
|
torch::Tensor q_perm,
|
|
|
|
int bit);
|
2024-01-28 04:46:35 +08:00
|
|
|
|
2024-04-26 21:49:59 -07:00
|
|
|
void static_scaled_fp8_quant(
|
|
|
|
torch::Tensor& out,
|
|
|
|
torch::Tensor& input,
|
|
|
|
torch::Tensor& scale);
|
|
|
|
|
|
|
|
void dynamic_scaled_fp8_quant(
|
[Kernel] FP8 support for MoE kernel / Mixtral (#4244)
This PR is the first step towards fixing https://github.com/vllm-project/vllm/pull/3208
It implements dynamic per-tensor scaling (see https://github.com/vllm-project/vllm/pull/4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in https://github.com/vllm-project/vllm/pull/3954). With this PR, the results are as follows:
<img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">
**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7018|± |0.0036|
| - humanities |N/A |none | 5|acc |0.6472|± |0.0065|
| - other |N/A |none | 5|acc |0.7673|± |0.0072|
| - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070|
| - stem |N/A |none | 5|acc |0.6131|± |0.0083|
```
this compares favorably with the fp16 results which are
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7020|± |0.1313|
| - humanities |N/A |none | 5|acc |0.6425|± |0.1349|
| - other |N/A |none | 5|acc |0.7744|± |0.1038|
| - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695|
| - stem |N/A |none | 5|acc |0.6108|± |0.1383|
```
Happy hacking!
2024-04-23 18:18:23 -07:00
|
|
|
torch::Tensor& out,
|
|
|
|
torch::Tensor& input,
|
|
|
|
torch::Tensor& scale);
|
|
|
|
|
2024-01-29 22:43:37 -08:00
|
|
|
void moe_align_block_size(
|
|
|
|
torch::Tensor topk_ids,
|
|
|
|
int num_experts,
|
|
|
|
int block_size,
|
|
|
|
torch::Tensor sorted_token_ids,
|
|
|
|
torch::Tensor experts_ids,
|
|
|
|
torch::Tensor num_tokens_post_pad);
|
2024-01-28 04:46:35 +08:00
|
|
|
|
|
|
|
#ifndef USE_ROCM
|
|
|
|
using fptr_t = uint64_t;
|
|
|
|
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
|
|
|
|
const std::vector<std::string> &handles,
|
|
|
|
const std::vector<int64_t> &offsets, int rank,
|
|
|
|
bool full_nvlink);
|
|
|
|
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
|
|
|
|
bool full_nvlink);
|
|
|
|
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out);
|
|
|
|
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
|
|
|
|
torch::Tensor &out);
|
|
|
|
void dispose(fptr_t _fa);
|
|
|
|
int meta_size();
|
|
|
|
void register_buffer(fptr_t _fa, torch::Tensor &t,
|
|
|
|
const std::vector<std::string> &handles,
|
|
|
|
const std::vector<int64_t> &offsets);
|
|
|
|
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
|
|
|
|
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
|
|
|
|
const std::vector<std::vector<int64_t>> &offsets);
|
|
|
|
#endif
|