Merge EmbeddedLLM/vllm-rocm into vLLM main (#1836)

Co-authored-by: Philipp Moritz <pcmoritz@gmail.com>
Co-authored-by: Amir Balwel <amoooori04@gmail.com>
Co-authored-by: root <kuanfu.liu@akirakan.com>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: kuanfu <kuanfu.liu@embeddedllm.com>
Co-authored-by: miloice <17350011+kliuae@users.noreply.github.com>
This commit is contained in:
TJian 2023-12-08 15:16:52 +08:00 committed by GitHub
parent c8e7eb1eb3
commit 6ccc0bfffb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
29 changed files with 873 additions and 118 deletions

4
.gitignore vendored
View File

@ -177,3 +177,7 @@ _build/
# vim swap files # vim swap files
*.swo *.swo
*.swp *.swp
# hip files generated by PyTorch
*.hip
*_hip*

62
Dockerfile.rocm Normal file
View File

@ -0,0 +1,62 @@
FROM rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1
# Install some basic utilities
RUN apt-get update && apt-get install python3 python3-pip -y
# Install some basic utilities
RUN apt-get update && apt-get install -y \
curl \
ca-certificates \
sudo \
git \
bzip2 \
libx11-6 \
build-essential \
wget \
unzip \
nvidia-cuda-toolkit \
tmux \
&& rm -rf /var/lib/apt/lists/*
### Mount Point ###
# When launching the container, mount the code directory to /app
ARG APP_MOUNT=/app
VOLUME [ ${APP_MOUNT} ]
WORKDIR ${APP_MOUNT}
RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
# Install ROCm flash-attention
RUN mkdir libs \
&& cd libs \
&& git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \
&& cd flash-attention \
&& git checkout 3d2b6f5 \
&& git submodule update --init \
&& export GPU_ARCHS=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) \
&& patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch \
&& python3 setup.py install \
&& cd ..
COPY ./ /app/vllm
RUN python3 -m pip install --upgrade pip
RUN pip install xformers==0.0.22.post7 --no-deps
RUN cd /app \
&& cd vllm \
&& pip install -U -r requirements-rocm.txt \
&& bash patch_xformers-0.0.22.post7.rocm.sh \
&& python3 setup.py install \
&& cd ..
RUN python3 -m pip install --upgrade pip
RUN python3 -m pip install --no-cache-dir ray[all]
CMD ["/bin/bash"]

View File

@ -17,6 +17,7 @@ Easy, fast, and cheap LLM serving for everyone
--- ---
*Latest News* 🔥 *Latest News* 🔥
- [2023/12] Added ROCm support to vLLM.
- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing). - [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there. - [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv! - [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
@ -43,6 +44,7 @@ vLLM is flexible and easy to use with:
- Tensor parallelism support for distributed inference - Tensor parallelism support for distributed inference
- Streaming outputs - Streaming outputs
- OpenAI-compatible API server - OpenAI-compatible API server
- Support NVIDIA CUDA and AMD ROCm.
vLLM seamlessly supports many Hugging Face models, including the following architectures: vLLM seamlessly supports many Hugging Face models, including the following architectures:

View File

@ -1,6 +1,7 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include "cuda_compat.h"
#include "dispatch_utils.h" #include "dispatch_utils.h"
namespace vllm { namespace vllm {
@ -18,8 +19,8 @@ __global__ void silu_and_mul_kernel(
const int d) { const int d) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]); const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]); const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = silu(x) * y; out[token_idx * d + idx] = silu(x) * y;
} }
} }
@ -57,7 +58,7 @@ __global__ void activation_kernel(
const int d) { const int d) {
const int64_t token_idx = blockIdx.x; const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * d + idx]); const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
out[token_idx * d + idx] = ACT_FN(x); out[token_idx * d + idx] = ACT_FN(x);
} }
} }

View File

@ -15,6 +15,10 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif
#include <torch/extension.h> #include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
@ -23,7 +27,11 @@
#include <algorithm> #include <algorithm>
#ifndef USE_ROCM
#define WARP_SIZE 32 #define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
@ -40,7 +48,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
// Compute the sum per warp. // Compute the sum per warp.
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask); sum += VLLM_SHFL_XOR_SYNC(sum, mask);
} }
// Warp leaders store the data to shared memory. // Warp leaders store the data to shared memory.
@ -59,11 +67,11 @@ inline __device__ float block_sum(float* red_smem, float sum) {
// Parallel reduction inside the warp. // Parallel reduction inside the warp.
#pragma unroll #pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask); sum += VLLM_SHFL_XOR_SYNC(sum, mask);
} }
// Broadcast to other threads. // Broadcast to other threads.
return __shfl_sync(uint32_t(-1), sum, 0); return VLLM_SHFL_SYNC(sum, 0);
} }
// TODO(woosuk): Merge the last two dimensions of the grid. // TODO(woosuk): Merge the last two dimensions of the grid.
@ -223,7 +231,7 @@ __device__ void paged_attention_kernel(
// The 0-th thread of each thread group already has its max qk value. // The 0-th thread of each thread group already has its max qk value.
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
} }
if (lane == 0) { if (lane == 0) {
red_smem[warp_idx] = qk_max; red_smem[warp_idx] = qk_max;
@ -235,10 +243,10 @@ __device__ void paged_attention_kernel(
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll #pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
} }
// Broadcast the max qk value to all threads. // Broadcast the max qk value to all threads.
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); qk_max = VLLM_SHFL_SYNC(qk_max, 0);
// Get the sum of the exp values. // Get the sum of the exp values.
float exp_sum = 0.f; float exp_sum = 0.f;
@ -326,7 +334,7 @@ __device__ void paged_attention_kernel(
float acc = accs[i]; float acc = accs[i];
#pragma unroll #pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += __shfl_xor_sync(uint32_t(-1), acc, mask); acc += VLLM_SHFL_XOR_SYNC(acc, mask);
} }
accs[i] = acc; accs[i] = acc;
} }
@ -492,7 +500,7 @@ __global__ void paged_attention_v2_reduce_kernel(
// Reduce within the warp. // Reduce within the warp.
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
} }
if (lane == 0) { if (lane == 0) {
red_smem[warp_idx] = max_logit; red_smem[warp_idx] = max_logit;
@ -502,10 +510,10 @@ __global__ void paged_attention_v2_reduce_kernel(
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll #pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
} }
// Broadcast the max value to all threads. // Broadcast the max value to all threads.
max_logit = __shfl_sync(uint32_t(-1), max_logit, 0); max_logit = VLLM_SHFL_SYNC(max_logit, 0);
// Load rescaled exp sums to shared memory. // Load rescaled exp sums to shared memory.
float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions); float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
@ -539,9 +547,9 @@ __global__ void paged_attention_v2_reduce_kernel(
} // namespace vllm } // namespace vllm
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
cudaFuncSetAttribute( \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \ ((void*)vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>), \
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ shared_mem_size); \
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \ vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
<<<grid, block, shared_mem_size, stream>>>( \ <<<grid, block, shared_mem_size, stream>>>( \
out_ptr, \ out_ptr, \

View File

@ -17,6 +17,7 @@
*/ */
#pragma once #pragma once
#include "../cuda_compat.h"
#include "attention_dtypes.h" #include "attention_dtypes.h"
#include <float.h> #include <float.h>
@ -39,7 +40,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
float qk = sum(qk_vec); float qk = sum(qk_vec);
#pragma unroll #pragma unroll
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
qk += __shfl_xor_sync(uint32_t(-1), qk, mask); qk += VLLM_SHFL_XOR_SYNC(qk, mask);
} }
return qk; return qk;
} }

