Merge EmbeddedLLM/vllm-rocm into vLLM main (#1836)
Co-authored-by: Philipp Moritz <pcmoritz@gmail.com> Co-authored-by: Amir Balwel <amoooori04@gmail.com> Co-authored-by: root <kuanfu.liu@akirakan.com> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com> Co-authored-by: kuanfu <kuanfu.liu@embeddedllm.com> Co-authored-by: miloice <17350011+kliuae@users.noreply.github.com>
This commit is contained in:
parent
c8e7eb1eb3
commit
6ccc0bfffb
4
.gitignore
vendored
4
.gitignore
vendored
@ -177,3 +177,7 @@ _build/
|
||||
# vim swap files
|
||||
*.swo
|
||||
*.swp
|
||||
|
||||
# hip files generated by PyTorch
|
||||
*.hip
|
||||
*_hip*
|
||||
|
62
Dockerfile.rocm
Normal file
62
Dockerfile.rocm
Normal file
@ -0,0 +1,62 @@
|
||||
FROM rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get update && apt-get install python3 python3-pip -y
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
ca-certificates \
|
||||
sudo \
|
||||
git \
|
||||
bzip2 \
|
||||
libx11-6 \
|
||||
build-essential \
|
||||
wget \
|
||||
unzip \
|
||||
nvidia-cuda-toolkit \
|
||||
tmux \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
### Mount Point ###
|
||||
# When launching the container, mount the code directory to /app
|
||||
ARG APP_MOUNT=/app
|
||||
VOLUME [ ${APP_MOUNT} ]
|
||||
WORKDIR ${APP_MOUNT}
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
|
||||
|
||||
ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
|
||||
ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
|
||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
|
||||
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
|
||||
|
||||
# Install ROCm flash-attention
|
||||
RUN mkdir libs \
|
||||
&& cd libs \
|
||||
&& git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \
|
||||
&& cd flash-attention \
|
||||
&& git checkout 3d2b6f5 \
|
||||
&& git submodule update --init \
|
||||
&& export GPU_ARCHS=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) \
|
||||
&& patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch \
|
||||
&& python3 setup.py install \
|
||||
&& cd ..
|
||||
|
||||
COPY ./ /app/vllm
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
RUN pip install xformers==0.0.22.post7 --no-deps
|
||||
|
||||
RUN cd /app \
|
||||
&& cd vllm \
|
||||
&& pip install -U -r requirements-rocm.txt \
|
||||
&& bash patch_xformers-0.0.22.post7.rocm.sh \
|
||||
&& python3 setup.py install \
|
||||
&& cd ..
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
RUN python3 -m pip install --no-cache-dir ray[all]
|
||||
|
||||
CMD ["/bin/bash"]
|
@ -17,6 +17,7 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
---
|
||||
|
||||
*Latest News* 🔥
|
||||
- [2023/12] Added ROCm support to vLLM.
|
||||
- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
|
||||
- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
|
||||
- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
|
||||
@ -43,6 +44,7 @@ vLLM is flexible and easy to use with:
|
||||
- Tensor parallelism support for distributed inference
|
||||
- Streaming outputs
|
||||
- OpenAI-compatible API server
|
||||
- Support NVIDIA CUDA and AMD ROCm.
|
||||
|
||||
vLLM seamlessly supports many Hugging Face models, including the following architectures:
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
@ -18,8 +19,8 @@ __global__ void silu_and_mul_kernel(
|
||||
const int d) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
|
||||
const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
||||
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
||||
out[token_idx * d + idx] = silu(x) * y;
|
||||
}
|
||||
}
|
||||
@ -57,7 +58,7 @@ __global__ void activation_kernel(
|
||||
const int d) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = __ldg(&input[token_idx * d + idx]);
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
|
||||
out[token_idx * d + idx] = ACT_FN(x);
|
||||
}
|
||||
}
|
||||
|
@ -15,6 +15,10 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_runtime.h>
|
||||
#endif
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
@ -23,7 +27,11 @@
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#else
|
||||
#define WARP_SIZE warpSize
|
||||
#endif
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
||||
@ -40,7 +48,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
|
||||
// Compute the sum per warp.
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
||||
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
|
||||
}
|
||||
|
||||
// Warp leaders store the data to shared memory.
|
||||
@ -59,11 +67,11 @@ inline __device__ float block_sum(float* red_smem, float sum) {
|
||||
// Parallel reduction inside the warp.
|
||||
#pragma unroll
|
||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
||||
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
|
||||
}
|
||||
|
||||
// Broadcast to other threads.
|
||||
return __shfl_sync(uint32_t(-1), sum, 0);
|
||||
return VLLM_SHFL_SYNC(sum, 0);
|
||||
}
|
||||
|
||||
// TODO(woosuk): Merge the last two dimensions of the grid.
|
||||
@ -223,7 +231,7 @@ __device__ void paged_attention_kernel(
|
||||
// The 0-th thread of each thread group already has its max qk value.
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
|
||||
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
||||
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
||||
}
|
||||
if (lane == 0) {
|
||||
red_smem[warp_idx] = qk_max;
|
||||
@ -235,10 +243,10 @@ __device__ void paged_attention_kernel(
|
||||
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
||||
#pragma unroll
|
||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
||||
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
||||
}
|
||||
// Broadcast the max qk value to all threads.
|
||||
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
|
||||
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
|
||||
|
||||
// Get the sum of the exp values.
|
||||
float exp_sum = 0.f;
|
||||
@ -326,7 +334,7 @@ __device__ void paged_attention_kernel(
|
||||
float acc = accs[i];
|
||||
#pragma unroll
|
||||
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
|
||||
acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
|
||||
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
|
||||
}
|
||||
accs[i] = acc;
|
||||
}
|
||||
@ -492,7 +500,7 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
// Reduce within the warp.
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||
max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
|
||||
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
|
||||
}
|
||||
if (lane == 0) {
|
||||
red_smem[warp_idx] = max_logit;
|
||||
@ -502,10 +510,10 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
||||
#pragma unroll
|
||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||
max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
|
||||
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
|
||||
}
|
||||
// Broadcast the max value to all threads.
|
||||
max_logit = __shfl_sync(uint32_t(-1), max_logit, 0);
|
||||
max_logit = VLLM_SHFL_SYNC(max_logit, 0);
|
||||
|
||||
// Load rescaled exp sums to shared memory.
|
||||
float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
|
||||
@ -539,9 +547,9 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
} // namespace vllm
|
||||
|
||||
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
||||
cudaFuncSetAttribute( \
|
||||
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
|
||||
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
||||
((void*)vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>), \
|
||||
shared_mem_size); \
|
||||
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
||||
<<<grid, block, shared_mem_size, stream>>>( \
|
||||
out_ptr, \
|
||||
|
@ -17,6 +17,7 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "../cuda_compat.h"
|
||||
#include "attention_dtypes.h"
|
||||
|
||||
#include <float.h>
|
||||
@ -39,7 +40,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
|
||||
float qk = sum(qk_vec);
|
||||
#pragma unroll
|
||||
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
|
||||
qk += VLLM_SHFL_XOR_SYNC(qk, mask);
|
||||
}
|
||||
return qk;
|
||||
}
|
||||
|
@ -21,8 +21,17 @@
|
||||
#include "attention_generic.cuh"
|
||||
#include "dtype_float32.cuh"
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
typedef __hip_bfloat162 __nv_bfloat162;
|
||||
typedef __hip_bfloat16 __nv_bfloat16;
|
||||
#endif
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace vllm {
|
||||
@ -98,7 +107,11 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
return a + b;
|
||||
#ifndef USE_ROCM
|
||||
return a + b;
|
||||
#else
|
||||
return __hadd(a, b);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -21,6 +21,10 @@
|
||||
#include "attention_generic.cuh"
|
||||
#include "dtype_float32.cuh"
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_fp16.h>
|
||||
#endif
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace vllm {
|
||||
@ -63,21 +67,47 @@ struct FloatVec<uint4> {
|
||||
|
||||
// Utility functions for type conversions.
|
||||
inline __device__ uint32_t h0_h0(uint16_t a) {
|
||||
#ifndef USE_ROCM
|
||||
uint32_t b;
|
||||
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
|
||||
return b;
|
||||
#else
|
||||
union {
|
||||
uint32_t u32;
|
||||
uint16_t u16[2];
|
||||
} tmp;
|
||||
tmp.u16[0] = a;
|
||||
tmp.u16[1] = a;
|
||||
return tmp.u32;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ float half_to_float(uint16_t h) {
|
||||
float f;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
|
||||
#else
|
||||
asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
|
||||
#endif
|
||||
return f;
|
||||
}
|
||||
|
||||
inline __device__ float2 half2_to_float2(uint32_t v) {
|
||||
#ifndef USE_ROCM
|
||||
uint16_t lo, hi;
|
||||
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
|
||||
return make_float2(half_to_float(lo), half_to_float(hi));
|
||||
#else
|
||||
union {
|
||||
uint32_t u32;
|
||||
uint16_t u16[2];
|
||||
} tmp;
|
||||
tmp.u32 = v;
|
||||
float2 ret;
|
||||
ret.x = half_to_float(tmp.u16[0]);
|
||||
ret.y = half_to_float(tmp.u16[1]);
|
||||
return ret;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ uint16_t float_to_half(float f) {
|
||||
@ -85,7 +115,11 @@ inline __device__ uint16_t float_to_half(float f) {
|
||||
uint32_t u32;
|
||||
uint16_t u16[2];
|
||||
} tmp;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
|
||||
#else
|
||||
asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
|
||||
#endif
|
||||
return tmp.u16[0];
|
||||
}
|
||||
|
||||
@ -94,12 +128,16 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
|
||||
uint32_t u32;
|
||||
uint16_t u16[2];
|
||||
} tmp;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
|
||||
#ifndef USE_ROCM
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
|
||||
#else
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
||||
#endif
|
||||
#else
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
||||
tmp.u16[0] = float_to_half(f.x);
|
||||
tmp.u16[1] = float_to_half(f.y);
|
||||
#endif
|
||||
return tmp.u32;
|
||||
}
|
||||
@ -107,13 +145,21 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
|
||||
// Vector addition.
|
||||
inline __device__ uint16_t add(uint16_t a, uint16_t b) {
|
||||
uint16_t c;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
||||
#else
|
||||
asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||
#endif
|
||||
return c;
|
||||
}
|
||||
|
||||
inline __device__ uint32_t add(uint32_t a, uint32_t b) {
|
||||
uint32_t c;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
||||
#else
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||
#endif
|
||||
return c;
|
||||
}
|
||||
|
||||
@ -158,14 +204,22 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) {
|
||||
template<>
|
||||
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
|
||||
uint16_t c;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
||||
#else
|
||||
asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||
#endif
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
|
||||
uint32_t c;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
||||
#else
|
||||
asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||
#endif
|
||||
return c;
|
||||
}
|
||||
|
||||
@ -272,7 +326,11 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
|
||||
// Vector fused multiply-add.
|
||||
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
|
||||
uint32_t d;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
|
||||
#else
|
||||
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
|
||||
#endif
|
||||
return d;
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
@ -28,8 +29,8 @@ void swap_blocks(
|
||||
TORCH_CHECK(false, "Invalid device combination");
|
||||
}
|
||||
|
||||
void *src_ptr = src.data_ptr();
|
||||
void *dst_ptr = dst.data_ptr();
|
||||
char *src_ptr = static_cast<char*>(src.data_ptr());
|
||||
char *dst_ptr = static_cast<char*>(dst.data_ptr());
|
||||
|
||||
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
@ -267,8 +268,8 @@ __global__ void gather_cached_kv_kernel(
|
||||
+ head_offset * block_size
|
||||
+ block_offset;
|
||||
|
||||
key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]);
|
||||
value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]);
|
||||
key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]);
|
||||
value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -333,8 +334,8 @@ __global__ void gather_cached_kv_kernel_optimized(
|
||||
src_key_indices[j] = src_key_idx;
|
||||
src_value_indices[j] = src_value_idx;
|
||||
|
||||
keys_to_store[j] = __ldg(&key_cache[src_key_idx]);
|
||||
values_to_store[j] = __ldg(&value_cache[src_value_idx]);
|
||||
keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]);
|
||||
values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
|
28
csrc/cuda_compat.h
Normal file
28
csrc/cuda_compat.h
Normal file
@ -0,0 +1,28 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_LDG(arg) __ldg(arg)
|
||||
#else
|
||||
#define VLLM_LDG(arg) *(arg)
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
|
||||
#else
|
||||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
|
||||
#else
|
||||
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
||||
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
||||
#else
|
||||
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
||||
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
||||
#endif
|
||||
|
@ -1,3 +1,6 @@
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_runtime.h>
|
||||
#endif
|
||||
int get_device_attribute(
|
||||
int attribute,
|
||||
int device_id)
|
||||
|
@ -61,12 +61,14 @@ void gelu_fast(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor awq_gemm(
|
||||
torch::Tensor _in_feats,
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters);
|
||||
#endif
|
||||
|
||||
void squeezellm_gemm(
|
||||
torch::Tensor vec,
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
@ -19,14 +20,14 @@ inline __device__ void apply_rotary_embedding(
|
||||
// GPT-NeoX style rotary embedding.
|
||||
x_index = rot_offset;
|
||||
y_index = embed_dim + rot_offset;
|
||||
cos = __ldg(cos_ptr + x_index);
|
||||
sin = __ldg(sin_ptr + x_index);
|
||||
cos = VLLM_LDG(cos_ptr + x_index);
|
||||
sin = VLLM_LDG(sin_ptr + x_index);
|
||||
} else {
|
||||
// GPT-J style rotary embedding.
|
||||
x_index = 2 * rot_offset;
|
||||
y_index = 2 * rot_offset + 1;
|
||||
cos = __ldg(cos_ptr + x_index / 2);
|
||||
sin = __ldg(sin_ptr + x_index / 2);
|
||||
cos = VLLM_LDG(cos_ptr + x_index / 2);
|
||||
sin = VLLM_LDG(sin_ptr + x_index / 2);
|
||||
}
|
||||
|
||||
const scalar_t x = arr[x_index];
|
||||
|
@ -48,8 +48,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
&rotary_embedding,
|
||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Quantization ops
|
||||
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||
#endif
|
||||
|
||||
|
||||
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
||||
|
||||
// Cache ops
|
||||
|
@ -20,9 +20,17 @@ __device__ inline unsigned int as_unsigned(int i) {
|
||||
|
||||
// 4-bit matvec kernel (LUT-based)
|
||||
__global__ void NUQ4MatMulKernel(
|
||||
#ifndef USE_ROCM
|
||||
const half2* __restrict__ vec,
|
||||
#else
|
||||
const __half2* __restrict__ vec,
|
||||
#endif
|
||||
const int* __restrict__ mat,
|
||||
#ifndef USE_ROCM
|
||||
half2* __restrict__ mul,
|
||||
#else
|
||||
float2* __restrict__ mul,
|
||||
#endif
|
||||
const __half* __restrict__ lookup_table,
|
||||
int height,
|
||||
int width,
|
||||
@ -35,7 +43,11 @@ __global__ void NUQ4MatMulKernel(
|
||||
int row = BLOCKHEIGHT4 * blockIdx.x;
|
||||
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||
|
||||
#ifndef USE_ROCM
|
||||
__shared__ half2 blockvec[blockwidth2];
|
||||
#else
|
||||
__shared__ __half2 blockvec[blockwidth2];
|
||||
#endif
|
||||
|
||||
__shared__ __half deq2[16][BLOCKWIDTH];
|
||||
int off = threadIdx.x;
|
||||
@ -46,8 +58,13 @@ __global__ void NUQ4MatMulKernel(
|
||||
}
|
||||
|
||||
__half res;
|
||||
#ifndef USE_ROCM
|
||||
half2 res2;
|
||||
half2 tmp2;
|
||||
#else
|
||||
__half2 res2;
|
||||
__half2 tmp2;
|
||||
#endif
|
||||
|
||||
int i;
|
||||
int k;
|
||||
@ -68,48 +85,96 @@ __global__ void NUQ4MatMulKernel(
|
||||
while (k < blockwidth2) {
|
||||
tmp1 = as_unsigned(mat[i]);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
res2 = {};
|
||||
tmp2 = {};
|
||||
#else
|
||||
res2.x = __half_as_ushort(__float2half(0));
|
||||
res2.y = __half_as_ushort(__float2half(0));
|
||||
tmp2.x = __half_as_ushort(__float2half(0));
|
||||
tmp2.y = __half_as_ushort(__float2half(0));
|
||||
#endif
|
||||
|
||||
lut_index1 = tmp1 & 0xF;
|
||||
lut_index2 = (tmp1 >> 4) & 0xF;
|
||||
#ifndef USE_ROCM
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
#else
|
||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||
#endif
|
||||
res2 = __hfma2(tmp2, blockvec[k + 0], res2);
|
||||
|
||||
lut_index1 = (tmp1 >> 8) & 0xF;
|
||||
lut_index2 = (tmp1 >> 12) & 0xF;
|
||||
#ifndef USE_ROCM
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
#else
|
||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||
#endif
|
||||
res2 = __hfma2(tmp2, blockvec[k + 1], res2);
|
||||
|
||||
lut_index1 = (tmp1 >> 16) & 0xF;
|
||||
lut_index2 = (tmp1 >> 20) & 0xF;
|
||||
#ifndef USE_ROCM
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
#else
|
||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||
#endif
|
||||
res2 = __hfma2(tmp2, blockvec[k + 2], res2);
|
||||
|
||||
lut_index1 = (tmp1 >> 24) & 0xF;
|
||||
lut_index2 = (tmp1 >> 28) & 0xF;
|
||||
#ifndef USE_ROCM
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
#else
|
||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||
#endif
|
||||
res2 = __hfma2(tmp2, blockvec[k + 3], res2);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
res = __hadd(__hadd(res2.x, res2.y), res);
|
||||
#else
|
||||
res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res);
|
||||
#endif
|
||||
|
||||
i += width;
|
||||
k += 4;
|
||||
}
|
||||
|
||||
// col%2 -> only set one of the two values
|
||||
#ifndef USE_ROCM
|
||||
half2 res3 = {};
|
||||
if (col % 2 == 0) {
|
||||
res3.x = res;
|
||||
} else {
|
||||
res3.y = res;
|
||||
}
|
||||
#else
|
||||
__half2 res3;
|
||||
res3.x = __half_as_ushort(__float2half(0));
|
||||
res3.y = __half_as_ushort(__float2half(0));
|
||||
if (col % 2 == 0) {
|
||||
res3.x = __half_as_ushort(res);
|
||||
} else {
|
||||
res3.y = __half_as_ushort(res);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
atomicAdd(&mul[b * width / 2 + col / 2], res3);
|
||||
#else
|
||||
int tmp_addr = b * width / 2 + col / 2;
|
||||
atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x)));
|
||||
atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y)));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@ -136,10 +201,19 @@ void squeezellm_gemm(
|
||||
dim3 threads(BLOCKWIDTH);
|
||||
|
||||
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>(
|
||||
#ifndef USE_ROCM
|
||||
(half2*) vec.data<at::Half>(),
|
||||
#else
|
||||
(__half2*) vec.data_ptr<at::Half>(),
|
||||
#endif
|
||||
mat.data_ptr<int>(),
|
||||
#ifndef USE_ROCM
|
||||
(half2*) mul.data<at::Half>(),
|
||||
(__half*) lookup_table.data<at::Half>(),
|
||||
#else
|
||||
(float2*) mul.data_ptr<float>(),
|
||||
(__half*) lookup_table.data_ptr<at::Half>(),
|
||||
#endif
|
||||
height, width, batch, vec_height
|
||||
);
|
||||
}
|
||||
|
@ -17,13 +17,15 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cuda_compat.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename T>
|
||||
__inline__ __device__ T warpReduceSum(T val) {
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1)
|
||||
val += __shfl_xor_sync(0xffffffff, val, mask, 32);
|
||||
val += VLLM_SHFL_XOR_SYNC(val, mask);
|
||||
return val;
|
||||
}
|
||||
|
||||
|
143
docs/source/getting_started/amd-installation.rst
Normal file
143
docs/source/getting_started/amd-installation.rst
Normal file
@ -0,0 +1,143 @@
|
||||
.. _installation_rocm:
|
||||
|
||||
Installation with ROCm
|
||||
======================
|
||||
|
||||
vLLM 0.2.x onwards supports model inferencing and serving on AMD GPUs with ROCm.
|
||||
At the moment AWQ quantization is not supported in ROCm, but SqueezeLLM quantization has been ported.
|
||||
Data types currently supported in ROCm are FP16 and BF16.
|
||||
|
||||
Requirements
|
||||
------------
|
||||
|
||||
* OS: Linux
|
||||
* Python: 3.8 -- 3.11 (Verified on 3.10)
|
||||
* GPU: MI200s
|
||||
* Pytorch 2.0.1/2.1.1/2.2
|
||||
* ROCm 5.7
|
||||
|
||||
Installation options:
|
||||
|
||||
#. :ref:`(Recommended) Quick start with vLLM pre-installed in Docker Image <quick_start_docker_rocm>`
|
||||
#. :ref:`Build from source <build_from_source_rocm>`
|
||||
#. :ref:`Build from source with docker <build_from_source_docker_rocm>`
|
||||
|
||||
.. _quick_start_docker_rocm:
|
||||
|
||||
(Recommended) Option 1: Quick start with vLLM pre-installed in Docker Image
|
||||
---------------------------------------------------------------------------
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.3
|
||||
$ docker run -it \
|
||||
--network=host \
|
||||
--group-add=video \
|
||||
--ipc=host \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--security-opt seccomp=unconfined \
|
||||
--device /dev/kfd \
|
||||
--device /dev/dri \
|
||||
-v <path/to/model>:/app/model \
|
||||
embeddedllminfo/vllm-rocm \
|
||||
bash
|
||||
|
||||
|
||||
.. _build_from_source_rocm:
|
||||
|
||||
Option 2: Build from source
|
||||
---------------------------
|
||||
|
||||
You can build and install vLLM from source:
|
||||
|
||||
0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
|
||||
|
||||
- `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_
|
||||
- `Pytorch <https://pytorch.org/>`_
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install torch==2.2.0.dev20231206+rocm5.7 --index-url https://download.pytorch.org/whl/nightly/rocm5.7 # tested version
|
||||
|
||||
|
||||
1. Install `flash attention for ROCm <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm>`_
|
||||
|
||||
Install ROCm's flash attention (v2.0.4) following the instructions from `ROCmSoftwarePlatform/flash-attention <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm#amd-gpurocm-support>`_
|
||||
|
||||
.. note::
|
||||
- If you are using rocm5.7 with pytorch 2.1.0 onwards, you don't need to apply the `hipify_python.patch`. You can build the ROCm flash attention directly.
|
||||
- If you fail to install `ROCmSoftwarePlatform/flash-attention`, try cloning from the commit `6fd2f8e572805681cd67ef8596c7e2ce521ed3c6`.
|
||||
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
|
||||
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
|
||||
|
||||
2. Setup `xformers==0.0.22.post7` without dependencies, and apply patches to adapt for ROCm flash attention
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install xformers==0.0.22.post7 --no-deps
|
||||
$ bash patch_xformers-0.0.22.post7.rocm.sh
|
||||
|
||||
3. Build vLLM.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ cd vllm
|
||||
$ pip install -U -r requirements-rocm.txt
|
||||
$ python setup.py install # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation
|
||||
|
||||
|
||||
.. _build_from_source_docker_rocm:
|
||||
|
||||
Option 3: Build from source with docker
|
||||
-----------------------------------------------------
|
||||
|
||||
You can build and install vLLM from source:
|
||||
|
||||
Build a docker image from `Dockerfile.rocm`, and launch a docker container.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ docker build -f Dockerfile.rocm -t vllm-rocm .
|
||||
$ docker run -it \
|
||||
--network=host \
|
||||
--group-add=video \
|
||||
--ipc=host \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--security-opt seccomp=unconfined \
|
||||
--device /dev/kfd \
|
||||
--device /dev/dri \
|
||||
-v <path/to/model>:/app/model \
|
||||
vllm-rocm \
|
||||
bash
|
||||
|
||||
Alternatively, if you plan to install vLLM-ROCm on a local machine or start from a fresh docker image (e.g. rocm/pytorch), you can follow the steps below:
|
||||
|
||||
0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
|
||||
|
||||
- `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_
|
||||
- `Pytorch <https://pytorch.org/>`_
|
||||
|
||||
1. Install `flash attention for ROCm <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm>`_
|
||||
|
||||
Install ROCm's flash attention (v2.0.4) following the instructions from `ROCmSoftwarePlatform/flash-attention <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm#amd-gpurocm-support>`_
|
||||
|
||||
.. note::
|
||||
- If you are using rocm5.7 with pytorch 2.1.0 onwards, you don't need to apply the `hipify_python.patch`. You can build the ROCm flash attention directly.
|
||||
- If you fail to install `ROCmSoftwarePlatform/flash-attention`, try cloning from the commit `6fd2f8e572805681cd67ef8596c7e2ce521ed3c6`.
|
||||
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
|
||||
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
|
||||
|
||||
2. Setup `xformers==0.0.22.post7` without dependencies, and apply patches to adapt for ROCm flash attention
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install xformers==0.0.22.post7 --no-deps
|
||||
$ bash patch_xformers-0.0.22.post7.rocm.sh
|
||||
|
||||
3. Build vLLM.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ cd vllm
|
||||
$ pip install -U -r requirements-rocm.txt
|
||||
$ python setup.py install # This may take 5-10 minutes.
|
@ -39,6 +39,7 @@ vLLM is flexible and easy to use with:
|
||||
* Tensor parallelism support for distributed inference
|
||||
* Streaming outputs
|
||||
* OpenAI-compatible API server
|
||||
* Support NVIDIA CUDA and AMD ROCm.
|
||||
|
||||
For more information, check out the following:
|
||||
|
||||
@ -56,6 +57,7 @@ Documentation
|
||||
:caption: Getting Started
|
||||
|
||||
getting_started/installation
|
||||
getting_started/amd-installation
|
||||
getting_started/quickstart
|
||||
|
||||
.. toctree::
|
||||
|
22
patch_xformers-0.0.22.post7.rocm.sh
Normal file
22
patch_xformers-0.0.22.post7.rocm.sh
Normal file
@ -0,0 +1,22 @@
|
||||
#!/bin/bash
|
||||
export XFORMERS_FMHA_FLASH_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.flash.__file__)')
|
||||
export XFORMERS_FMHA_COMMON_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.common.__file__)')
|
||||
|
||||
echo $XFORMERS_FMHA_FLASH_PATH
|
||||
echo $XFORMERS_FMHA_COMMON_PATH
|
||||
|
||||
if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch"; then
|
||||
echo "Applying patch to ${XFORMERS_FMHA_FLASH_PATH}"
|
||||
patch -p0 $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch"
|
||||
echo "Successfully patch ${XFORMERS_FMHA_FLASH_PATH}"
|
||||
else
|
||||
echo "${XFORMERS_FMHA_FLASH_PATH} was patched before"
|
||||
fi
|
||||
|
||||
if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch"; then
|
||||
echo "Applying patch to ${XFORMERS_FMHA_COMMON_PATH}"
|
||||
patch -p0 $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch"
|
||||
echo "Successfully patch ${XFORMERS_FMHA_COMMON_PATH}"
|
||||
else
|
||||
echo "${XFORMERS_FMHA_COMMON_PATH} was patched before"
|
||||
fi
|
17
requirements-rocm.txt
Normal file
17
requirements-rocm.txt
Normal file
@ -0,0 +1,17 @@
|
||||
ninja # For faster builds.
|
||||
typing-extensions>=4.8.0
|
||||
starlette
|
||||
psutil
|
||||
ray >= 2.5.1
|
||||
pandas # Required for Ray data.
|
||||
pyarrow # Required for Ray data.
|
||||
sentencepiece # Required for LLaMA tokenizer.
|
||||
numpy
|
||||
tokenizers>=0.15.0
|
||||
huggingface_hub<0.18,>=0.16.4
|
||||
einops # Required for phi-1_5
|
||||
transformers >= 4.34.0 # Required for Mistral.
|
||||
fastapi
|
||||
uvicorn[standard]
|
||||
pydantic == 1.10.13 # Required for OpenAI server.
|
||||
aioprometheus[starlette]
|
13
rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch
Normal file
13
rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch
Normal file
@ -0,0 +1,13 @@
|
||||
--- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/common.py 2023-11-29 03:17:03.930103539 +0000
|
||||
+++ common.py 2023-11-28 16:14:19.846233146 +0000
|
||||
@@ -298,8 +298,8 @@
|
||||
dtype = d.query.dtype
|
||||
if device_type not in cls.SUPPORTED_DEVICES:
|
||||
reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})")
|
||||
- if device_type == "cuda" and not _built_with_cuda:
|
||||
- reasons.append("xFormers wasn't build with CUDA support")
|
||||
+ #if device_type == "cuda" and not _built_with_cuda:
|
||||
+ # reasons.append("xFormers wasn't build with CUDA support")
|
||||
if device_type == "cuda":
|
||||
device_capability = torch.cuda.get_device_capability(d.device)
|
||||
if device_capability < cls.CUDA_MINIMUM_COMPUTE_CAPABILITY:
|
134
rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch
Normal file
134
rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch
Normal file
@ -0,0 +1,134 @@
|
||||
--- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/flash.py 2023-11-29 03:17:03.930103539 +0000
|
||||
+++ flash.py 2023-11-28 16:14:25.206128903 +0000
|
||||
@@ -31,39 +31,39 @@
|
||||
|
||||
FLASH_VERSION = "0.0.0"
|
||||
try:
|
||||
- try:
|
||||
- from ... import _C_flashattention # type: ignore[attr-defined]
|
||||
- from ..._cpp_lib import _build_metadata
|
||||
-
|
||||
- if _build_metadata is not None:
|
||||
- FLASH_VERSION = _build_metadata.flash_version
|
||||
- except ImportError:
|
||||
- import flash_attn
|
||||
- from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
|
||||
-
|
||||
- FLASH_VERSION = flash_attn.__version__
|
||||
- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
|
||||
- if flash_ver_parsed < (2, 3):
|
||||
- raise ImportError("Requires 2.3 for sliding window support")
|
||||
+ #try:
|
||||
+ # from ... import _C_flashattention # type: ignore[attr-defined]
|
||||
+ # from ..._cpp_lib import _build_metadata
|
||||
+
|
||||
+ # if _build_metadata is not None:
|
||||
+ # FLASH_VERSION = _build_metadata.flash_version
|
||||
+ #except ImportError:
|
||||
+ import flash_attn
|
||||
+ from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
|
||||
+
|
||||
+ FLASH_VERSION = flash_attn.__version__
|
||||
+ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2])
|
||||
+ # if flash_ver_parsed < (2, 3):
|
||||
+ # raise ImportError("Requires 2.3 for sliding window support")
|
||||
|
||||
# create library so that flash-attn goes through the PyTorch Dispatcher
|
||||
- _flash_lib = torch.library.Library("xformers_flash", "DEF")
|
||||
+ #_flash_lib = torch.library.Library("xformers_flash", "DEF")
|
||||
|
||||
- _flash_lib.define(
|
||||
- "flash_fwd(Tensor query, Tensor key, Tensor value, "
|
||||
- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
|
||||
- "int max_seqlen_q, int max_seqlen_k, "
|
||||
- "float p, float softmax_scale, "
|
||||
- "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)"
|
||||
- )
|
||||
-
|
||||
- _flash_lib.define(
|
||||
- "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
|
||||
- "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
|
||||
- "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
|
||||
- "int max_seqlen_q, int max_seqlen_k, "
|
||||
- "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
|
||||
- )
|
||||
+ #_flash_lib.define(
|
||||
+ # "flash_fwd(Tensor query, Tensor key, Tensor value, "
|
||||
+ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, "
|
||||
+ # "int max_seqlen_q, int max_seqlen_k, "
|
||||
+ # "float p, float softmax_scale, "
|
||||
+ # "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)"
|
||||
+ #)
|
||||
+
|
||||
+ #_flash_lib.define(
|
||||
+ # "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
|
||||
+ # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
|
||||
+ # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
|
||||
+ # "int max_seqlen_q, int max_seqlen_k, "
|
||||
+ # "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
|
||||
+ #)
|
||||
|
||||
def _flash_fwd(
|
||||
query,
|
||||
@@ -98,8 +98,8 @@
|
||||
p,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
- window_size - 1, # window_size_left
|
||||
- -1, # window_size_right
|
||||
+ # window_size - 1, # window_size_left
|
||||
+ # -1, # window_size_right
|
||||
return_softmax,
|
||||
None, # rng
|
||||
)
|
||||
@@ -127,8 +127,8 @@
|
||||
softmax_scale,
|
||||
False,
|
||||
is_causal,
|
||||
- window_size - 1, # window_size_left
|
||||
- -1, # window_size_right
|
||||
+ # window_size - 1, # window_size_left
|
||||
+ # -1, # window_size_right
|
||||
return_softmax,
|
||||
None,
|
||||
)
|
||||
@@ -169,8 +169,8 @@
|
||||
p,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
- window_size - 1, # window_size_left
|
||||
- -1, # window_size_right
|
||||
+ # window_size - 1, # window_size_left
|
||||
+ # -1, # window_size_right
|
||||
None,
|
||||
rng_state,
|
||||
)
|
||||
@@ -193,15 +193,15 @@
|
||||
softmax_scale,
|
||||
False, # zero_tensors
|
||||
is_causal,
|
||||
- window_size - 1, # window_size_left
|
||||
- -1, # window_size_right
|
||||
+ # window_size - 1, # window_size_left
|
||||
+ # -1, # window_size_right
|
||||
None,
|
||||
rng_state,
|
||||
)
|
||||
return dq, dk, dv
|
||||
|
||||
- _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
|
||||
- _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
|
||||
+ #_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
|
||||
+ #_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@@ -348,7 +348,7 @@
|
||||
implementation.
|
||||
"""
|
||||
|
||||
- OPERATOR = get_operator("xformers_flash", "flash_fwd")
|
||||
+ OPERATOR = _flash_fwd # get_operator("xformers_flash", "flash_fwd")
|
||||
SUPPORTED_DEVICES: Set[str] = {"cuda"}
|
||||
CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
|
||||
SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}
|
232
setup.py
232
setup.py
@ -8,27 +8,83 @@ import warnings
|
||||
from packaging.version import parse, Version
|
||||
import setuptools
|
||||
import torch
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
|
||||
|
||||
ROOT_DIR = os.path.dirname(__file__)
|
||||
|
||||
MAIN_CUDA_VERSION = "12.1"
|
||||
|
||||
# Supported NVIDIA GPU architectures.
|
||||
SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
|
||||
NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
|
||||
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"}
|
||||
# SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS)
|
||||
|
||||
|
||||
def _is_hip() -> bool:
|
||||
return torch.version.hip is not None
|
||||
|
||||
|
||||
def _is_cuda() -> bool:
|
||||
return torch.version.cuda is not None
|
||||
|
||||
|
||||
# Compiler flags.
|
||||
CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
|
||||
# TODO(woosuk): Should we use -O3?
|
||||
NVCC_FLAGS = ["-O2", "-std=c++17"]
|
||||
|
||||
if _is_hip():
|
||||
if ROCM_HOME is None:
|
||||
raise RuntimeError(
|
||||
"Cannot find ROCM_HOME. ROCm must be available to build the package."
|
||||
)
|
||||
NVCC_FLAGS += ["-DUSE_ROCM"]
|
||||
|
||||
if _is_cuda() and CUDA_HOME is None:
|
||||
raise RuntimeError(
|
||||
"Cannot find CUDA_HOME. CUDA must be available to build the package.")
|
||||
|
||||
ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
|
||||
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
|
||||
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
|
||||
|
||||
if CUDA_HOME is None:
|
||||
raise RuntimeError(
|
||||
"Cannot find CUDA_HOME. CUDA must be available to build the package.")
|
||||
|
||||
def get_amdgpu_offload_arch():
|
||||
command = "/opt/rocm/llvm/bin/amdgpu-offload-arch"
|
||||
try:
|
||||
output = subprocess.check_output([command])
|
||||
return output.decode('utf-8').strip()
|
||||
except subprocess.CalledProcessError as e:
|
||||
error_message = f"Error: {e}"
|
||||
raise RuntimeError(error_message) from e
|
||||
except FileNotFoundError as e:
|
||||
# If the command is not found, print an error message
|
||||
error_message = f"The command {command} was not found."
|
||||
raise RuntimeError(error_message) from e
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_hipcc_rocm_version():
|
||||
# Run the hipcc --version command
|
||||
result = subprocess.run(['hipcc', '--version'],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True)
|
||||
|
||||
# Check if the command was executed successfully
|
||||
if result.returncode != 0:
|
||||
print("Error running 'hipcc --version'")
|
||||
return None
|
||||
|
||||
# Extract the version using a regular expression
|
||||
match = re.search(r'HIP version: (\S+)', result.stdout)
|
||||
if match:
|
||||
# Return the version string
|
||||
return match.group(1)
|
||||
else:
|
||||
print("Could not find HIP version in the output")
|
||||
return None
|
||||
|
||||
|
||||
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
||||
@ -61,20 +117,22 @@ def get_torch_arch_list() -> Set[str]:
|
||||
return set()
|
||||
|
||||
# Filter out the invalid architectures and print a warning.
|
||||
valid_archs = SUPPORTED_ARCHS.union({s + "+PTX" for s in SUPPORTED_ARCHS})
|
||||
valid_archs = NVIDIA_SUPPORTED_ARCHS.union(
|
||||
{s + "+PTX"
|
||||
for s in NVIDIA_SUPPORTED_ARCHS})
|
||||
arch_list = torch_arch_list.intersection(valid_archs)
|
||||
# If none of the specified architectures are valid, raise an error.
|
||||
if not arch_list:
|
||||
raise RuntimeError(
|
||||
"None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env "
|
||||
"None of the CUDA/ROCM architectures in `TORCH_CUDA_ARCH_LIST` env "
|
||||
f"variable ({env_arch_list}) is supported. "
|
||||
f"Supported CUDA architectures are: {valid_archs}.")
|
||||
f"Supported CUDA/ROCM architectures are: {valid_archs}.")
|
||||
invalid_arch_list = torch_arch_list - valid_archs
|
||||
if invalid_arch_list:
|
||||
warnings.warn(
|
||||
f"Unsupported CUDA architectures ({invalid_arch_list}) are "
|
||||
f"Unsupported CUDA/ROCM architectures ({invalid_arch_list}) are "
|
||||
"excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
|
||||
f"({env_arch_list}). Supported CUDA architectures are: "
|
||||
f"({env_arch_list}). Supported CUDA/ROCM architectures are: "
|
||||
f"{valid_archs}.",
|
||||
stacklevel=2)
|
||||
return arch_list
|
||||
@ -82,7 +140,7 @@ def get_torch_arch_list() -> Set[str]:
|
||||
|
||||
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
|
||||
compute_capabilities = get_torch_arch_list()
|
||||
if not compute_capabilities:
|
||||
if _is_cuda() and not compute_capabilities:
|
||||
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
|
||||
# GPUs on the current machine.
|
||||
device_count = torch.cuda.device_count()
|
||||
@ -93,69 +151,84 @@ if not compute_capabilities:
|
||||
"GPUs with compute capability below 7.0 are not supported.")
|
||||
compute_capabilities.add(f"{major}.{minor}")
|
||||
|
||||
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
|
||||
if not compute_capabilities:
|
||||
# If no GPU is specified nor available, add all supported architectures
|
||||
# based on the NVCC CUDA version.
|
||||
compute_capabilities = SUPPORTED_ARCHS.copy()
|
||||
if nvcc_cuda_version < Version("11.1"):
|
||||
compute_capabilities.remove("8.6")
|
||||
if nvcc_cuda_version < Version("11.8"):
|
||||
compute_capabilities.remove("8.9")
|
||||
compute_capabilities.remove("9.0")
|
||||
|
||||
# Validate the NVCC CUDA version.
|
||||
if nvcc_cuda_version < Version("11.0"):
|
||||
raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
|
||||
if (nvcc_cuda_version < Version("11.1")
|
||||
and any(cc.startswith("8.6") for cc in compute_capabilities)):
|
||||
raise RuntimeError(
|
||||
"CUDA 11.1 or higher is required for compute capability 8.6.")
|
||||
if nvcc_cuda_version < Version("11.8"):
|
||||
if any(cc.startswith("8.9") for cc in compute_capabilities):
|
||||
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
|
||||
# However, GPUs with compute capability 8.9 can also run the code generated by
|
||||
# the previous versions of CUDA 11 and targeting compute capability 8.0.
|
||||
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
|
||||
# instead of 8.9.
|
||||
warnings.warn(
|
||||
"CUDA 11.8 or higher is required for compute capability 8.9. "
|
||||
"Targeting compute capability 8.0 instead.",
|
||||
stacklevel=2)
|
||||
compute_capabilities = set(cc for cc in compute_capabilities
|
||||
if not cc.startswith("8.9"))
|
||||
compute_capabilities.add("8.0+PTX")
|
||||
if any(cc.startswith("9.0") for cc in compute_capabilities):
|
||||
if _is_cuda():
|
||||
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
|
||||
if not compute_capabilities:
|
||||
# If no GPU is specified nor available, add all supported architectures
|
||||
# based on the NVCC CUDA version.
|
||||
compute_capabilities = NVIDIA_SUPPORTED_ARCHS.copy()
|
||||
if nvcc_cuda_version < Version("11.1"):
|
||||
compute_capabilities.remove("8.6")
|
||||
if nvcc_cuda_version < Version("11.8"):
|
||||
compute_capabilities.remove("8.9")
|
||||
compute_capabilities.remove("9.0")
|
||||
# Validate the NVCC CUDA version.
|
||||
if nvcc_cuda_version < Version("11.0"):
|
||||
raise RuntimeError(
|
||||
"CUDA 11.8 or higher is required for compute capability 9.0.")
|
||||
"CUDA 11.0 or higher is required to build the package.")
|
||||
if (nvcc_cuda_version < Version("11.1")
|
||||
and any(cc.startswith("8.6") for cc in compute_capabilities)):
|
||||
raise RuntimeError(
|
||||
"CUDA 11.1 or higher is required for compute capability 8.6.")
|
||||
if nvcc_cuda_version < Version("11.8"):
|
||||
if any(cc.startswith("8.9") for cc in compute_capabilities):
|
||||
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
|
||||
# However, GPUs with compute capability 8.9 can also run the code generated by
|
||||
# the previous versions of CUDA 11 and targeting compute capability 8.0.
|
||||
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
|
||||
# instead of 8.9.
|
||||
warnings.warn(
|
||||
"CUDA 11.8 or higher is required for compute capability 8.9. "
|
||||
"Targeting compute capability 8.0 instead.",
|
||||
stacklevel=2)
|
||||
compute_capabilities = set(cc for cc in compute_capabilities
|
||||
if not cc.startswith("8.9"))
|
||||
compute_capabilities.add("8.0+PTX")
|
||||
if any(cc.startswith("9.0") for cc in compute_capabilities):
|
||||
raise RuntimeError(
|
||||
"CUDA 11.8 or higher is required for compute capability 9.0.")
|
||||
|
||||
# Add target compute capabilities to NVCC flags.
|
||||
for capability in compute_capabilities:
|
||||
num = capability[0] + capability[2]
|
||||
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
|
||||
if capability.endswith("+PTX"):
|
||||
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"]
|
||||
# Add target compute capabilities to NVCC flags.
|
||||
for capability in compute_capabilities:
|
||||
num = capability[0] + capability[2]
|
||||
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
|
||||
if capability.endswith("+PTX"):
|
||||
NVCC_FLAGS += [
|
||||
"-gencode", f"arch=compute_{num},code=compute_{num}"
|
||||
]
|
||||
|
||||
# Use NVCC threads to parallelize the build.
|
||||
if nvcc_cuda_version >= Version("11.2"):
|
||||
nvcc_threads = int(os.getenv("NVCC_THREADS", 8))
|
||||
num_threads = min(os.cpu_count(), nvcc_threads)
|
||||
NVCC_FLAGS += ["--threads", str(num_threads)]
|
||||
# Use NVCC threads to parallelize the build.
|
||||
if nvcc_cuda_version >= Version("11.2"):
|
||||
nvcc_threads = int(os.getenv("NVCC_THREADS", 8))
|
||||
num_threads = min(os.cpu_count(), nvcc_threads)
|
||||
NVCC_FLAGS += ["--threads", str(num_threads)]
|
||||
|
||||
elif _is_hip():
|
||||
amd_arch = get_amdgpu_offload_arch()
|
||||
if amd_arch not in ROCM_SUPPORTED_ARCHS:
|
||||
raise RuntimeError(
|
||||
f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}"
|
||||
f"amdgpu_arch_found: {amd_arch}")
|
||||
|
||||
ext_modules = []
|
||||
|
||||
vllm_extension_sources = [
|
||||
"csrc/cache_kernels.cu",
|
||||
"csrc/attention/attention_kernels.cu",
|
||||
"csrc/pos_encoding_kernels.cu",
|
||||
"csrc/activation_kernels.cu",
|
||||
"csrc/layernorm_kernels.cu",
|
||||
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
|
||||
"csrc/cuda_utils_kernels.cu",
|
||||
"csrc/pybind.cpp",
|
||||
]
|
||||
|
||||
if _is_cuda():
|
||||
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
|
||||
|
||||
vllm_extension = CUDAExtension(
|
||||
name="vllm._C",
|
||||
sources=[
|
||||
"csrc/cache_kernels.cu",
|
||||
"csrc/attention/attention_kernels.cu",
|
||||
"csrc/pos_encoding_kernels.cu",
|
||||
"csrc/activation_kernels.cu",
|
||||
"csrc/layernorm_kernels.cu",
|
||||
"csrc/quantization/awq/gemm_kernels.cu",
|
||||
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
|
||||
"csrc/cuda_utils_kernels.cu",
|
||||
"csrc/pybind.cpp",
|
||||
],
|
||||
sources=vllm_extension_sources,
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
@ -183,10 +256,19 @@ def find_version(filepath: str) -> str:
|
||||
|
||||
def get_vllm_version() -> str:
|
||||
version = find_version(get_path("vllm", "__init__.py"))
|
||||
cuda_version = str(nvcc_cuda_version)
|
||||
if cuda_version != MAIN_CUDA_VERSION:
|
||||
cuda_version_str = cuda_version.replace(".", "")[:3]
|
||||
version += f"+cu{cuda_version_str}"
|
||||
|
||||
if _is_hip():
|
||||
# Get the HIP version
|
||||
hipcc_version = get_hipcc_rocm_version()
|
||||
if hipcc_version != MAIN_CUDA_VERSION:
|
||||
rocm_version_str = hipcc_version.replace(".", "")[:3]
|
||||
version += f"+rocm{rocm_version_str}"
|
||||
else:
|
||||
cuda_version = str(nvcc_cuda_version)
|
||||
if cuda_version != MAIN_CUDA_VERSION:
|
||||
cuda_version_str = cuda_version.replace(".", "")[:3]
|
||||
version += f"+cu{cuda_version_str}"
|
||||
|
||||
return version
|
||||
|
||||
|
||||
@ -201,8 +283,12 @@ def read_readme() -> str:
|
||||
|
||||
def get_requirements() -> List[str]:
|
||||
"""Get Python package dependencies from requirements.txt."""
|
||||
with open(get_path("requirements.txt")) as f:
|
||||
requirements = f.read().strip().split("\n")
|
||||
if _is_hip():
|
||||
with open(get_path("requirements-rocm.txt")) as f:
|
||||
requirements = f.read().strip().split("\n")
|
||||
else:
|
||||
with open(get_path("requirements.txt")) as f:
|
||||
requirements = f.read().strip().split("\n")
|
||||
return requirements
|
||||
|
||||
|
||||
|
@ -6,7 +6,7 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config import get_config
|
||||
from vllm.utils import get_cpu_memory
|
||||
from vllm.utils import get_cpu_memory, is_hip
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -98,12 +98,27 @@ class ModelConfig:
|
||||
|
||||
def _verify_load_format(self) -> None:
|
||||
load_format = self.load_format.lower()
|
||||
if load_format not in [
|
||||
"auto", "pt", "safetensors", "npcache", "dummy"
|
||||
]:
|
||||
supported_load_format = [
|
||||
"auto", "pt", "safetensors", "npcache", "dummy"
|
||||
]
|
||||
rocm_not_supported_load_format = ["safetensors"]
|
||||
if load_format not in supported_load_format:
|
||||
raise ValueError(
|
||||
f"Unknown load format: {self.load_format}. Must be one of "
|
||||
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
|
||||
if is_hip():
|
||||
if load_format in ["safetensors"]:
|
||||
rocm_supported_load_format = [
|
||||
f for f in supported_load_format
|
||||
if (f not in rocm_not_supported_load_format)
|
||||
]
|
||||
raise ValueError(
|
||||
f"load format \'{load_format}\' is not supported in ROCm. "
|
||||
f"Supported load format are "
|
||||
f"{rocm_supported_load_format}")
|
||||
# Force ROCm to load from pt weights if nothing specific is set
|
||||
if load_format == "auto":
|
||||
load_format = "pt"
|
||||
self.load_format = load_format
|
||||
|
||||
def _verify_tokenizer_mode(self) -> None:
|
||||
@ -116,6 +131,7 @@ class ModelConfig:
|
||||
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = ["awq", "squeezellm"]
|
||||
rocm_not_supported_quantization = ["awq"]
|
||||
if self.quantization is not None:
|
||||
self.quantization = self.quantization.lower()
|
||||
|
||||
@ -137,6 +153,11 @@ class ModelConfig:
|
||||
raise ValueError(
|
||||
f"Unknown quantization method: {self.quantization}. Must "
|
||||
f"be one of {supported_quantization}.")
|
||||
if is_hip(
|
||||
) and self.quantization in rocm_not_supported_quantization:
|
||||
raise ValueError(
|
||||
f"{self.quantization} quantization is currently not supported "
|
||||
f"in ROCm.")
|
||||
logger.warning(f"{self.quantization} quantization is not fully "
|
||||
"optimized yet. The speed can be slower than "
|
||||
"non-quantized models.")
|
||||
@ -364,6 +385,8 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
|
||||
_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]
|
||||
|
||||
|
||||
def _get_and_verify_dtype(
|
||||
config: PretrainedConfig,
|
||||
@ -393,6 +416,14 @@ def _get_and_verify_dtype(
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
|
||||
if is_hip() and torch_dtype == torch.float32:
|
||||
rocm_supported_dtypes = [
|
||||
k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
|
||||
if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
|
||||
]
|
||||
raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. "
|
||||
f"Supported dtypes are {rocm_supported_dtypes}")
|
||||
|
||||
# Verify the dtype.
|
||||
if torch_dtype != config_dtype:
|
||||
if torch_dtype == torch.float32:
|
||||
|
@ -3,6 +3,7 @@ from typing import Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_hip
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -73,7 +74,12 @@ def initialize_cluster(
|
||||
"Ray is not installed. Please install Ray to use distributed "
|
||||
"serving.")
|
||||
# Connect to a ray cluster.
|
||||
ray.init(address=ray_address, ignore_reinit_error=True)
|
||||
if is_hip():
|
||||
ray.init(address=ray_address,
|
||||
ignore_reinit_error=True,
|
||||
num_gpus=parallel_config.world_size)
|
||||
else:
|
||||
ray.init(address=ray_address, ignore_reinit_error=True)
|
||||
|
||||
if not parallel_config.worker_use_ray:
|
||||
# Initialize cluster locally.
|
||||
|
@ -10,6 +10,7 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
|
||||
from vllm._C import ops
|
||||
from vllm._C import cache_ops
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.utils import is_hip
|
||||
|
||||
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
@ -160,6 +161,8 @@ class PagedAttention(nn.Module):
|
||||
attn_bias=input_metadata.attn_bias,
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
|
||||
(is_hip()) else None,
|
||||
)
|
||||
output = out.view_as(query)
|
||||
else:
|
||||
|
@ -7,6 +7,7 @@ from vllm._C import ops
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.utils import is_hip
|
||||
|
||||
|
||||
class SqueezeLLMConfig(QuantizationConfig):
|
||||
@ -114,9 +115,14 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
|
||||
lookup_table = weights["lookup_table"]
|
||||
out_shape = x.shape[:-1] + (qweight.shape[-1], )
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
# NOTE: The output tensor should be zero-initialized.
|
||||
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
|
||||
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
|
||||
if is_hip():
|
||||
out_f = torch.zeros(out_shape, device="cuda", dtype=torch.float)
|
||||
ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table)
|
||||
out = out_f.to(dtype=torch.float16)
|
||||
else:
|
||||
# NOTE: The output tensor should be zero-initialized.
|
||||
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
|
||||
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
|
||||
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
|
@ -10,6 +10,10 @@ from vllm.config import ModelConfig
|
||||
from vllm.model_executor.models import *
|
||||
from vllm.model_executor.weight_utils import (get_quant_config,
|
||||
initialize_dummy_weights)
|
||||
from vllm.utils import is_hip
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# TODO(woosuk): Lazy-load the model classes.
|
||||
_MODEL_REGISTRY = {
|
||||
@ -39,6 +43,18 @@ _MODEL_REGISTRY = {
|
||||
"YiForCausalLM": YiForCausalLM,
|
||||
}
|
||||
|
||||
# Models to be disabled in ROCm
|
||||
_ROCM_UNSUPPORTED_MODELS = []
|
||||
if is_hip():
|
||||
for rocm_model in _ROCM_UNSUPPORTED_MODELS:
|
||||
del _MODEL_REGISTRY[rocm_model]
|
||||
|
||||
# Models partially supported in ROCm
|
||||
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
|
||||
"MistralForCausalLM":
|
||||
"Sliding window attention is not supported in ROCm's flash attention",
|
||||
}
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _set_default_torch_dtype(dtype: torch.dtype):
|
||||
@ -53,7 +69,15 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
if arch in _MODEL_REGISTRY:
|
||||
if is_hip() and arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
|
||||
logger.warning(
|
||||
f"{arch} is not fully supported in ROCm. Reason: "
|
||||
f"{_ROCM_PARTIALLY_SUPPORTED_MODELS[arch]}")
|
||||
return _MODEL_REGISTRY[arch]
|
||||
elif arch in _ROCM_UNSUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
f"Model architecture {arch} is not supported by ROCm for now. \n"
|
||||
f"Supported architectures {list(_MODEL_REGISTRY.keys())}")
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported for now. "
|
||||
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
|
||||
|
@ -27,10 +27,14 @@ class Counter:
|
||||
self.counter = 0
|
||||
|
||||
|
||||
def is_hip() -> bool:
|
||||
return torch.version.hip is not None
|
||||
|
||||
|
||||
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
||||
"""Returns the maximum shared memory per thread block in bytes."""
|
||||
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
|
||||
max_shared_mem = cuda_utils.get_device_attribute(
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu)
|
||||
return int(max_shared_mem)
|
||||
|
Loading…
x
Reference in New Issue
Block a user