[Misc] Upgrade to pytorch 2.5 (#9588)
Signed-off-by: Bill Nell <bill@neuralmagic.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
8549c82660
commit
3cb07a36a2
@ -49,7 +49,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
|
|||||||
# requirements.txt files and should be kept consistent. The ROCm torch
|
# requirements.txt files and should be kept consistent. The ROCm torch
|
||||||
# versions are derived from Dockerfile.rocm
|
# versions are derived from Dockerfile.rocm
|
||||||
#
|
#
|
||||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.4.0")
|
set(TORCH_SUPPORTED_VERSION_CUDA "2.5.0")
|
||||||
set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0")
|
set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0")
|
||||||
|
|
||||||
#
|
#
|
||||||
@ -507,7 +507,7 @@ else()
|
|||||||
FetchContent_Declare(
|
FetchContent_Declare(
|
||||||
vllm-flash-attn
|
vllm-flash-attn
|
||||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||||
GIT_TAG 013f0c4fc47e6574060879d9734c1df8c5c273bd
|
GIT_TAG 5259c586c403a4e4d8bf69973c159b40cc346fb9
|
||||||
GIT_PROGRESS TRUE
|
GIT_PROGRESS TRUE
|
||||||
# Don't share the vllm-flash-attn build between build types
|
# Don't share the vllm-flash-attn build between build types
|
||||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||||
|
@ -424,11 +424,7 @@ function (define_gpu_extension_target GPU_MOD_NAME)
|
|||||||
# Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of
|
# Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of
|
||||||
# dependencies that are not necessary and may not be installed.
|
# dependencies that are not necessary and may not be installed.
|
||||||
if (GPU_LANGUAGE STREQUAL "CUDA")
|
if (GPU_LANGUAGE STREQUAL "CUDA")
|
||||||
if ("${CUDA_CUDA_LIB}" STREQUAL "")
|
target_link_libraries(${GPU_MOD_NAME} PRIVATE CUDA::cudart CUDA::cuda_driver)
|
||||||
set(CUDA_CUDA_LIB "${CUDA_CUDA_LIBRARY}")
|
|
||||||
endif()
|
|
||||||
target_link_libraries(${GPU_MOD_NAME} PRIVATE ${CUDA_CUDA_LIB}
|
|
||||||
${CUDA_LIBRARIES})
|
|
||||||
else()
|
else()
|
||||||
target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES})
|
target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES})
|
||||||
endif()
|
endif()
|
||||||
|
@ -6,7 +6,7 @@ requires = [
|
|||||||
"packaging",
|
"packaging",
|
||||||
"setuptools>=61",
|
"setuptools>=61",
|
||||||
"setuptools-scm>=8.0",
|
"setuptools-scm>=8.0",
|
||||||
"torch == 2.4.0",
|
"torch == 2.5.0",
|
||||||
"wheel",
|
"wheel",
|
||||||
"jinja2",
|
"jinja2",
|
||||||
]
|
]
|
||||||
|
@ -4,6 +4,6 @@ ninja
|
|||||||
packaging
|
packaging
|
||||||
setuptools>=61
|
setuptools>=61
|
||||||
setuptools-scm>=8
|
setuptools-scm>=8
|
||||||
torch==2.4.0
|
torch==2.5.0
|
||||||
wheel
|
wheel
|
||||||
jinja2
|
jinja2
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
# Dependencies for NVIDIA GPUs
|
# Dependencies for NVIDIA GPUs
|
||||||
ray >= 2.9
|
ray >= 2.9
|
||||||
nvidia-ml-py # for pynvml package
|
nvidia-ml-py # for pynvml package
|
||||||
torch == 2.4.0
|
torch == 2.5.0
|
||||||
# These must be updated alongside torch
|
# These must be updated alongside torch
|
||||||
torchvision == 0.19 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
torchvision == 0.20 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
||||||
xformers == 0.0.27.post2; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.4.0
|
xformers == 0.0.28.post2; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.5.0
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Common dependencies
|
# Common dependencies
|
||||||
-r requirements-common.txt
|
-r requirements-common.txt
|
||||||
|
|
||||||
torch == 2.4.0 # should be aligned with "common" vLLM torch version
|
torch == 2.5.0 # should be aligned with "common" vLLM torch version
|
||||||
openvino >= 2024.4.0 # since 2024.4.0 both CPU and GPU support Paged Attention
|
openvino >= 2024.4.0 # since 2024.4.0 both CPU and GPU support Paged Attention
|
||||||
|
|
||||||
optimum @ git+https://github.com/huggingface/optimum.git@main # latest optimum is used to support latest transformers version
|
optimum @ git+https://github.com/huggingface/optimum.git@main # latest optimum is used to support latest transformers version
|
||||||
|
@ -8,7 +8,7 @@ import pytest
|
|||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from ...utils import check_outputs_equal
|
from ...utils import check_logprobs_close, check_outputs_equal
|
||||||
|
|
||||||
MODELS = [
|
MODELS = [
|
||||||
"meta-llama/Llama-2-7b-hf",
|
"meta-llama/Llama-2-7b-hf",
|
||||||
@ -43,18 +43,40 @@ def test_models(
|
|||||||
dtype: str,
|
dtype: str,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
with hf_runner(model, dtype=dtype) as hf_model:
|
|
||||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
|
||||||
|
|
||||||
with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model:
|
if model == "openbmb/MiniCPM3-4B":
|
||||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
# the output becomes slightly different when upgrading to
|
||||||
|
# pytorch 2.5 . Changing to logprobs checks instead of exact
|
||||||
|
# output checks.
|
||||||
|
NUM_LOG_PROBS = 8
|
||||||
|
with hf_runner(model, dtype=dtype) as hf_model:
|
||||||
|
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||||
|
example_prompts, max_tokens, NUM_LOG_PROBS)
|
||||||
|
|
||||||
check_outputs_equal(
|
with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model:
|
||||||
outputs_0_lst=hf_outputs,
|
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||||
outputs_1_lst=vllm_outputs,
|
example_prompts, max_tokens, NUM_LOG_PROBS)
|
||||||
name_0="hf",
|
|
||||||
name_1="vllm",
|
check_logprobs_close(
|
||||||
)
|
outputs_0_lst=hf_outputs,
|
||||||
|
outputs_1_lst=vllm_outputs,
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
with hf_runner(model, dtype=dtype) as hf_model:
|
||||||
|
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|
||||||
|
with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model:
|
||||||
|
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
||||||
|
max_tokens)
|
||||||
|
|
||||||
|
check_outputs_equal(
|
||||||
|
outputs_0_lst=hf_outputs,
|
||||||
|
outputs_1_lst=vllm_outputs,
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@ -7,6 +7,7 @@ from functools import lru_cache, wraps
|
|||||||
from typing import Callable, List, Tuple, TypeVar
|
from typing import Callable, List, Tuple, TypeVar
|
||||||
|
|
||||||
import pynvml
|
import pynvml
|
||||||
|
import torch
|
||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -26,6 +27,10 @@ if pynvml.__file__.endswith("__init__.py"):
|
|||||||
" and cause errors. See https://pypi.org/project/pynvml "
|
" and cause errors. See https://pypi.org/project/pynvml "
|
||||||
"for more information.")
|
"for more information.")
|
||||||
|
|
||||||
|
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
|
||||||
|
# see https://github.com/huggingface/diffusers/issues/9704 for details
|
||||||
|
torch.backends.cuda.enable_cudnn_sdp(False)
|
||||||
|
|
||||||
# NVML utils
|
# NVML utils
|
||||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||||
# all the related functions work on real physical device ids.
|
# all the related functions work on real physical device ids.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user