View File

@ -21,8 +21,17 @@
#include "attention_generic.cuh" #include "attention_generic.cuh"
#include "dtype_float32.cuh" #include "dtype_float32.cuh"
#ifndef USE_ROCM
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
typedef __hip_bfloat162 __nv_bfloat162;
typedef __hip_bfloat16 __nv_bfloat16;
#endif
#include <stdint.h> #include <stdint.h>
namespace vllm { namespace vllm {
@ -98,7 +107,11 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false); assert(false);
#else #else
#ifndef USE_ROCM
return a + b; return a + b;
#else
return __hadd(a, b);
#endif
#endif #endif
} }

View File

@ -21,6 +21,10 @@
#include "attention_generic.cuh" #include "attention_generic.cuh"
#include "dtype_float32.cuh" #include "dtype_float32.cuh"
#ifdef USE_ROCM
#include <hip/hip_fp16.h>
#endif
#include <stdint.h> #include <stdint.h>
namespace vllm { namespace vllm {
@ -63,21 +67,47 @@ struct FloatVec<uint4> {
// Utility functions for type conversions. // Utility functions for type conversions.
inline __device__ uint32_t h0_h0(uint16_t a) { inline __device__ uint32_t h0_h0(uint16_t a) {
#ifndef USE_ROCM
uint32_t b; uint32_t b;
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
return b; return b;
#else
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
tmp.u16[0] = a;
tmp.u16[1] = a;
return tmp.u32;
#endif
} }
inline __device__ float half_to_float(uint16_t h) { inline __device__ float half_to_float(uint16_t h) {
float f; float f;
#ifndef USE_ROCM
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
#else
asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
#endif
return f; return f;
} }
inline __device__ float2 half2_to_float2(uint32_t v) { inline __device__ float2 half2_to_float2(uint32_t v) {
#ifndef USE_ROCM
uint16_t lo, hi; uint16_t lo, hi;
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
return make_float2(half_to_float(lo), half_to_float(hi)); return make_float2(half_to_float(lo), half_to_float(hi));
#else
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
tmp.u32 = v;
float2 ret;
ret.x = half_to_float(tmp.u16[0]);
ret.y = half_to_float(tmp.u16[1]);
return ret;
#endif
} }
inline __device__ uint16_t float_to_half(float f) { inline __device__ uint16_t float_to_half(float f) {
@ -85,7 +115,11 @@ inline __device__ uint16_t float_to_half(float f) {
uint32_t u32; uint32_t u32;
uint16_t u16[2]; uint16_t u16[2];
} tmp; } tmp;
#ifndef USE_ROCM
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
#else
asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
#endif
return tmp.u16[0]; return tmp.u16[0];
} }
@ -94,26 +128,38 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
uint32_t u32; uint32_t u32;
uint16_t u16[2]; uint16_t u16[2];
} tmp; } tmp;
#ifndef USE_ROCM
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
#else #else
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
#endif #endif
#else
tmp.u16[0] = float_to_half(f.x);
tmp.u16[1] = float_to_half(f.y);
#endif
return tmp.u32; return tmp.u32;
} }
// Vector addition. // Vector addition.
inline __device__ uint16_t add(uint16_t a, uint16_t b) { inline __device__ uint16_t add(uint16_t a, uint16_t b) {
uint16_t c; uint16_t c;
#ifndef USE_ROCM
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
#else
asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
#endif
return c; return c;
} }
inline __device__ uint32_t add(uint32_t a, uint32_t b) { inline __device__ uint32_t add(uint32_t a, uint32_t b) {
uint32_t c; uint32_t c;
#ifndef USE_ROCM
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
#else
asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
#endif
return c; return c;
} }
@ -158,14 +204,22 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) {
template<> template<>
inline __device__ uint16_t mul(uint16_t a, uint16_t b) { inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
uint16_t c; uint16_t c;
#ifndef USE_ROCM
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
#else
asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
#endif
return c; return c;
} }
template<> template<>
inline __device__ uint32_t mul(uint32_t a, uint32_t b) { inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
uint32_t c; uint32_t c;
#ifndef USE_ROCM
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
#else
asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
#endif
return c; return c;
} }
@ -272,7 +326,11 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
// Vector fused multiply-add. // Vector fused multiply-add.
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
uint32_t d; uint32_t d;
#ifndef USE_ROCM
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
#else
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
#endif
return d; return d;
} }

View File

@ -1,6 +1,7 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include "cuda_compat.h"
#include "dispatch_utils.h" #include "dispatch_utils.h"
#include <algorithm> #include <algorithm>
@ -28,8 +29,8 @@ void swap_blocks(
TORCH_CHECK(false, "Invalid device combination"); TORCH_CHECK(false, "Invalid device combination");
} }
void *src_ptr = src.data_ptr(); char *src_ptr = static_cast<char*>(src.data_ptr());
void *dst_ptr = dst.data_ptr(); char *dst_ptr = static_cast<char*>(dst.data_ptr());
const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
@ -267,8 +268,8 @@ __global__ void gather_cached_kv_kernel(
+ head_offset * block_size + head_offset * block_size
+ block_offset; + block_offset;
key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]); key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]);
value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]); value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]);
} }
} }
@ -333,8 +334,8 @@ __global__ void gather_cached_kv_kernel_optimized(
src_key_indices[j] = src_key_idx; src_key_indices[j] = src_key_idx;
src_value_indices[j] = src_value_idx; src_value_indices[j] = src_value_idx;
keys_to_store[j] = __ldg(&key_cache[src_key_idx]); keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]);
values_to_store[j] = __ldg(&value_cache[src_value_idx]); values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]);
} }
#pragma unroll #pragma unroll

