[Neuron] Add an option to build with neuron (#2065)
This commit is contained in:
parent
4df417d059
commit
18473cf498
9
requirements-neuron.txt
Normal file
9
requirements-neuron.txt
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
sentencepiece # Required for LLaMA tokenizer.
|
||||||
|
numpy
|
||||||
|
transformers-neuronx >= 0.9.0
|
||||||
|
torch-neuronx >= 2.1.0
|
||||||
|
neuronx-cc
|
||||||
|
fastapi
|
||||||
|
uvicorn[standard]
|
||||||
|
pydantic == 1.10.13 # Required for OpenAI server.
|
||||||
|
aioprometheus[starlette]
|
62
setup.py
62
setup.py
@ -24,8 +24,17 @@ def _is_hip() -> bool:
|
|||||||
return torch.version.hip is not None
|
return torch.version.hip is not None
|
||||||
|
|
||||||
|
|
||||||
|
def _is_neuron() -> bool:
|
||||||
|
torch_neuronx_installed = True
|
||||||
|
try:
|
||||||
|
subprocess.run(["neuron-ls"], capture_output=True, check=True)
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
torch_neuronx_installed = False
|
||||||
|
return torch_neuronx_installed
|
||||||
|
|
||||||
|
|
||||||
def _is_cuda() -> bool:
|
def _is_cuda() -> bool:
|
||||||
return torch.version.cuda is not None
|
return (torch.version.cuda is not None) and not _is_neuron()
|
||||||
|
|
||||||
|
|
||||||
# Compiler flags.
|
# Compiler flags.
|
||||||
@ -87,6 +96,24 @@ def get_hipcc_rocm_version():
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_neuronxcc_version():
|
||||||
|
import sysconfig
|
||||||
|
site_dir = sysconfig.get_paths()["purelib"]
|
||||||
|
version_file = os.path.join(site_dir, "neuronxcc", "version", "__init__.py")
|
||||||
|
|
||||||
|
# Check if the command was executed successfully
|
||||||
|
with open(version_file, "rt") as fp:
|
||||||
|
content = fp.read()
|
||||||
|
|
||||||
|
# Extract the version using a regular expression
|
||||||
|
match = re.search(r"__version__ = '(\S+)'", content)
|
||||||
|
if match:
|
||||||
|
# Return the version string
|
||||||
|
return match.group(1)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Could not find HIP version in the output")
|
||||||
|
|
||||||
|
|
||||||
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
||||||
"""Get the CUDA version from nvcc.
|
"""Get the CUDA version from nvcc.
|
||||||
|
|
||||||
@ -210,6 +237,9 @@ elif _is_hip():
|
|||||||
f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}"
|
f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}"
|
||||||
f"amdgpu_arch_found: {amd_arch}")
|
f"amdgpu_arch_found: {amd_arch}")
|
||||||
|
|
||||||
|
elif _is_neuron():
|
||||||
|
neuronxcc_version = get_neuronxcc_version()
|
||||||
|
|
||||||
ext_modules = []
|
ext_modules = []
|
||||||
|
|
||||||
vllm_extension_sources = [
|
vllm_extension_sources = [
|
||||||
@ -227,15 +257,16 @@ vllm_extension_sources = [
|
|||||||
if _is_cuda():
|
if _is_cuda():
|
||||||
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
|
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
|
||||||
|
|
||||||
vllm_extension = CUDAExtension(
|
if not _is_neuron():
|
||||||
name="vllm._C",
|
vllm_extension = CUDAExtension(
|
||||||
sources=vllm_extension_sources,
|
name="vllm._C",
|
||||||
extra_compile_args={
|
sources=vllm_extension_sources,
|
||||||
"cxx": CXX_FLAGS,
|
extra_compile_args={
|
||||||
"nvcc": NVCC_FLAGS,
|
"cxx": CXX_FLAGS,
|
||||||
},
|
"nvcc": NVCC_FLAGS,
|
||||||
)
|
},
|
||||||
ext_modules.append(vllm_extension)
|
)
|
||||||
|
ext_modules.append(vllm_extension)
|
||||||
|
|
||||||
|
|
||||||
def get_path(*filepath) -> str:
|
def get_path(*filepath) -> str:
|
||||||
@ -264,6 +295,12 @@ def get_vllm_version() -> str:
|
|||||||
if hipcc_version != MAIN_CUDA_VERSION:
|
if hipcc_version != MAIN_CUDA_VERSION:
|
||||||
rocm_version_str = hipcc_version.replace(".", "")[:3]
|
rocm_version_str = hipcc_version.replace(".", "")[:3]
|
||||||
version += f"+rocm{rocm_version_str}"
|
version += f"+rocm{rocm_version_str}"
|
||||||
|
elif _is_neuron():
|
||||||
|
# Get the Neuron version
|
||||||
|
neuron_version = str(neuronxcc_version)
|
||||||
|
if neuron_version != MAIN_CUDA_VERSION:
|
||||||
|
neuron_version_str = neuron_version.replace(".", "")[:3]
|
||||||
|
version += f"+neuron{neuron_version_str}"
|
||||||
else:
|
else:
|
||||||
cuda_version = str(nvcc_cuda_version)
|
cuda_version = str(nvcc_cuda_version)
|
||||||
if cuda_version != MAIN_CUDA_VERSION:
|
if cuda_version != MAIN_CUDA_VERSION:
|
||||||
@ -287,6 +324,9 @@ def get_requirements() -> List[str]:
|
|||||||
if _is_hip():
|
if _is_hip():
|
||||||
with open(get_path("requirements-rocm.txt")) as f:
|
with open(get_path("requirements-rocm.txt")) as f:
|
||||||
requirements = f.read().strip().split("\n")
|
requirements = f.read().strip().split("\n")
|
||||||
|
elif _is_neuron():
|
||||||
|
with open(get_path("requirements-neuron.txt")) as f:
|
||||||
|
requirements = f.read().strip().split("\n")
|
||||||
else:
|
else:
|
||||||
with open(get_path("requirements.txt")) as f:
|
with open(get_path("requirements.txt")) as f:
|
||||||
requirements = f.read().strip().split("\n")
|
requirements = f.read().strip().split("\n")
|
||||||
@ -325,6 +365,6 @@ setuptools.setup(
|
|||||||
python_requires=">=3.8",
|
python_requires=">=3.8",
|
||||||
install_requires=get_requirements(),
|
install_requires=get_requirements(),
|
||||||
ext_modules=ext_modules,
|
ext_modules=ext_modules,
|
||||||
cmdclass={"build_ext": BuildExtension},
|
cmdclass={"build_ext": BuildExtension} if not _is_neuron() else {},
|
||||||
package_data=package_data,
|
package_data=package_data,
|
||||||
)
|
)
|
||||||
|
@ -8,8 +8,6 @@ from typing import List
|
|||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm._C import cuda_utils
|
|
||||||
|
|
||||||
|
|
||||||
class Device(enum.Enum):
|
class Device(enum.Enum):
|
||||||
GPU = enum.auto()
|
GPU = enum.auto()
|
||||||
@ -36,6 +34,10 @@ def is_hip() -> bool:
|
|||||||
|
|
||||||
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
||||||
"""Returns the maximum shared memory per thread block in bytes."""
|
"""Returns the maximum shared memory per thread block in bytes."""
|
||||||
|
# NOTE: This import statement should be executed lazily since
|
||||||
|
# the Neuron-X backend does not have the `cuda_utils` module.
|
||||||
|
from vllm._C import cuda_utils
|
||||||
|
|
||||||
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
||||||
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
|
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
|
||||||
max_shared_mem = cuda_utils.get_device_attribute(
|
max_shared_mem = cuda_utils.get_device_attribute(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user