[Misc] fix line length for entire codebase (#3444)

This commit is contained in:
Simon Mo 2024-03-16 00:36:29 -07:00 committed by GitHub
parent ad50bf4b25
commit 8e67598aa6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 174 additions and 128 deletions

View File

@ -28,7 +28,7 @@ jobs:
pip install ruff==0.1.5 codespell==2.2.6 tomli==2.0.1
- name: Analysing the code with ruff
run: |
ruff vllm tests
ruff .
- name: Spelling check with codespell
run: |
codespell --toml pyproject.toml

View File

@ -110,7 +110,7 @@ async def async_request_vllm(
output.ttft = ttft
output.latency = time.perf_counter() - st
# When streaming, '\0' is appended to the end of the response.
# When streaming, '\0' is appended to the end of response.
body = data.decode("utf-8").strip("\0")
output.generated_text = json.loads(
body)["text"][0][len(request_func_input.prompt):]
@ -192,7 +192,8 @@ async def async_request_deepspeed_mii(
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len
# DeepSpeed-MII doesn't support streaming as of Jan 28 2024, will use 0 as placeholder.
# DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
# will use 0 as placeholder.
# https://github.com/microsoft/DeepSpeed-MII/pull/311
output.ttft = 0
@ -344,7 +345,8 @@ async def async_request_openai_chat_completions(
return output
# Since vllm must support Python 3.8, we can't use str.removeprefix(prefix) introduced in Python 3.9
# Since vllm must support Python 3.8, we can't use str.removeprefix(prefix)
# introduced in Python 3.9
def remove_prefix(text: str, prefix: str) -> str:
if text.startswith(prefix):
return text[len(prefix):]

View File

@ -4,7 +4,7 @@ import time
from vllm import LLM
from vllm import SamplingParams
PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n"
PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501
def test_prefix(llm=None, sampling_params=None, prompts=None):

View File

@ -293,7 +293,9 @@ def main(args: argparse.Namespace):
# Save to file
base_model_id = model_id.split("/")[-1]
file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
file_name = (
f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
)
with open(file_name, "w") as outfile:
json.dump(result_json, outfile)
@ -341,7 +343,7 @@ if __name__ == "__main__":
"--tokenizer",
type=str,
help=
"Name or path of the tokenizer, if not using the default model tokenizer.",
"Name or path of the tokenizer, if not using the default tokenizer.",
)
parser.add_argument(
"--best-of",

View File

@ -1,3 +1,4 @@
# ruff: noqa
# code borrowed from https://github.com/pytorch/pytorch/blob/main/torch/utils/collect_env.py
# Unlike the rest of the PyTorch this file must be python2 compliant.
@ -11,7 +12,6 @@ import sys
import os
from collections import namedtuple
try:
import torch
TORCH_AVAILABLE = True
@ -19,7 +19,9 @@ except (ImportError, NameError, AttributeError, OSError):
TORCH_AVAILABLE = False
# System Environment Information
SystemEnv = namedtuple('SystemEnv', [
SystemEnv = namedtuple(
'SystemEnv',
[
'torch_version',
'is_debug_build',
'cuda_compiled_version',
@ -50,7 +52,7 @@ SystemEnv = namedtuple('SystemEnv', [
'vllm_version', # vllm specific field
'vllm_build_flags', # vllm specific field
'gpu_topo', # vllm specific field
])
])
DEFAULT_CONDA_PATTERNS = {
"torch",
@ -77,8 +79,10 @@ DEFAULT_PIP_PATTERNS = {
def run(command):
"""Return (return-code, stdout, stderr)."""
shell = True if type(command) is str else False
p = subprocess.Popen(command, stdout=subprocess.PIPE,
stderr=subprocess.PIPE, shell=shell)
p = subprocess.Popen(command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=shell)
raw_output, raw_err = p.communicate()
rc = p.returncode
if get_platform() == 'win32':
@ -108,6 +112,7 @@ def run_and_parse_first_match(run_lambda, command, regex):
return None
return match.group(1)
def run_and_return_first_line(run_lambda, command):
"""Run command using run_lambda and returns first line if output is not empty."""
rc, out, _ = run_lambda(command)
@ -124,22 +129,23 @@ def get_conda_packages(run_lambda, patterns=None):
if out is None:
return out
return "\n".join(
line
for line in out.splitlines()
if not line.startswith("#")
and any(name in line for name in patterns)
)
return "\n".join(line for line in out.splitlines()
if not line.startswith("#") and any(name in line
for name in patterns))
def get_gcc_version(run_lambda):
return run_and_parse_first_match(run_lambda, 'gcc --version', r'gcc (.*)')
def get_clang_version(run_lambda):
return run_and_parse_first_match(run_lambda, 'clang --version', r'clang version (.*)')
return run_and_parse_first_match(run_lambda, 'clang --version',
r'clang version (.*)')
def get_cmake_version(run_lambda):
return run_and_parse_first_match(run_lambda, 'cmake --version', r'cmake (.*)')
return run_and_parse_first_match(run_lambda, 'cmake --version',
r'cmake (.*)')
def get_nvidia_driver_version(run_lambda):
@ -148,11 +154,13 @@ def get_nvidia_driver_version(run_lambda):
return run_and_parse_first_match(run_lambda, cmd,
r'com[.]nvidia[.]CUDA [(](.*?)[)]')
smi = get_nvidia_smi()
return run_and_parse_first_match(run_lambda, smi, r'Driver Version: (.*?) ')
return run_and_parse_first_match(run_lambda, smi,
r'Driver Version: (.*?) ')
def get_gpu_info(run_lambda):
if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr(torch.version, 'hip') and torch.version.hip is not None):
if get_platform() == 'darwin' or (TORCH_AVAILABLE and hasattr(
torch.version, 'hip') and torch.version.hip is not None):
if TORCH_AVAILABLE and torch.cuda.is_available():
if torch.version.hip is not None:
prop = torch.cuda.get_device_properties(0)
@ -174,7 +182,8 @@ def get_gpu_info(run_lambda):
def get_running_cuda_version(run_lambda):
return run_and_parse_first_match(run_lambda, 'nvcc --version', r'release .+ V(.*)')
return run_and_parse_first_match(run_lambda, 'nvcc --version',
r'release .+ V(.*)')
def get_cudnn_version(run_lambda):
@ -219,8 +228,10 @@ def get_nvidia_smi():
smi = 'nvidia-smi'
if get_platform() == 'win32':
system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows')
program_files_root = os.environ.get('PROGRAMFILES', 'C:\\Program Files')
legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation', 'NVSMI', smi)
program_files_root = os.environ.get('PROGRAMFILES',
'C:\\Program Files')
legacy_path = os.path.join(program_files_root, 'NVIDIA Corporation',
'NVSMI', smi)
new_path = os.path.join(system_root, 'System32', smi)
smis = [new_path, legacy_path]
for candidate_smi in smis:
@ -232,7 +243,8 @@ def get_nvidia_smi():
def get_rocm_version(run_lambda):
"""Returns the ROCm version if available, otherwise 'N/A'."""
return run_and_parse_first_match(run_lambda, 'hipcc --version', r'HIP version: (\S+)')
return run_and_parse_first_match(run_lambda, 'hipcc --version',
r'HIP version: (\S+)')
def get_neuron_sdk_version(run_lambda):
@ -342,13 +354,16 @@ def get_gpu_topo(run_lambda):
# ProcessorType=3
# Revision=27142
def get_cpu_info(run_lambda):
rc, out, err = 0, '', ''
if get_platform() == 'linux':
rc, out, err = run_lambda('lscpu')
elif get_platform() == 'win32':
rc, out, err = run_lambda('wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID, \
CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE')
rc, out, err = run_lambda(
'wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID, \
CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE'
)
elif get_platform() == 'darwin':
rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string")
cpu_info = 'None'
@ -373,18 +388,22 @@ def get_platform():
def get_mac_version(run_lambda):
return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion', r'(.*)')
return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion',
r'(.*)')
def get_windows_version(run_lambda):
system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows')
wmic_cmd = os.path.join(system_root, 'System32', 'Wbem', 'wmic')
findstr_cmd = os.path.join(system_root, 'System32', 'findstr')
return run_and_read_all(run_lambda, '{} os get Caption | {} /v Caption'.format(wmic_cmd, findstr_cmd))
return run_and_read_all(
run_lambda,
'{} os get Caption | {} /v Caption'.format(wmic_cmd, findstr_cmd))
def get_lsb_version(run_lambda):
return run_and_parse_first_match(run_lambda, 'lsb_release -a', r'Description:\t(.*)')
return run_and_parse_first_match(run_lambda, 'lsb_release -a',
r'Description:\t(.*)')
def check_release_file(run_lambda):
@ -443,11 +462,8 @@ def get_pip_packages(run_lambda, patterns=None):
# But here it is invoked as `python -mpip`
def run_with_pip(pip):
out = run_and_read_all(run_lambda, pip + ["list", "--format=freeze"])
return "\n".join(
line
for line in out.splitlines()
if any(name in line for name in patterns)
)
return "\n".join(line for line in out.splitlines()
if any(name in line for name in patterns))
pip_version = 'pip3' if sys.version[0] == '3' else 'pip'
out = run_with_pip([sys.executable, '-mpip'])
@ -472,10 +488,12 @@ def get_cuda_module_loading_config():
def is_xnnpack_available():
if TORCH_AVAILABLE:
import torch.backends.xnnpack
return str(torch.backends.xnnpack.enabled) # type: ignore[attr-defined]
return str(
torch.backends.xnnpack.enabled) # type: ignore[attr-defined]
else:
return "N/A"
def get_env_info():
run_lambda = run
pip_version, pip_list_output = get_pip_packages(run_lambda)
@ -485,9 +503,11 @@ def get_env_info():
debug_mode_str = str(torch.version.debug)
cuda_available_str = str(torch.cuda.is_available())
cuda_version_str = torch.version.cuda
if not hasattr(torch.version, 'hip') or torch.version.hip is None: # cuda version
if not hasattr(torch.version,
'hip') or torch.version.hip is None: # cuda version
hip_compiled_version = hip_runtime_version = miopen_runtime_version = 'N/A'
else: # HIP version
def get_version_or_na(cfg, prefix):
_lst = [s.rsplit(None, 1)[-1] for s in cfg if prefix in s]
return _lst[0] if _lst else 'N/A'
@ -514,7 +534,9 @@ def get_env_info():
return SystemEnv(
torch_version=version_str,
is_debug_build=debug_mode_str,
python_version='{} ({}-bit runtime)'.format(sys_version, sys.maxsize.bit_length() + 1),
python_version='{} ({}-bit runtime)'.format(
sys_version,
sys.maxsize.bit_length() + 1),
python_platform=get_python_platform(),
is_cuda_available=cuda_available_str,
cuda_compiled_version=cuda_version_str,
@ -544,6 +566,7 @@ def get_env_info():
gpu_topo=gpu_topo,
)
env_info_fmt = """
PyTorch version: {torch_version}
Is debug build: {is_debug_build}
@ -588,6 +611,7 @@ GPU Topology:
def pretty_str(envinfo):
def replace_nones(dct, replacement='Could not collect'):
for key in dct.keys():
if dct[key] is not None:
@ -632,9 +656,10 @@ def pretty_str(envinfo):
'nvidia_driver_version',
]
all_cuda_fields = dynamic_cuda_fields + ['cudnn_version']
all_dynamic_cuda_fields_missing = all(
mutable_dict[field] is None for field in dynamic_cuda_fields)
if TORCH_AVAILABLE and not torch.cuda.is_available() and all_dynamic_cuda_fields_missing:
all_dynamic_cuda_fields_missing = all(mutable_dict[field] is None
for field in dynamic_cuda_fields)
if TORCH_AVAILABLE and not torch.cuda.is_available(
) and all_dynamic_cuda_fields_missing:
for field in all_cuda_fields:
mutable_dict[field] = 'No CUDA'
if envinfo.cuda_compiled_version is None:
@ -647,17 +672,19 @@ def pretty_str(envinfo):
mutable_dict = replace_nones(mutable_dict)
# If either of these are '', replace with 'No relevant packages'
mutable_dict['pip_packages'] = replace_if_empty(mutable_dict['pip_packages'])
mutable_dict['conda_packages'] = replace_if_empty(mutable_dict['conda_packages'])
mutable_dict['pip_packages'] = replace_if_empty(
mutable_dict['pip_packages'])
mutable_dict['conda_packages'] = replace_if_empty(
mutable_dict['conda_packages'])
# Tag conda and pip packages with a prefix
# If they were previously None, they'll show up as ie '[conda] Could not collect'
if mutable_dict['pip_packages']:
mutable_dict['pip_packages'] = prepend(mutable_dict['pip_packages'],
'[{}] '.format(envinfo.pip_version))
mutable_dict['pip_packages'] = prepend(
mutable_dict['pip_packages'], '[{}] '.format(envinfo.pip_version))
if mutable_dict['conda_packages']:
mutable_dict['conda_packages'] = prepend(mutable_dict['conda_packages'],
'[conda] ')
mutable_dict['conda_packages'] = prepend(
mutable_dict['conda_packages'], '[conda] ')
mutable_dict['cpu_info'] = envinfo.cpu_info
return env_info_fmt.format(**mutable_dict)
@ -671,18 +698,22 @@ def main():
output = get_pretty_env_info()
print(output)
if TORCH_AVAILABLE and hasattr(torch, 'utils') and hasattr(torch.utils, '_crash_handler'):
if TORCH_AVAILABLE and hasattr(torch, 'utils') and hasattr(
torch.utils, '_crash_handler'):
minidump_dir = torch.utils._crash_handler.DEFAULT_MINIDUMP_DIR
if sys.platform == "linux" and os.path.exists(minidump_dir):
dumps = [os.path.join(minidump_dir, dump) for dump in os.listdir(minidump_dir)]
dumps = [
os.path.join(minidump_dir, dump)
for dump in os.listdir(minidump_dir)
]
latest = max(dumps, key=os.path.getctime)
ctime = os.path.getctime(latest)
creation_time = datetime.datetime.fromtimestamp(ctime).strftime('%Y-%m-%d %H:%M:%S')
creation_time = datetime.datetime.fromtimestamp(ctime).strftime(
'%Y-%m-%d %H:%M:%S')
msg = "\n*** Detected a minidump at {} created on {}, ".format(latest, creation_time) + \
"if this is related to your bug please include it when you file a report ***"
print(msg, file=sys.stderr)
if __name__ == '__main__':
main()

View File

@ -10,7 +10,7 @@ TEMPLATE = """
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
""".lstrip()
""".lstrip() # noqa: E501
for input_dtype in DTYPES:
for output_dtype in DTYPES:

View File

@ -1,5 +1,6 @@
"""
This example shows how to use the multi-LoRA functionality for offline inference.
This example shows how to use the multi-LoRA functionality
for offline inference.
Requires HuggingFace credentials for access to Llama2.
"""
@ -34,14 +35,16 @@ def create_test_prompts(
top_k=5,
presence_penalty=0.2,
max_tokens=128), None),
("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128,
stop_token_ids=[32003]),
LoRARequest("sql-lora", 1, lora_path)),
("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
SamplingParams(n=3,
best_of=3,
use_beam_search=True,
@ -49,14 +52,16 @@ def create_test_prompts(
max_tokens=128,
stop_token_ids=[32003]),
LoRARequest("sql-lora", 1, lora_path)),
("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128,
stop_token_ids=[32003]),
LoRARequest("sql-lora2", 2, lora_path)),
("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
SamplingParams(n=3,
best_of=3,
use_beam_search=True,

View File

@ -37,9 +37,10 @@ for output in outputs:
print("-" * 80)
# The llm.generate call will batch all prompts and send the batch at once if resources allow.
# The prefix will only be cached after the first batch is processed, so we need to call generate once
# to calculate the prefix and cache it.
# The llm.generate call will batch all prompts and send the batch at once
# if resources allow. The prefix will only be cached after the first batch
# is processed, so we need to call generate once to calculate the prefix
# and cache it.
outputs = llm.generate(generating_prompts[0], sampling_params)
# Subsequent batches can leverage the cached prefix

View File

@ -12,7 +12,12 @@ import setuptools
import sys
import torch
import torch.utils.cpp_extension as torch_cpp_ext
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
from torch.utils.cpp_extension import (
BuildExtension,
CUDAExtension,
CUDA_HOME,
ROCM_HOME,
)
ROOT_DIR = os.path.dirname(__file__)
@ -57,9 +62,8 @@ NVCC_FLAGS = ["-O2", "-std=c++17"]
if _is_hip():
if ROCM_HOME is None:
raise RuntimeError(
"Cannot find ROCM_HOME. ROCm must be available to build the package."
)
raise RuntimeError("Cannot find ROCM_HOME. "
"ROCm must be available to build the package.")
NVCC_FLAGS += ["-DUSE_ROCM"]
NVCC_FLAGS += ["-U__HIP_NO_HALF_CONVERSIONS__"]
NVCC_FLAGS += ["-U__HIP_NO_HALF_OPERATORS__"]
@ -144,7 +148,8 @@ def get_pytorch_rocm_arch() -> Set[str]:
"""
env_arch_list = os.environ.get("PYTORCH_ROCM_ARCH", None)
# If we don't have PYTORCH_ROCM_ARCH specified pull the list from rocm_agent_enumerator
# If we don't have PYTORCH_ROCM_ARCH specified pull the list from
# rocm_agent_enumerator
if env_arch_list is None:
command = "rocm_agent_enumerator"
env_arch_list = (subprocess.check_output(
@ -255,11 +260,11 @@ if _is_cuda():
"CUDA 11.1 or higher is required for compute capability 8.6.")
if nvcc_cuda_version < Version("11.8"):
if any(cc.startswith("8.9") for cc in compute_capabilities):
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
# However, GPUs with compute capability 8.9 can also run the code generated by
# the previous versions of CUDA 11 and targeting compute capability 8.0.
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
# instead of 8.9.
# CUDA 11.8 is required to generate the code targeting compute
# capability 8.9. However, GPUs with compute capability 8.9 can
# also run the code generated by the previous versions of CUDA 11
# and targeting compute capability 8.0. Therefore, if CUDA 11.8
# is not available, we target compute capability 8.0 instead of 8.9.
warnings.warn(
"CUDA 11.8 or higher is required for compute capability 8.9. "
"Targeting compute capability 8.0 instead.",