28
csrc/cuda_compat.h Normal file
View File

@ -0,0 +1,28 @@
#pragma once
#ifndef USE_ROCM
#define VLLM_LDG(arg) __ldg(arg)
#else
#define VLLM_LDG(arg) *(arg)
#endif
#ifndef USE_ROCM
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
#else
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
#endif
#ifndef USE_ROCM
#define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
#else
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
#endif
#ifndef USE_ROCM
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
#else
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
#endif

View File

@ -1,3 +1,6 @@
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif
int get_device_attribute( int get_device_attribute(
int attribute, int attribute,
int device_id) int device_id)

View File

@ -61,12 +61,14 @@ void gelu_fast(
torch::Tensor& out, torch::Tensor& out,
torch::Tensor& input); torch::Tensor& input);
#ifndef USE_ROCM
torch::Tensor awq_gemm( torch::Tensor awq_gemm(
torch::Tensor _in_feats, torch::Tensor _in_feats,
torch::Tensor _kernel, torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _scaling_factors,
torch::Tensor _zeros, torch::Tensor _zeros,
int split_k_iters); int split_k_iters);
#endif
void squeezellm_gemm( void squeezellm_gemm(
torch::Tensor vec, torch::Tensor vec,

View File

@ -1,6 +1,7 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include "cuda_compat.h"
#include "dispatch_utils.h" #include "dispatch_utils.h"
namespace vllm { namespace vllm {
@ -19,14 +20,14 @@ inline __device__ void apply_rotary_embedding(
// GPT-NeoX style rotary embedding. // GPT-NeoX style rotary embedding.
x_index = rot_offset; x_index = rot_offset;
y_index = embed_dim + rot_offset; y_index = embed_dim + rot_offset;
cos = __ldg(cos_ptr + x_index); cos = VLLM_LDG(cos_ptr + x_index);
sin = __ldg(sin_ptr + x_index); sin = VLLM_LDG(sin_ptr + x_index);
} else { } else {
// GPT-J style rotary embedding. // GPT-J style rotary embedding.
x_index = 2 * rot_offset; x_index = 2 * rot_offset;
y_index = 2 * rot_offset + 1; y_index = 2 * rot_offset + 1;
cos = __ldg(cos_ptr + x_index / 2); cos = VLLM_LDG(cos_ptr + x_index / 2);
sin = __ldg(sin_ptr + x_index / 2); sin = VLLM_LDG(sin_ptr + x_index / 2);
} }
const scalar_t x = arr[x_index]; const scalar_t x = arr[x_index];

View File

@ -48,8 +48,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&rotary_embedding, &rotary_embedding,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); "Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
#ifndef USE_ROCM
// Quantization ops // Quantization ops
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
#endif
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
// Cache ops // Cache ops

View File

@ -20,9 +20,17 @@ __device__ inline unsigned int as_unsigned(int i) {
// 4-bit matvec kernel (LUT-based) // 4-bit matvec kernel (LUT-based)
__global__ void NUQ4MatMulKernel( __global__ void NUQ4MatMulKernel(
#ifndef USE_ROCM
const half2* __restrict__ vec, const half2* __restrict__ vec,
#else
const __half2* __restrict__ vec,
#endif
const int* __restrict__ mat, const int* __restrict__ mat,
#ifndef USE_ROCM
half2* __restrict__ mul, half2* __restrict__ mul,
#else
float2* __restrict__ mul,
#endif
const __half* __restrict__ lookup_table, const __half* __restrict__ lookup_table,
int height, int height,
int width, int width,
@ -35,7 +43,11 @@ __global__ void NUQ4MatMulKernel(
int row = BLOCKHEIGHT4 * blockIdx.x; int row = BLOCKHEIGHT4 * blockIdx.x;
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
#ifndef USE_ROCM
__shared__ half2 blockvec[blockwidth2]; __shared__ half2 blockvec[blockwidth2];
#else
__shared__ __half2 blockvec[blockwidth2];
#endif
__shared__ __half deq2[16][BLOCKWIDTH]; __shared__ __half deq2[16][BLOCKWIDTH];
int off = threadIdx.x; int off = threadIdx.x;
@ -46,8 +58,13 @@ __global__ void NUQ4MatMulKernel(
} }
__half res; __half res;
#ifndef USE_ROCM
half2 res2; half2 res2;
half2 tmp2; half2 tmp2;
#else
__half2 res2;
__half2 tmp2;
#endif
int i; int i;
int k; int k;
@ -68,48 +85,96 @@ __global__ void NUQ4MatMulKernel(
while (k < blockwidth2) { while (k < blockwidth2) {
tmp1 = as_unsigned(mat[i]); tmp1 = as_unsigned(mat[i]);
#ifndef USE_ROCM
res2 = {}; res2 = {};
tmp2 = {}; tmp2 = {};
#else
res2.x = __half_as_ushort(__float2half(0));
res2.y = __half_as_ushort(__float2half(0));
tmp2.x = __half_as_ushort(__float2half(0));
tmp2.y = __half_as_ushort(__float2half(0));
#endif
lut_index1 = tmp1 & 0xF; lut_index1 = tmp1 & 0xF;
lut_index2 = (tmp1 >> 4) & 0xF; lut_index2 = (tmp1 >> 4) & 0xF;
#ifndef USE_ROCM
tmp2.x = deq2[lut_index1][off]; tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off]; tmp2.y = deq2[lut_index2][off];
#else
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
res2 = __hfma2(tmp2, blockvec[k + 0], res2); res2 = __hfma2(tmp2, blockvec[k + 0], res2);
lut_index1 = (tmp1 >> 8) & 0xF; lut_index1 = (tmp1 >> 8) & 0xF;
lut_index2 = (tmp1 >> 12) & 0xF; lut_index2 = (tmp1 >> 12) & 0xF;
#ifndef USE_ROCM
tmp2.x = deq2[lut_index1][off]; tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off]; tmp2.y = deq2[lut_index2][off];
#else
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
res2 = __hfma2(tmp2, blockvec[k + 1], res2); res2 = __hfma2(tmp2, blockvec[k + 1], res2);
lut_index1 = (tmp1 >> 16) & 0xF; lut_index1 = (tmp1 >> 16) & 0xF;
lut_index2 = (tmp1 >> 20) & 0xF; lut_index2 = (tmp1 >> 20) & 0xF;
#ifndef USE_ROCM
tmp2.x = deq2[lut_index1][off]; tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off]; tmp2.y = deq2[lut_index2][off];
#else
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
res2 = __hfma2(tmp2, blockvec[k + 2], res2); res2 = __hfma2(tmp2, blockvec[k + 2], res2);
lut_index1 = (tmp1 >> 24) & 0xF; lut_index1 = (tmp1 >> 24) & 0xF;
lut_index2 = (tmp1 >> 28) & 0xF; lut_index2 = (tmp1 >> 28) & 0xF;
#ifndef USE_ROCM
tmp2.x = deq2[lut_index1][off]; tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off]; tmp2.y = deq2[lut_index2][off];
#else
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
res2 = __hfma2(tmp2, blockvec[k + 3], res2); res2 = __hfma2(tmp2, blockvec[k + 3], res2);
#ifndef USE_ROCM
res = __hadd(__hadd(res2.x, res2.y), res); res = __hadd(__hadd(res2.x, res2.y), res);
#else
res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res);
#endif
i += width; i += width;
k += 4; k += 4;
} }
// col%2 -> only set one of the two values // col%2 -> only set one of the two values
#ifndef USE_ROCM
half2 res3 = {}; half2 res3 = {};
if (col % 2 == 0) { if (col % 2 == 0) {
res3.x = res; res3.x = res;
} else { } else {
res3.y = res; res3.y = res;
} }
#else
__half2 res3;
res3.x = __half_as_ushort(__float2half(0));
res3.y = __half_as_ushort(__float2half(0));
if (col % 2 == 0) {
res3.x = __half_as_ushort(res);
} else {
res3.y = __half_as_ushort(res);
}
#endif
#ifndef USE_ROCM
atomicAdd(&mul[b * width / 2 + col / 2], res3); atomicAdd(&mul[b * width / 2 + col / 2], res3);
#else
int tmp_addr = b * width / 2 + col / 2;
atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x)));
atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y)));
#endif
} }
} }
@ -136,10 +201,19 @@ void squeezellm_gemm(
dim3 threads(BLOCKWIDTH); dim3 threads(BLOCKWIDTH);
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>( vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>(
#ifndef USE_ROCM
(half2*) vec.data<at::Half>(), (half2*) vec.data<at::Half>(),
#else
(__half2*) vec.data_ptr<at::Half>(),
#endif
mat.data_ptr<int>(), mat.data_ptr<int>(),
#ifndef USE_ROCM
(half2*) mul.data<at::Half>(), (half2*) mul.data<at::Half>(),
(__half*) lookup_table.data<at::Half>(), (__half*) lookup_table.data<at::Half>(),
#else
(float2*) mul.data_ptr<float>(),
(__half*) lookup_table.data_ptr<at::Half>(),
#endif
height, width, batch, vec_height height, width, batch, vec_height
); );
} }

