[Kernel] FP8 support for MoE kernel / Mixtral (#4244)
This PR is the first step towards fixing https://github.com/vllm-project/vllm/pull/3208
It implements dynamic per-tensor scaling (see https://github.com/vllm-project/vllm/pull/4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in https://github.com/vllm-project/vllm/pull/3954). With this PR, the results are as follows:
<img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">
**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7018|± |0.0036|
| - humanities |N/A |none | 5|acc |0.6472|± |0.0065|
| - other |N/A |none | 5|acc |0.7673|± |0.0072|
| - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070|
| - stem |N/A |none | 5|acc |0.6131|± |0.0083|
```
this compares favorably with the fp16 results which are
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7020|± |0.1313|
| - humanities |N/A |none | 5|acc |0.6425|± |0.1349|
| - other |N/A |none | 5|acc |0.7744|± |0.1038|
| - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695|
| - stem |N/A |none | 5|acc |0.6108|± |0.1383|
```
Happy hacking!
2024-04-23 18:18:23 -07:00
|
|
|
#include <ATen/cuda/CUDAContext.h>
|
2024-06-09 16:23:30 -04:00
|
|
|
#include <torch/all.h>
|
[Kernel] FP8 support for MoE kernel / Mixtral (#4244)
This PR is the first step towards fixing https://github.com/vllm-project/vllm/pull/3208
It implements dynamic per-tensor scaling (see https://github.com/vllm-project/vllm/pull/4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in https://github.com/vllm-project/vllm/pull/3954). With this PR, the results are as follows:
<img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">
**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7018|± |0.0036|
| - humanities |N/A |none | 5|acc |0.6472|± |0.0065|
| - other |N/A |none | 5|acc |0.7673|± |0.0072|
| - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070|
| - stem |N/A |none | 5|acc |0.6131|± |0.0083|
```
this compares favorably with the fp16 results which are
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7020|± |0.1313|
| - humanities |N/A |none | 5|acc |0.6425|± |0.1349|
| - other |N/A |none | 5|acc |0.7744|± |0.1038|
| - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695|
| - stem |N/A |none | 5|acc |0.6108|± |0.1383|
```
Happy hacking!
2024-04-23 18:18:23 -07:00
|
|
|
#include <c10/cuda/CUDAGuard.h>
|
|
|
|
|
|
|
|
#include <cmath>
|
|
|
|
|
|
|
|
#include "cuda_compat.h"
|
|
|
|
#include "dispatch_utils.h"
|
|
|
|
|
2024-07-17 21:38:35 -04:00
|
|
|
#include "../../reduction_utils.cuh"
|
|
|
|
|
2024-08-16 12:06:30 -05:00
|
|
|
#ifndef USE_ROCM
|
|
|
|
using FP8_TYPE = c10::Float8_e4m3fn;
|
|
|
|
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
|
|
|
|
std::numeric_limits<FP8_TYPE>::max();
|
|
|
|
#else
|
|
|
|
#include "amd/hip_float8.h"
|
|
|
|
using FP8_TYPE = c10::Float8_e4m3fnuz;
|
|
|
|
// Using the default max value from pytorch (240.0) will cause accuracy
|
|
|
|
// issue when running dynamic quantization. Here use 224.0f for rocm.
|
|
|
|
constexpr auto FP8_E4M3_MAX = 224.0f;
|
|
|
|
#endif
|
|
|
|
|
[Kernel] FP8 support for MoE kernel / Mixtral (#4244)
This PR is the first step towards fixing https://github.com/vllm-project/vllm/pull/3208
It implements dynamic per-tensor scaling (see https://github.com/vllm-project/vllm/pull/4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in https://github.com/vllm-project/vllm/pull/3954). With this PR, the results are as follows:
<img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">
**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7018|± |0.0036|
| - humanities |N/A |none | 5|acc |0.6472|± |0.0065|
| - other |N/A |none | 5|acc |0.7673|± |0.0072|
| - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070|
| - stem |N/A |none | 5|acc |0.6131|± |0.0083|
```
this compares favorably with the fp16 results which are
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7020|± |0.1313|
| - humanities |N/A |none | 5|acc |0.6425|± |0.1349|
| - other |N/A |none | 5|acc |0.7744|± |0.1038|
| - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695|
| - stem |N/A |none | 5|acc |0.6108|± |0.1383|
```
Happy hacking!
2024-04-23 18:18:23 -07:00
|
|
|
namespace vllm {
|
|
|
|
|
|
|
|
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
2024-05-22 03:18:41 -04:00
|
|
|
float old;
|
|
|
|
old = (value >= 0)
|
|
|
|
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
|
|
|
|
: __uint_as_float(
|
|
|
|
atomicMin((unsigned int*)addr, __float_as_uint(value)));
|
[Kernel] FP8 support for MoE kernel / Mixtral (#4244)
This PR is the first step towards fixing https://github.com/vllm-project/vllm/pull/3208
It implements dynamic per-tensor scaling (see https://github.com/vllm-project/vllm/pull/4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in https://github.com/vllm-project/vllm/pull/3954). With this PR, the results are as follows:
<img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">
**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7018|± |0.0036|
| - humanities |N/A |none | 5|acc |0.6472|± |0.0065|
| - other |N/A |none | 5|acc |0.7673|± |0.0072|
| - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070|
| - stem |N/A |none | 5|acc |0.6131|± |0.0083|
```
this compares favorably with the fp16 results which are
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7020|± |0.1313|
| - humanities |N/A |none | 5|acc |0.6425|± |0.1349|
| - other |N/A |none | 5|acc |0.7744|± |0.1038|
| - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695|
| - stem |N/A |none | 5|acc |0.6108|± |0.1383|
```
Happy hacking!
2024-04-23 18:18:23 -07:00
|
|
|
|
2024-05-22 03:18:41 -04:00
|
|
|
return old;
|
[Kernel] FP8 support for MoE kernel / Mixtral (#4244)
This PR is the first step towards fixing https://github.com/vllm-project/vllm/pull/3208
It implements dynamic per-tensor scaling (see https://github.com/vllm-project/vllm/pull/4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in https://github.com/vllm-project/vllm/pull/3954). With this PR, the results are as follows:
<img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">
**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7018|± |0.0036|
| - humanities |N/A |none | 5|acc |0.6472|± |0.0065|
| - other |N/A |none | 5|acc |0.7673|± |0.0072|
| - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070|
| - stem |N/A |none | 5|acc |0.6131|± |0.0083|
```
this compares favorably with the fp16 results which are
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7020|± |0.1313|
| - humanities |N/A |none | 5|acc |0.6425|± |0.1349|
| - other |N/A |none | 5|acc |0.7744|± |0.1038|
| - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695|
| - stem |N/A |none | 5|acc |0.6108|± |0.1383|
```
Happy hacking!
2024-04-23 18:18:23 -07:00
|
|
|
}
|
|
|
|
|
2024-07-19 21:15:26 -04:00
|
|
|
template <bool is_scale_inverted>
|
2024-08-16 12:06:30 -05:00
|
|
|
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
|
|
|
|
float const scale) {
|
2024-07-19 21:15:26 -04:00
|
|
|
float x = 0.0f;
|
|
|
|
if constexpr (is_scale_inverted) {
|
|
|
|
x = val * scale;
|
|
|
|
} else {
|
|
|
|
x = val / scale;
|
|
|
|
}
|
|
|
|
|
2024-05-06 17:39:28 -07:00
|
|
|
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
2024-08-16 12:06:30 -05:00
|
|
|
#ifndef USE_ROCM
|
2024-05-06 17:39:28 -07:00
|
|
|
return static_cast<c10::Float8_e4m3fn>(r);
|
2024-08-16 12:06:30 -05:00
|
|
|
#else
|
|
|
|
// Use hardware cvt instruction for fp8 on rocm
|
|
|
|
return c10::Float8_e4m3fnuz(hip_fp8(r).data,
|
|
|
|
c10::Float8_e4m3fnuz::from_bits());
|
|
|
|
#endif
|
2024-05-06 17:39:28 -07:00
|
|
|
}
|
|
|
|
|
[Kernel] FP8 support for MoE kernel / Mixtral (#4244)
This PR is the first step towards fixing https://github.com/vllm-project/vllm/pull/3208
It implements dynamic per-tensor scaling (see https://github.com/vllm-project/vllm/pull/4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in https://github.com/vllm-project/vllm/pull/3954). With this PR, the results are as follows:
<img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">
**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7018|± |0.0036|
| - humanities |N/A |none | 5|acc |0.6472|± |0.0065|
| - other |N/A |none | 5|acc |0.7673|± |0.0072|
| - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070|
| - stem |N/A |none | 5|acc |0.6131|± |0.0083|
```
this compares favorably with the fp16 results which are
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7020|± |0.1313|
| - humanities |N/A |none | 5|acc |0.6425|± |0.1349|
| - other |N/A |none | 5|acc |0.7744|± |0.1038|
| - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695|
| - stem |N/A |none | 5|acc |0.6108|± |0.1383|
```
Happy hacking!
2024-04-23 18:18:23 -07:00
|
|
|
// Compute the absolute maximum m of the input tensor and store
|
|
|
|
// m / float8_e4m3::max() in *scale. Each thread block performs a
|
|
|
|
// reduction tree and the memory in scale is atomically updated.
|
|
|
|
// So to get the right answer, *scale needs to be initialized to
|
|
|
|
// a value <= 0.0 and we need to wait for all thread blocks to
|
|
|
|
// finish before consuming *scale.
|
2024-05-22 03:18:41 -04:00
|
|
|
template <typename scalar_t>
|
|
|
|
__global__ void segmented_max_reduction(float* __restrict__ scale,
|
|
|
|
const scalar_t* __restrict__ input,
|
|
|
|
int64_t num_elems) {
|
[Kernel] FP8 support for MoE kernel / Mixtral (#4244)
This PR is the first step towards fixing https://github.com/vllm-project/vllm/pull/3208
It implements dynamic per-tensor scaling (see https://github.com/vllm-project/vllm/pull/4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in https://github.com/vllm-project/vllm/pull/3954). With this PR, the results are as follows:
<img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">
**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7018|± |0.0036|
| - humanities |N/A |none | 5|acc |0.6472|± |0.0065|
| - other |N/A |none | 5|acc |0.7673|± |0.0072|
| - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070|
| - stem |N/A |none | 5|acc |0.6131|± |0.0083|
```
this compares favorably with the fp16 results which are
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7020|± |0.1313|
| - humanities |N/A |none | 5|acc |0.6425|± |0.1349|
| - other |N/A |none | 5|acc |0.7744|± |0.1038|
| - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695|
| - stem |N/A |none | 5|acc |0.6108|± |0.1383|
```
Happy hacking!
2024-04-23 18:18:23 -07:00
|
|
|
__shared__ float cache[1024];
|
2024-07-26 14:41:04 -04:00
|
|
|
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
|
[Kernel] FP8 support for MoE kernel / Mixtral (#4244)
This PR is the first step towards fixing https://github.com/vllm-project/vllm/pull/3208
It implements dynamic per-tensor scaling (see https://github.com/vllm-project/vllm/pull/4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in https://github.com/vllm-project/vllm/pull/3954). With this PR, the results are as follows:
<img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">
**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7018|± |0.0036|
| - humanities |N/A |none | 5|acc |0.6472|± |0.0065|
| - other |N/A |none | 5|acc |0.7673|± |0.0072|
| - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070|
| - stem |N/A |none | 5|acc |0.6131|± |0.0083|
```
this compares favorably with the fp16 results which are
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7020|± |0.1313|
| - humanities |N/A |none | 5|acc |0.6425|± |0.1349|
| - other |N/A |none | 5|acc |0.7744|± |0.1038|
| - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695|
| - stem |N/A |none | 5|acc |0.6108|± |0.1383|
```
Happy hacking!
2024-04-23 18:18:23 -07:00
|
|
|
|
|
|
|
// First store maximum for all values processes by
|
|
|
|
// the current thread in cache[threadIdx.x]
|
|
|
|
scalar_t tmp = 0.0;
|
|
|
|
while (i < num_elems) {
|
|
|
|
float x = static_cast<float>(input[i]);
|
|
|
|
tmp = max(tmp, fabs(x));
|
|
|
|
i += blockDim.x * gridDim.x;
|
|
|
|
}
|
|
|
|
cache[threadIdx.x] = tmp;
|
|
|
|
|
|
|
|
__syncthreads();
|
|
|
|
|
|
|
|
// Now perform parallel reduction within the thread block
|
|
|
|
int ib = blockDim.x / 2;
|
|
|
|
while (ib != 0) {
|
|
|
|
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
|
2024-05-22 03:18:41 -04:00
|
|
|
cache[threadIdx.x] = cache[threadIdx.x + ib];
|
[Kernel] FP8 support for MoE kernel / Mixtral (#4244)
This PR is the first step towards fixing https://github.com/vllm-project/vllm/pull/3208
It implements dynamic per-tensor scaling (see https://github.com/vllm-project/vllm/pull/4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in https://github.com/vllm-project/vllm/pull/3954). With this PR, the results are as follows:
<img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">
**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7018|± |0.0036|
| - humanities |N/A |none | 5|acc |0.6472|± |0.0065|
| - other |N/A |none | 5|acc |0.7673|± |0.0072|
| - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070|
| - stem |N/A |none | 5|acc |0.6131|± |0.0083|
```
this compares favorably with the fp16 results which are
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7020|± |0.1313|
| - humanities |N/A |none | 5|acc |0.6425|± |0.1349|
| - other |N/A |none | 5|acc |0.7744|± |0.1038|
| - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695|
| - stem |N/A |none | 5|acc |0.6108|± |0.1383|
```
Happy hacking!
2024-04-23 18:18:23 -07:00
|
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
ib /= 2;
|
|
|
|
}
|
|
|
|
// Finally, since cache[0] contains the maximum for this thread block,
|
|
|
|
// atomically write the max to the target location
|
|
|
|
if (threadIdx.x == 0) {
|
2024-08-16 12:06:30 -05:00
|
|
|
atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
|
[Kernel] FP8 support for MoE kernel / Mixtral (#4244)
This PR is the first step towards fixing https://github.com/vllm-project/vllm/pull/3208
It implements dynamic per-tensor scaling (see https://github.com/vllm-project/vllm/pull/4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in https://github.com/vllm-project/vllm/pull/3954). With this PR, the results are as follows:
<img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">
**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7018|± |0.0036|
| - humanities |N/A |none | 5|acc |0.6472|± |0.0065|
| - other |N/A |none | 5|acc |0.7673|± |0.0072|
| - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070|
| - stem |N/A |none | 5|acc |0.6131|± |0.0083|
```
this compares favorably with the fp16 results which are
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7020|± |0.1313|
| - humanities |N/A |none | 5|acc |0.6425|± |0.1349|
| - other |N/A |none | 5|acc |0.7744|± |0.1038|
| - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695|
| - stem |N/A |none | 5|acc |0.6108|± |0.1383|
```
Happy hacking!
2024-04-23 18:18:23 -07:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-06-12 14:07:26 -07:00
|
|
|
template <typename scalar_t>
|
|
|
|
struct __align__(8) vec4_t {
|
|
|
|
scalar_t x;
|
|
|
|
scalar_t y;
|
|
|
|
scalar_t z;
|
|
|
|
scalar_t w;
|
|
|
|
};
|
|
|
|
|
|
|
|
typedef struct __align__(4) {
|
2024-08-16 12:06:30 -05:00
|
|
|
FP8_TYPE x;
|
|
|
|
FP8_TYPE y;
|
|
|
|
FP8_TYPE z;
|
|
|
|
FP8_TYPE w;
|
2024-06-12 14:07:26 -07:00
|
|
|
}
|
|
|
|
float8x4_t;
|
|
|
|
|
2024-05-22 03:18:41 -04:00
|
|
|
template <typename scalar_t>
|
2024-07-17 21:38:35 -04:00
|
|
|
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
|
|
|
int64_t const num_elems, int const tid,
|
|
|
|
int const step) {
|
|
|
|
// Vectorized input/output to better utilize memory bandwidth.
|
|
|
|
vec4_t<scalar_t> const* vectorized_in =
|
|
|
|
reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
2024-06-12 14:07:26 -07:00
|
|
|
|
2024-07-22 16:08:30 -04:00
|
|
|
int64_t const num_vec_elems = num_elems >> 2;
|
2024-07-17 21:38:35 -04:00
|
|
|
float absmax_val = 0.0f;
|
|
|
|
|
|
|
|
#pragma unroll 4
|
2024-07-22 16:08:30 -04:00
|
|
|
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
2024-07-17 21:38:35 -04:00
|
|
|
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
|
|
|
absmax_val = max(absmax_val, fabs(in_vec.x));
|
|
|
|
absmax_val = max(absmax_val, fabs(in_vec.y));
|
|
|
|
absmax_val = max(absmax_val, fabs(in_vec.z));
|
|
|
|
absmax_val = max(absmax_val, fabs(in_vec.w));
|
|
|
|
}
|
2024-06-12 14:07:26 -07:00
|
|
|
|
2024-07-17 21:38:35 -04:00
|
|
|
// Handle the remaining elements if num_elems is not divisible by 4
|
2024-07-22 16:08:30 -04:00
|
|
|
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
2024-07-17 21:38:35 -04:00
|
|
|
absmax_val = max(absmax_val, fabs(input[i]));
|
|
|
|
}
|
|
|
|
|
|
|
|
return absmax_val;
|
|
|
|
}
|
|
|
|
|
2024-07-19 21:15:26 -04:00
|
|
|
template <typename scalar_t, bool is_scale_inverted>
|
2024-08-16 12:06:30 -05:00
|
|
|
__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
|
2024-07-17 21:38:35 -04:00
|
|
|
scalar_t const* __restrict__ input,
|
2024-07-19 21:15:26 -04:00
|
|
|
float const scale,
|
2024-07-17 21:38:35 -04:00
|
|
|
int64_t const num_elems,
|
|
|
|
int const tid, int const step) {
|
2024-06-12 14:07:26 -07:00
|
|
|
// Vectorized input/output to better utilize memory bandwidth.
|
2024-07-17 21:38:35 -04:00
|
|
|
vec4_t<scalar_t> const* vectorized_in =
|
|
|
|
reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
2024-06-12 14:07:26 -07:00
|
|
|
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
|
|
|
|
|
2024-07-22 16:08:30 -04:00
|
|
|
int64_t const num_vec_elems = num_elems >> 2;
|
2024-06-12 14:07:26 -07:00
|
|
|
|
|
|
|
#pragma unroll 4
|
2024-07-22 16:08:30 -04:00
|
|
|
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
2024-06-12 14:07:26 -07:00
|
|
|
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
|
|
|
float8x4_t out_vec;
|
|
|
|
|
2024-07-19 21:15:26 -04:00
|
|
|
out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
|
|
|
|
static_cast<float>(in_vec.x), scale);
|
|
|
|
out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
|
|
|
|
static_cast<float>(in_vec.y), scale);
|
|
|
|
out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
|
|
|
|
static_cast<float>(in_vec.z), scale);
|
|
|
|
out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
|
|
|
|
static_cast<float>(in_vec.w), scale);
|
2024-06-12 14:07:26 -07:00
|
|
|
vectorized_out[i] = out_vec;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Handle the remaining elements if num_elems is not divisible by 4
|
2024-07-22 16:08:30 -04:00
|
|
|
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
2024-07-19 21:15:26 -04:00
|
|
|
out[i] = scaled_fp8_conversion<is_scale_inverted>(
|
|
|
|
static_cast<float>(input[i]), scale);
|
[Kernel] FP8 support for MoE kernel / Mixtral (#4244)
This PR is the first step towards fixing https://github.com/vllm-project/vllm/pull/3208
It implements dynamic per-tensor scaling (see https://github.com/vllm-project/vllm/pull/4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in https://github.com/vllm-project/vllm/pull/3954). With this PR, the results are as follows:
<img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">
**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7018|± |0.0036|
| - humanities |N/A |none | 5|acc |0.6472|± |0.0065|
| - other |N/A |none | 5|acc |0.7673|± |0.0072|
| - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070|
| - stem |N/A |none | 5|acc |0.6131|± |0.0083|
```
this compares favorably with the fp16 results which are
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7020|± |0.1313|
| - humanities |N/A |none | 5|acc |0.6425|± |0.1349|
| - other |N/A |none | 5|acc |0.7744|± |0.1038|
| - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695|
| - stem |N/A |none | 5|acc |0.6108|± |0.1383|
```
Happy hacking!
2024-04-23 18:18:23 -07:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-07-17 21:38:35 -04:00
|
|
|
template <typename scalar_t>
|
2024-08-16 12:06:30 -05:00
|
|
|
__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
|
2024-07-17 21:38:35 -04:00
|
|
|
const scalar_t* __restrict__ input,
|
|
|
|
const float* __restrict__ scale,
|
|
|
|
int64_t num_elems) {
|
|
|
|
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
|
|
|
|
|
|
|
// Invert the scale so that we can use multiplications to avoid expensive
|
|
|
|
// division.
|
|
|
|
const float inverted_scale = 1.0f / (*scale);
|
2024-07-19 21:15:26 -04:00
|
|
|
scaled_fp8_conversion_vec<scalar_t, true>(
|
|
|
|
out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
|
2024-07-17 21:38:35 -04:00
|
|
|
}
|
|
|
|
|
|
|
|
template <typename scalar_t>
|
|
|
|
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
2024-08-16 12:06:30 -05:00
|
|
|
FP8_TYPE* __restrict__ out, float* __restrict__ scale,
|
2024-07-19 21:15:26 -04:00
|
|
|
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
|
|
|
|
const int hidden_size) {
|
|
|
|
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
|
|
|
|
|
2024-07-17 21:38:35 -04:00
|
|
|
int const tid = threadIdx.x;
|
|
|
|
int const token_idx = blockIdx.x;
|
|
|
|
|
|
|
|
scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size];
|
2024-08-16 12:06:30 -05:00
|
|
|
FP8_TYPE* __restrict__ token_output = &out[token_idx * hidden_size];
|
2024-07-17 21:38:35 -04:00
|
|
|
|
|
|
|
// For vectorization, token_input and token_output pointers need to be
|
|
|
|
// aligned at 8-byte and 4-byte addresses respectively.
|
|
|
|
bool const can_vectorize = hidden_size % 4 == 0;
|
|
|
|
|
|
|
|
float absmax_val = 0.0f;
|
|
|
|
if (can_vectorize) {
|
|
|
|
absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x);
|
|
|
|
} else {
|
|
|
|
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
|
|
|
float const x = static_cast<float>(token_input[i]);
|
|
|
|
absmax_val = max(absmax_val, fabs(x));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
float const block_absmax_val_maybe = blockReduceMax(absmax_val);
|
2024-07-19 21:15:26 -04:00
|
|
|
__shared__ float token_scale;
|
2024-07-17 21:38:35 -04:00
|
|
|
if (tid == 0) {
|
2024-07-19 21:15:26 -04:00
|
|
|
if (scale_ub) {
|
|
|
|
token_scale = min(block_absmax_val_maybe, *scale_ub);
|
|
|
|
} else {
|
|
|
|
token_scale = block_absmax_val_maybe;
|
|
|
|
}
|
|
|
|
// token scale computation
|
|
|
|
token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
|
|
|
|
scale[token_idx] = token_scale;
|
2024-07-17 21:38:35 -04:00
|
|
|
}
|
|
|
|
__syncthreads();
|
|
|
|
|
2024-07-19 21:15:26 -04:00
|
|
|
// Note that we don't use inverted scales so we can match FBGemm impl.
|
2024-07-17 21:38:35 -04:00
|
|
|
if (can_vectorize) {
|
2024-07-19 21:15:26 -04:00
|
|
|
scaled_fp8_conversion_vec<scalar_t, false>(
|
|
|
|
token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
|
2024-07-17 21:38:35 -04:00
|
|
|
} else {
|
|
|
|
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
2024-07-19 21:15:26 -04:00
|
|
|
token_output[i] = scaled_fp8_conversion<false>(
|
|
|
|
static_cast<float>(token_input[i]), token_scale);
|
2024-07-17 21:38:35 -04:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2024-05-22 03:18:41 -04:00
|
|
|
} // namespace vllm
|
[Kernel] FP8 support for MoE kernel / Mixtral (#4244)
This PR is the first step towards fixing https://github.com/vllm-project/vllm/pull/3208
It implements dynamic per-tensor scaling (see https://github.com/vllm-project/vllm/pull/4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in https://github.com/vllm-project/vllm/pull/3954). With this PR, the results are as follows:
<img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">
**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7018|± |0.0036|
| - humanities |N/A |none | 5|acc |0.6472|± |0.0065|
| - other |N/A |none | 5|acc |0.7673|± |0.0072|
| - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070|
| - stem |N/A |none | 5|acc |0.6131|± |0.0083|
```
this compares favorably with the fp16 results which are
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7020|± |0.1313|
| - humanities |N/A |none | 5|acc |0.6425|± |0.1349|
| - other |N/A |none | 5|acc |0.7744|± |0.1038|
| - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695|
| - stem |N/A |none | 5|acc |0.6108|± |0.1383|
```
Happy hacking!
2024-04-23 18:18:23 -07:00
|
|
|
|
2024-07-17 21:38:35 -04:00
|
|
|
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
|
|
|
torch::Tensor const& input, // [..., d]
|
|
|
|
torch::Tensor const& scale) // [1]
|
2024-04-26 21:49:59 -07:00
|
|
|
{
|
|
|
|
int64_t num_tokens = input.numel() / input.size(-1);
|
|
|
|
int64_t num_elems = input.numel();
|
|
|
|
dim3 grid(num_tokens);
|
|
|
|
dim3 block(1024);
|
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
VLLM_DISPATCH_FLOATING_TYPES(
|
2024-05-22 03:18:41 -04:00
|
|
|
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
|
|
|
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
2024-08-16 12:06:30 -05:00
|
|
|
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
|
2024-05-22 03:18:41 -04:00
|
|
|
scale.data_ptr<float>(), num_elems);
|
2024-04-26 21:49:59 -07:00
|
|
|
});
|
|
|
|
}
|
|
|
|
|
2024-07-17 21:38:35 -04:00
|
|
|
void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
|
|
|
torch::Tensor const& input, // [..., d]
|
|
|
|
torch::Tensor& scale) // [1]
|
[Kernel] FP8 support for MoE kernel / Mixtral (#4244)
This PR is the first step towards fixing https://github.com/vllm-project/vllm/pull/3208
It implements dynamic per-tensor scaling (see https://github.com/vllm-project/vllm/pull/4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in https://github.com/vllm-project/vllm/pull/3954). With this PR, the results are as follows:
<img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">
**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7018|± |0.0036|
| - humanities |N/A |none | 5|acc |0.6472|± |0.0065|
| - other |N/A |none | 5|acc |0.7673|± |0.0072|
| - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070|
| - stem |N/A |none | 5|acc |0.6131|± |0.0083|
```
this compares favorably with the fp16 results which are
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7020|± |0.1313|
| - humanities |N/A |none | 5|acc |0.6425|± |0.1349|
| - other |N/A |none | 5|acc |0.7744|± |0.1038|
| - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695|
| - stem |N/A |none | 5|acc |0.6108|± |0.1383|
```
Happy hacking!
2024-04-23 18:18:23 -07:00
|
|
|
{
|
|
|
|
int64_t num_tokens = input.numel() / input.size(-1);
|
|
|
|
int64_t num_elems = input.numel();
|
|
|
|
dim3 grid(num_tokens);
|
|
|
|
dim3 block(1024);
|
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
VLLM_DISPATCH_FLOATING_TYPES(
|
2024-05-22 03:18:41 -04:00
|
|
|
input.scalar_type(), "scaled_fp8_quant_kernel", [&] {
|
|
|
|
vllm::segmented_max_reduction<scalar_t><<<grid, block, 0, stream>>>(
|
|
|
|
scale.data_ptr<float>(), input.data_ptr<scalar_t>(), num_elems);
|
|
|
|
vllm::scaled_fp8_quant_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
2024-08-16 12:06:30 -05:00
|
|
|
out.data_ptr<FP8_TYPE>(), input.data_ptr<scalar_t>(),
|
2024-05-22 03:18:41 -04:00
|
|
|
scale.data_ptr<float>(), num_elems);
|
[Kernel] FP8 support for MoE kernel / Mixtral (#4244)
This PR is the first step towards fixing https://github.com/vllm-project/vllm/pull/3208
It implements dynamic per-tensor scaling (see https://github.com/vllm-project/vllm/pull/4118), so users do not need to compute activation scales on a calibration dataset and they also don't need to convert their model checkpoints. It is enough to specify the `quantization="fp8"` argument. You can try out the PR like this:
```python
from vllm import LLM, SamplingParams
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", tensor_parallel_size=2, quantization="fp8")
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```
**Performance**: For this PR, the focus is on making the code clean (while still trying to get reasonable performance), there is a bunch of optimizations that we will submit as a follow up PR that significantly improve the performance (similar to the numbers in https://github.com/vllm-project/vllm/pull/3954). With this PR, the results are as follows:
<img width="725" alt="Screenshot 2024-04-21 at 1 31 50 PM" src="https://github.com/vllm-project/vllm/assets/113316/d8fe1118-07a0-4d4e-8530-37a77d465a03">
**Accuracy**: The accuracy with this PR on MMLU on `mistralai/Mixtral-8x7B-v0.1` is as follows:
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7018|± |0.0036|
| - humanities |N/A |none | 5|acc |0.6472|± |0.0065|
| - other |N/A |none | 5|acc |0.7673|± |0.0072|
| - social_sciences|N/A |none | 5|acc |0.8099|± |0.0070|
| - stem |N/A |none | 5|acc |0.6131|± |0.0083|
```
this compares favorably with the fp16 results which are
```
| Groups |Version|Filter|n-shot|Metric|Value | |Stderr|
|------------------|-------|------|-----:|------|-----:|---|-----:|
|mmlu |N/A |none | 0|acc |0.7020|± |0.1313|
| - humanities |N/A |none | 5|acc |0.6425|± |0.1349|
| - other |N/A |none | 5|acc |0.7744|± |0.1038|
| - social_sciences|N/A |none | 5|acc |0.8131|± |0.0695|
| - stem |N/A |none | 5|acc |0.6108|± |0.1383|
```
Happy hacking!
2024-04-23 18:18:23 -07:00
|
|
|
});
|
|
|
|
}
|
2024-07-17 21:38:35 -04:00
|
|
|
|
2024-07-19 21:15:26 -04:00
|
|
|
void dynamic_per_token_scaled_fp8_quant(
|
|
|
|
torch::Tensor& out, // [..., d]
|
|
|
|
torch::Tensor const& input, // [..., d]
|
|
|
|
torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
|
2024-07-17 21:38:35 -04:00
|
|
|
TORCH_CHECK(input.is_contiguous());
|
|
|
|
TORCH_CHECK(out.is_contiguous());
|
|
|
|
|
|
|
|
int const hidden_size = input.size(-1);
|
|
|
|
int const num_tokens = input.numel() / hidden_size;
|
|
|
|
dim3 const grid(num_tokens);
|
|
|
|
dim3 const block(std::min(hidden_size, 1024));
|
|
|
|
|
|
|
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
|
|
|
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
|
|
|
VLLM_DISPATCH_FLOATING_TYPES(
|
|
|
|
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
|
|
|
|
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
|
|
|
|
<<<grid, block, 0, stream>>>(
|
2024-08-16 12:06:30 -05:00
|
|
|
out.data_ptr<FP8_TYPE>(), scales.data_ptr<float>(),
|
2024-07-19 21:15:26 -04:00
|
|
|
input.data_ptr<scalar_t>(),
|
|
|
|
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
|
|
|
hidden_size);
|
2024-07-17 21:38:35 -04:00
|
|
|
});
|
|
|
|
}
|