View File

@ -17,13 +17,15 @@
*/ */
#pragma once #pragma once
#include "cuda_compat.h"
namespace vllm { namespace vllm {
template<typename T> template<typename T>
__inline__ __device__ T warpReduceSum(T val) { __inline__ __device__ T warpReduceSum(T val) {
#pragma unroll #pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) for (int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(0xffffffff, val, mask, 32); val += VLLM_SHFL_XOR_SYNC(val, mask);
return val; return val;
} }

View File

@ -0,0 +1,143 @@
.. _installation_rocm:
Installation with ROCm
======================
vLLM 0.2.x onwards supports model inferencing and serving on AMD GPUs with ROCm.
At the moment AWQ quantization is not supported in ROCm, but SqueezeLLM quantization has been ported.
Data types currently supported in ROCm are FP16 and BF16.
Requirements
------------
* OS: Linux
* Python: 3.8 -- 3.11 (Verified on 3.10)
* GPU: MI200s
* Pytorch 2.0.1/2.1.1/2.2
* ROCm 5.7
Installation options:
#. :ref:`(Recommended) Quick start with vLLM pre-installed in Docker Image <quick_start_docker_rocm>`
#. :ref:`Build from source <build_from_source_rocm>`
#. :ref:`Build from source with docker <build_from_source_docker_rocm>`
.. _quick_start_docker_rocm:
(Recommended) Option 1: Quick start with vLLM pre-installed in Docker Image
---------------------------------------------------------------------------
.. code-block:: console
$ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.3
$ docker run -it \
--network=host \
--group-add=video \
--ipc=host \
--cap-add=SYS_PTRACE \
--security-opt seccomp=unconfined \
--device /dev/kfd \
--device /dev/dri \
-v <path/to/model>:/app/model \
embeddedllminfo/vllm-rocm \
bash
.. _build_from_source_rocm:
Option 2: Build from source
---------------------------
You can build and install vLLM from source:
0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
- `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_
- `Pytorch <https://pytorch.org/>`_
.. code-block:: console
$ pip install torch==2.2.0.dev20231206+rocm5.7 --index-url https://download.pytorch.org/whl/nightly/rocm5.7 # tested version
1. Install `flash attention for ROCm <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm>`_
Install ROCm's flash attention (v2.0.4) following the instructions from `ROCmSoftwarePlatform/flash-attention <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm#amd-gpurocm-support>`_
.. note::
- If you are using rocm5.7 with pytorch 2.1.0 onwards, you don't need to apply the `hipify_python.patch`. You can build the ROCm flash attention directly.
- If you fail to install `ROCmSoftwarePlatform/flash-attention`, try cloning from the commit `6fd2f8e572805681cd67ef8596c7e2ce521ed3c6`.
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
2. Setup `xformers==0.0.22.post7` without dependencies, and apply patches to adapt for ROCm flash attention
.. code-block:: console
$ pip install xformers==0.0.22.post7 --no-deps
$ bash patch_xformers-0.0.22.post7.rocm.sh
3. Build vLLM.
.. code-block:: console
$ cd vllm
$ pip install -U -r requirements-rocm.txt
$ python setup.py install # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation
.. _build_from_source_docker_rocm:
Option 3: Build from source with docker
-----------------------------------------------------
You can build and install vLLM from source:
Build a docker image from `Dockerfile.rocm`, and launch a docker container.
.. code-block:: console
$ docker build -f Dockerfile.rocm -t vllm-rocm .
$ docker run -it \
--network=host \
--group-add=video \
--ipc=host \
--cap-add=SYS_PTRACE \
--security-opt seccomp=unconfined \
--device /dev/kfd \
--device /dev/dri \
-v <path/to/model>:/app/model \
vllm-rocm \
bash
Alternatively, if you plan to install vLLM-ROCm on a local machine or start from a fresh docker image (e.g. rocm/pytorch), you can follow the steps below:
0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
- `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_
- `Pytorch <https://pytorch.org/>`_
1. Install `flash attention for ROCm <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm>`_
Install ROCm's flash attention (v2.0.4) following the instructions from `ROCmSoftwarePlatform/flash-attention <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm#amd-gpurocm-support>`_
.. note::
- If you are using rocm5.7 with pytorch 2.1.0 onwards, you don't need to apply the `hipify_python.patch`. You can build the ROCm flash attention directly.
- If you fail to install `ROCmSoftwarePlatform/flash-attention`, try cloning from the commit `6fd2f8e572805681cd67ef8596c7e2ce521ed3c6`.
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
2. Setup `xformers==0.0.22.post7` without dependencies, and apply patches to adapt for ROCm flash attention
.. code-block:: console
$ pip install xformers==0.0.22.post7 --no-deps
$ bash patch_xformers-0.0.22.post7.rocm.sh
3. Build vLLM.
.. code-block:: console
$ cd vllm
$ pip install -U -r requirements-rocm.txt
$ python setup.py install # This may take 5-10 minutes.

View File

@ -39,6 +39,7 @@ vLLM is flexible and easy to use with:
* Tensor parallelism support for distributed inference * Tensor parallelism support for distributed inference
* Streaming outputs * Streaming outputs
* OpenAI-compatible API server * OpenAI-compatible API server
* Support NVIDIA CUDA and AMD ROCm.
For more information, check out the following: For more information, check out the following:
@ -56,6 +57,7 @@ Documentation
:caption: Getting Started :caption: Getting Started
getting_started/installation getting_started/installation
getting_started/amd-installation
getting_started/quickstart getting_started/quickstart
.. toctree:: .. toctree::

View File

@ -0,0 +1,22 @@
#!/bin/bash
export XFORMERS_FMHA_FLASH_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.flash.__file__)')
export XFORMERS_FMHA_COMMON_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.common.__file__)')
echo $XFORMERS_FMHA_FLASH_PATH
echo $XFORMERS_FMHA_COMMON_PATH
if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch"; then
echo "Applying patch to ${XFORMERS_FMHA_FLASH_PATH}"
patch -p0 $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch"
echo "Successfully patch ${XFORMERS_FMHA_FLASH_PATH}"
else
echo "${XFORMERS_FMHA_FLASH_PATH} was patched before"
fi
if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch"; then
echo "Applying patch to ${XFORMERS_FMHA_COMMON_PATH}"
patch -p0 $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch"
echo "Successfully patch ${XFORMERS_FMHA_COMMON_PATH}"
else
echo "${XFORMERS_FMHA_COMMON_PATH} was patched before"
fi

17
requirements-rocm.txt Normal file
View File

@ -0,0 +1,17 @@
ninja # For faster builds.
typing-extensions>=4.8.0
starlette
psutil
ray >= 2.5.1
pandas # Required for Ray data.
pyarrow # Required for Ray data.
sentencepiece # Required for LLaMA tokenizer.
numpy
tokenizers>=0.15.0
huggingface_hub<0.18,>=0.16.4
einops # Required for phi-1_5
transformers >= 4.34.0 # Required for Mistral.
fastapi
uvicorn[standard]
pydantic == 1.10.13 # Required for OpenAI server.
aioprometheus[starlette]

View File

@ -0,0 +1,13 @@
--- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/common.py 2023-11-29 03:17:03.930103539 +0000
+++ common.py 2023-11-28 16:14:19.846233146 +0000
@@ -298,8 +298,8 @@
dtype = d.query.dtype
if device_type not in cls.SUPPORTED_DEVICES:
reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})")
- if device_type == "cuda" and not _built_with_cuda:
- reasons.append("xFormers wasn't build with CUDA support")
+ #if device_type == "cuda" and not _built_with_cuda:
+ # reasons.append("xFormers wasn't build with CUDA support")
if device_type == "cuda":
device_capability = torch.cuda.get_device_capability(d.device)
if device_capability < cls.CUDA_MINIMUM_COMPUTE_CAPABILITY:

View File

@ -0,0 +1,134 @@
--- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/flash.py 2023-11-29 03:17:03.930103539 +0000
+++ flash.py 2023-11-28 16:14:25.206128903 +0000
@@ -31,39 +31,39 @@
FLASH_VERSION = "0.0.0"
try:
- try:
- from ... import _C_flashattention # type: ignore[attr-defined]
- from ..._cpp_lib import _build_metadata
-
- if _build_metadata is not None:
- FLASH_VERSION = _build_metadata.flash_version
- except ImportError:
- import flash_attn
- from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
-
- FLASH_VERSION = flash_attn.__version__
- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
- if flash_ver_parsed < (2, 3):
- raise ImportError("Requires 2.3 for sliding window support")
+ #try:
+ # from ... import _C_flashattention # type: ignore[attr-defined]
+ # from ..._cpp_lib import _build_metadata
+
+ # if _build_metadata is not None:
+ # FLASH_VERSION = _build_metadata.flash_version
+ #except ImportError:
+ import flash_attn
+ from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
+
+ FLASH_VERSION = flash_attn.__version__
+ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
+ # if flash_ver_parsed < (2, 3):
+ # raise ImportError("Requires 2.3 for sliding window support")
# create library so that flash-attn goes through the PyTorch Dispatcher
- _flash_lib = torch.library.Library("xformers_flash", "DEF")
+ #_flash_lib = torch.library.Library("xformers_flash", "DEF")
- _flash_lib.define(
- "flash_fwd(Tensor query, Tensor key, Tensor value, "
- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
- "int max_seqlen_q, int max_seqlen_k, "
- "float p, float softmax_scale, "
- "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)"
- )
-
- _flash_lib.define(
- "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
- "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
- "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
- "int max_seqlen_q, int max_seqlen_k, "
- "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
- )
+ #_flash_lib.define(
+ # "flash_fwd(Tensor query, Tensor key, Tensor value, "
+ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
+ # "int max_seqlen_q, int max_seqlen_k, "
+ # "float p, float softmax_scale, "
+ # "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)"
+ #)
+
+ #_flash_lib.define(
+ # "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
+ # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
+ # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
+ # "int max_seqlen_q, int max_seqlen_k, "
+ # "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
+ #)
def _flash_fwd(
query,
@@ -98,8 +98,8 @@
p,
softmax_scale,
is_causal,
- window_size - 1, # window_size_left
- -1, # window_size_right
+ # window_size - 1, # window_size_left
+ # -1, # window_size_right
return_softmax,
None, # rng
)
@@ -127,8 +127,8 @@
softmax_scale,
False,
is_causal,
- window_size - 1, # window_size_left
- -1, # window_size_right
+ # window_size - 1, # window_size_left
+ # -1, # window_size_right
return_softmax,
None,
)
@@ -169,8 +169,8 @@
p,
softmax_scale,
is_causal,
- window_size - 1, # window_size_left
- -1, # window_size_right
+ # window_size - 1, # window_size_left
+ # -1, # window_size_right
None,
rng_state,
)
@@ -193,15 +193,15 @@
softmax_scale,
False, # zero_tensors
is_causal,
- window_size - 1, # window_size_left
- -1, # window_size_right
+ # window_size - 1, # window_size_left
+ # -1, # window_size_right
None,
rng_state,
)
return dq, dk, dv
- _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
- _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
+ #_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
+ #_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
except ImportError:
pass
@@ -348,7 +348,7 @@
implementation.
"""
- OPERATOR = get_operator("xformers_flash", "flash_fwd")
+ OPERATOR = _flash_fwd # get_operator("xformers_flash", "flash_fwd")
SUPPORTED_DEVICES: Set[str] = {"cuda"}
CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}

126
setup.py
View File

@ -8,27 +8,83 @@ import warnings
from packaging.version import parse, Version from packaging.version import parse, Version
import setuptools import setuptools
import torch import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
ROOT_DIR = os.path.dirname(__file__) ROOT_DIR = os.path.dirname(__file__)
MAIN_CUDA_VERSION = "12.1" MAIN_CUDA_VERSION = "12.1"
# Supported NVIDIA GPU architectures. # Supported NVIDIA GPU architectures.
SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"} NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"}
# SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS)
def _is_hip() -> bool:
return torch.version.hip is not None
def _is_cuda() -> bool:
return torch.version.cuda is not None
# Compiler flags. # Compiler flags.
CXX_FLAGS = ["-g", "-O2", "-std=c++17"] CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
# TODO(woosuk): Should we use -O3? # TODO(woosuk): Should we use -O3?
NVCC_FLAGS = ["-O2", "-std=c++17"] NVCC_FLAGS = ["-O2", "-std=c++17"]
if _is_hip():
if ROCM_HOME is None:
raise RuntimeError(
"Cannot find ROCM_HOME. ROCm must be available to build the package."
)
NVCC_FLAGS += ["-DUSE_ROCM"]
if _is_cuda() and CUDA_HOME is None:
raise RuntimeError(
"Cannot find CUDA_HOME. CUDA must be available to build the package.")
ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0 ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
if CUDA_HOME is None:
raise RuntimeError( def get_amdgpu_offload_arch():
"Cannot find CUDA_HOME. CUDA must be available to build the package.") command = "/opt/rocm/llvm/bin/amdgpu-offload-arch"
try:
output = subprocess.check_output([command])
return output.decode('utf-8').strip()
except subprocess.CalledProcessError as e:
error_message = f"Error: {e}"
raise RuntimeError(error_message) from e
except FileNotFoundError as e:
# If the command is not found, print an error message
error_message = f"The command {command} was not found."
raise RuntimeError(error_message) from e
return None
def get_hipcc_rocm_version():
# Run the hipcc --version command
result = subprocess.run(['hipcc', '--version'],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True)
# Check if the command was executed successfully
if result.returncode != 0:
print("Error running 'hipcc --version'")
return None
# Extract the version using a regular expression
match = re.search(r'HIP version: (\S+)', result.stdout)
if match:
# Return the version string
return match.group(1)
else:
print("Could not find HIP version in the output")
return None
def get_nvcc_cuda_version(cuda_dir: str) -> Version: def get_nvcc_cuda_version(cuda_dir: str) -> Version:
@ -61,20 +117,22 @@ def get_torch_arch_list() -> Set[str]:
return set() return set()
# Filter out the invalid architectures and print a warning. # Filter out the invalid architectures and print a warning.
valid_archs = SUPPORTED_ARCHS.union({s + "+PTX" for s in SUPPORTED_ARCHS}) valid_archs = NVIDIA_SUPPORTED_ARCHS.union(
{s + "+PTX"
for s in NVIDIA_SUPPORTED_ARCHS})
arch_list = torch_arch_list.intersection(valid_archs) arch_list = torch_arch_list.intersection(valid_archs)
# If none of the specified architectures are valid, raise an error. # If none of the specified architectures are valid, raise an error.
if not arch_list: if not arch_list:
raise RuntimeError( raise RuntimeError(
"None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env " "None of the CUDA/ROCM architectures in `TORCH_CUDA_ARCH_LIST` env "
f"variable ({env_arch_list}) is supported. " f"variable ({env_arch_list}) is supported. "
f"Supported CUDA architectures are: {valid_archs}.") f"Supported CUDA/ROCM architectures are: {valid_archs}.")
invalid_arch_list = torch_arch_list - valid_archs invalid_arch_list = torch_arch_list - valid_archs
if invalid_arch_list: if invalid_arch_list:
warnings.warn( warnings.warn(
f"Unsupported CUDA architectures ({invalid_arch_list}) are " f"Unsupported CUDA/ROCM architectures ({invalid_arch_list}) are "
"excluded from the `TORCH_CUDA_ARCH_LIST` env variable " "excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
f"({env_arch_list}). Supported CUDA architectures are: " f"({env_arch_list}). Supported CUDA/ROCM architectures are: "
f"{valid_archs}.", f"{valid_archs}.",
stacklevel=2) stacklevel=2)
return arch_list return arch_list
@ -82,7 +140,7 @@ def get_torch_arch_list() -> Set[str]:
# First, check the TORCH_CUDA_ARCH_LIST environment variable. # First, check the TORCH_CUDA_ARCH_LIST environment variable.
compute_capabilities = get_torch_arch_list() compute_capabilities = get_torch_arch_list()
if not compute_capabilities: if _is_cuda() and not compute_capabilities:
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
# GPUs on the current machine. # GPUs on the current machine.
device_count = torch.cuda.device_count() device_count = torch.cuda.device_count()
@ -93,20 +151,21 @@ if not compute_capabilities:
"GPUs with compute capability below 7.0 are not supported.") "GPUs with compute capability below 7.0 are not supported.")
compute_capabilities.add(f"{major}.{minor}") compute_capabilities.add(f"{major}.{minor}")
if _is_cuda():
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
if not compute_capabilities: if not compute_capabilities:
# If no GPU is specified nor available, add all supported architectures # If no GPU is specified nor available, add all supported architectures
# based on the NVCC CUDA version. # based on the NVCC CUDA version.
compute_capabilities = SUPPORTED_ARCHS.copy() compute_capabilities = NVIDIA_SUPPORTED_ARCHS.copy()
if nvcc_cuda_version < Version("11.1"): if nvcc_cuda_version < Version("11.1"):
compute_capabilities.remove("8.6") compute_capabilities.remove("8.6")
if nvcc_cuda_version < Version("11.8"): if nvcc_cuda_version < Version("11.8"):
compute_capabilities.remove("8.9") compute_capabilities.remove("8.9")
compute_capabilities.remove("9.0") compute_capabilities.remove("9.0")
# Validate the NVCC CUDA version. # Validate the NVCC CUDA version.
if nvcc_cuda_version < Version("11.0"): if nvcc_cuda_version < Version("11.0"):
raise RuntimeError("CUDA 11.0 or higher is required to build the package.") raise RuntimeError(
"CUDA 11.0 or higher is required to build the package.")
if (nvcc_cuda_version < Version("11.1") if (nvcc_cuda_version < Version("11.1")
and any(cc.startswith("8.6") for cc in compute_capabilities)): and any(cc.startswith("8.6") for cc in compute_capabilities)):
raise RuntimeError( raise RuntimeError(
@ -134,7 +193,9 @@ for capability in compute_capabilities:
num = capability[0] + capability[2] num = capability[0] + capability[2]
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
if capability.endswith("+PTX"): if capability.endswith("+PTX"):
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] NVCC_FLAGS += [
"-gencode", f"arch=compute_{num},code=compute_{num}"
]
# Use NVCC threads to parallelize the build. # Use NVCC threads to parallelize the build.
if nvcc_cuda_version >= Version("11.2"): if nvcc_cuda_version >= Version("11.2"):
@ -142,20 +203,32 @@ if nvcc_cuda_version >= Version("11.2"):
num_threads = min(os.cpu_count(), nvcc_threads) num_threads = min(os.cpu_count(), nvcc_threads)
NVCC_FLAGS += ["--threads", str(num_threads)] NVCC_FLAGS += ["--threads", str(num_threads)]
elif _is_hip():
amd_arch = get_amdgpu_offload_arch()
if amd_arch not in ROCM_SUPPORTED_ARCHS:
raise RuntimeError(
f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}"
f"amdgpu_arch_found: {amd_arch}")
ext_modules = [] ext_modules = []
vllm_extension = CUDAExtension(
name="vllm._C", vllm_extension_sources = [
sources=[
"csrc/cache_kernels.cu", "csrc/cache_kernels.cu",
"csrc/attention/attention_kernels.cu", "csrc/attention/attention_kernels.cu",
"csrc/pos_encoding_kernels.cu", "csrc/pos_encoding_kernels.cu",
"csrc/activation_kernels.cu", "csrc/activation_kernels.cu",
"csrc/layernorm_kernels.cu", "csrc/layernorm_kernels.cu",
"csrc/quantization/awq/gemm_kernels.cu",
"csrc/quantization/squeezellm/quant_cuda_kernel.cu", "csrc/quantization/squeezellm/quant_cuda_kernel.cu",
"csrc/cuda_utils_kernels.cu", "csrc/cuda_utils_kernels.cu",
"csrc/pybind.cpp", "csrc/pybind.cpp",
], ]
if _is_cuda():
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
vllm_extension = CUDAExtension(
name="vllm._C",
sources=vllm_extension_sources,
extra_compile_args={ extra_compile_args={
"cxx": CXX_FLAGS, "cxx": CXX_FLAGS,
"nvcc": NVCC_FLAGS, "nvcc": NVCC_FLAGS,
@ -183,10 +256,19 @@ def find_version(filepath: str) -> str:
def get_vllm_version() -> str: def get_vllm_version() -> str:
version = find_version(get_path("vllm", "__init__.py")) version = find_version(get_path("vllm", "__init__.py"))
if _is_hip():
# Get the HIP version
hipcc_version = get_hipcc_rocm_version()
if hipcc_version != MAIN_CUDA_VERSION:
rocm_version_str = hipcc_version.replace(".", "")[:3]
version += f"+rocm{rocm_version_str}"
else:
cuda_version = str(nvcc_cuda_version) cuda_version = str(nvcc_cuda_version)
if cuda_version != MAIN_CUDA_VERSION: if cuda_version != MAIN_CUDA_VERSION:
cuda_version_str = cuda_version.replace(".", "")[:3] cuda_version_str = cuda_version.replace(".", "")[:3]
version += f"+cu{cuda_version_str}" version += f"+cu{cuda_version_str}"
return version return version
@ -201,6 +283,10 @@ def read_readme() -> str:
def get_requirements() -> List[str]: def get_requirements() -> List[str]:
"""Get Python package dependencies from requirements.txt.""" """Get Python package dependencies from requirements.txt."""
if _is_hip():
with open(get_path("requirements-rocm.txt")) as f:
requirements = f.read().strip().split("\n")
else:
with open(get_path("requirements.txt")) as f: with open(get_path("requirements.txt")) as f:
requirements = f.read().strip().split("\n") requirements = f.read().strip().split("\n")
return requirements return requirements

View File

@ -6,7 +6,7 @@ from transformers import PretrainedConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.config import get_config from vllm.transformers_utils.config import get_config
from vllm.utils import get_cpu_memory from vllm.utils import get_cpu_memory, is_hip
logger = init_logger(__name__) logger = init_logger(__name__)
@ -98,12 +98,27 @@ class ModelConfig:
def _verify_load_format(self) -> None: def _verify_load_format(self) -> None:
load_format = self.load_format.lower() load_format = self.load_format.lower()
if load_format not in [ supported_load_format = [
"auto", "pt", "safetensors", "npcache", "dummy" "auto", "pt", "safetensors", "npcache", "dummy"
]: ]
rocm_not_supported_load_format = ["safetensors"]
if load_format not in supported_load_format:
raise ValueError( raise ValueError(
f"Unknown load format: {self.load_format}. Must be one of " f"Unknown load format: {self.load_format}. Must be one of "
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.") "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
if is_hip():
if load_format in ["safetensors"]:
rocm_supported_load_format = [
f for f in supported_load_format
if (f not in rocm_not_supported_load_format)
]
raise ValueError(
f"load format \'{load_format}\' is not supported in ROCm. "
f"Supported load format are "
f"{rocm_supported_load_format}")
# Force ROCm to load from pt weights if nothing specific is set
if load_format == "auto":
load_format = "pt"
self.load_format = load_format self.load_format = load_format
def _verify_tokenizer_mode(self) -> None: def _verify_tokenizer_mode(self) -> None:
@ -116,6 +131,7 @@ class ModelConfig:
def _verify_quantization(self) -> None: def _verify_quantization(self) -> None:
supported_quantization = ["awq", "squeezellm"] supported_quantization = ["awq", "squeezellm"]
rocm_not_supported_quantization = ["awq"]
if self.quantization is not None: if self.quantization is not None:
self.quantization = self.quantization.lower() self.quantization = self.quantization.lower()
@ -137,6 +153,11 @@ class ModelConfig:
raise ValueError( raise ValueError(
f"Unknown quantization method: {self.quantization}. Must " f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}.") f"be one of {supported_quantization}.")
if is_hip(
) and self.quantization in rocm_not_supported_quantization:
raise ValueError(
f"{self.quantization} quantization is currently not supported "
f"in ROCm.")
logger.warning(f"{self.quantization} quantization is not fully " logger.warning(f"{self.quantization} quantization is not fully "
"optimized yet. The speed can be slower than " "optimized yet. The speed can be slower than "
"non-quantized models.") "non-quantized models.")
@ -364,6 +385,8 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
"bfloat16": torch.bfloat16, "bfloat16": torch.bfloat16,
} }
_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]
def _get_and_verify_dtype( def _get_and_verify_dtype(
config: PretrainedConfig, config: PretrainedConfig,
@ -393,6 +416,14 @@ def _get_and_verify_dtype(
else: else:
raise ValueError(f"Unknown dtype: {dtype}") raise ValueError(f"Unknown dtype: {dtype}")
if is_hip() and torch_dtype == torch.float32:
rocm_supported_dtypes = [
k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
]
raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. "
f"Supported dtypes are {rocm_supported_dtypes}")
# Verify the dtype. # Verify the dtype.
if torch_dtype != config_dtype: if torch_dtype != config_dtype:
if torch_dtype == torch.float32: if torch_dtype == torch.float32:

View File

@ -3,6 +3,7 @@ from typing import Optional, Tuple, TYPE_CHECKING
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import is_hip
logger = init_logger(__name__) logger = init_logger(__name__)
@ -73,6 +74,11 @@ def initialize_cluster(
"Ray is not installed. Please install Ray to use distributed " "Ray is not installed. Please install Ray to use distributed "
"serving.") "serving.")
# Connect to a ray cluster. # Connect to a ray cluster.
if is_hip():
ray.init(address=ray_address,
ignore_reinit_error=True,
num_gpus=parallel_config.world_size)
else:
ray.init(address=ray_address, ignore_reinit_error=True) ray.init(address=ray_address, ignore_reinit_error=True)
if not parallel_config.worker_use_ray: if not parallel_config.worker_use_ray:

View File

@ -10,6 +10,7 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
from vllm._C import ops from vllm._C import ops
from vllm._C import cache_ops from vllm._C import cache_ops
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.utils import is_hip
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
@ -160,6 +161,8 @@ class PagedAttention(nn.Module):
attn_bias=input_metadata.attn_bias, attn_bias=input_metadata.attn_bias,
p=0.0, p=0.0,
scale=self.scale, scale=self.scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
) )
output = out.view_as(query) output = out.view_as(query)
else: else:

View File

@ -7,6 +7,7 @@ from vllm._C import ops
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.utils import is_hip
class SqueezeLLMConfig(QuantizationConfig): class SqueezeLLMConfig(QuantizationConfig):
@ -114,6 +115,11 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
lookup_table = weights["lookup_table"] lookup_table = weights["lookup_table"]
out_shape = x.shape[:-1] + (qweight.shape[-1], ) out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1]) reshaped_x = x.reshape(-1, x.shape[-1])
if is_hip():
out_f = torch.zeros(out_shape, device="cuda", dtype=torch.float)
ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table)
out = out_f.to(dtype=torch.float16)
else:
# NOTE: The output tensor should be zero-initialized. # NOTE: The output tensor should be zero-initialized.
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table) ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)

View File

@ -10,6 +10,10 @@ from vllm.config import ModelConfig
from vllm.model_executor.models import * from vllm.model_executor.models import *
from vllm.model_executor.weight_utils import (get_quant_config, from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights) initialize_dummy_weights)
from vllm.utils import is_hip
from vllm.logger import init_logger
logger = init_logger(__name__)
# TODO(woosuk): Lazy-load the model classes. # TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = { _MODEL_REGISTRY = {
@ -39,6 +43,18 @@ _MODEL_REGISTRY = {
"YiForCausalLM": YiForCausalLM, "YiForCausalLM": YiForCausalLM,
} }
# Models to be disabled in ROCm
_ROCM_UNSUPPORTED_MODELS = []
if is_hip():
for rocm_model in _ROCM_UNSUPPORTED_MODELS:
del _MODEL_REGISTRY[rocm_model]
# Models partially supported in ROCm
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
"MistralForCausalLM":
"Sliding window attention is not supported in ROCm's flash attention",
}
@contextlib.contextmanager @contextlib.contextmanager
def _set_default_torch_dtype(dtype: torch.dtype): def _set_default_torch_dtype(dtype: torch.dtype):
@ -53,7 +69,15 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", []) architectures = getattr(config, "architectures", [])
for arch in architectures: for arch in architectures:
if arch in _MODEL_REGISTRY: if arch in _MODEL_REGISTRY:
if is_hip() and arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
logger.warning(
f"{arch} is not fully supported in ROCm. Reason: "
f"{_ROCM_PARTIALLY_SUPPORTED_MODELS[arch]}")
return _MODEL_REGISTRY[arch] return _MODEL_REGISTRY[arch]
elif arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(
f"Model architecture {arch} is not supported by ROCm for now. \n"
f"Supported architectures {list(_MODEL_REGISTRY.keys())}")
raise ValueError( raise ValueError(
f"Model architectures {architectures} are not supported for now. " f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}") f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")

View File

@ -27,10 +27,14 @@ class Counter:
self.counter = 0 self.counter = 0
def is_hip() -> bool:
return torch.version.hip is not None
def get_max_shared_memory_bytes(gpu: int = 0) -> int: def get_max_shared_memory_bytes(gpu: int = 0) -> int:
"""Returns the maximum shared memory per thread block in bytes.""" """Returns the maximum shared memory per thread block in bytes."""
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
max_shared_mem = cuda_utils.get_device_attribute( max_shared_mem = cuda_utils.get_device_attribute(
cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu) cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu)
return int(max_shared_mem) return int(max_shared_mem)