[Misc] Upgrade to pytorch 2.5 (#9588)
Signed-off-by: Bill Nell <bill@neuralmagic.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
8549c82660
commit
3cb07a36a2
@ -49,7 +49,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
|
||||
# requirements.txt files and should be kept consistent. The ROCm torch
|
||||
# versions are derived from Dockerfile.rocm
|
||||
#
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.4.0")
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.5.0")
|
||||
set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0")
|
||||
|
||||
#
|
||||
@ -507,7 +507,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 013f0c4fc47e6574060879d9734c1df8c5c273bd
|
||||
GIT_TAG 5259c586c403a4e4d8bf69973c159b40cc346fb9
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
|
@ -424,11 +424,7 @@ function (define_gpu_extension_target GPU_MOD_NAME)
|
||||
# Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of
|
||||
# dependencies that are not necessary and may not be installed.
|
||||
if (GPU_LANGUAGE STREQUAL "CUDA")
|
||||
if ("${CUDA_CUDA_LIB}" STREQUAL "")
|
||||
set(CUDA_CUDA_LIB "${CUDA_CUDA_LIBRARY}")
|
||||
endif()
|
||||
target_link_libraries(${GPU_MOD_NAME} PRIVATE ${CUDA_CUDA_LIB}
|
||||
${CUDA_LIBRARIES})
|
||||
target_link_libraries(${GPU_MOD_NAME} PRIVATE CUDA::cudart CUDA::cuda_driver)
|
||||
else()
|
||||
target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES})
|
||||
endif()
|
||||
|
@ -6,7 +6,7 @@ requires = [
|
||||
"packaging",
|
||||
"setuptools>=61",
|
||||
"setuptools-scm>=8.0",
|
||||
"torch == 2.4.0",
|
||||
"torch == 2.5.0",
|
||||
"wheel",
|
||||
"jinja2",
|
||||
]
|
||||
|
@ -4,6 +4,6 @@ ninja
|
||||
packaging
|
||||
setuptools>=61
|
||||
setuptools-scm>=8
|
||||
torch==2.4.0
|
||||
torch==2.5.0
|
||||
wheel
|
||||
jinja2
|
||||
|
@ -4,7 +4,7 @@
|
||||
# Dependencies for NVIDIA GPUs
|
||||
ray >= 2.9
|
||||
nvidia-ml-py # for pynvml package
|
||||
torch == 2.4.0
|
||||
torch == 2.5.0
|
||||
# These must be updated alongside torch
|
||||
torchvision == 0.19 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
||||
xformers == 0.0.27.post2; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.4.0
|
||||
torchvision == 0.20 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
|
||||
xformers == 0.0.28.post2; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.5.0
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Common dependencies
|
||||
-r requirements-common.txt
|
||||
|
||||
torch == 2.4.0 # should be aligned with "common" vLLM torch version
|
||||
torch == 2.5.0 # should be aligned with "common" vLLM torch version
|
||||
openvino >= 2024.4.0 # since 2024.4.0 both CPU and GPU support Paged Attention
|
||||
|
||||
optimum @ git+https://github.com/huggingface/optimum.git@main # latest optimum is used to support latest transformers version
|
||||
|
@ -8,7 +8,7 @@ import pytest
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from ...utils import check_outputs_equal
|
||||
from ...utils import check_logprobs_close, check_outputs_equal
|
||||
|
||||
MODELS = [
|
||||
"meta-llama/Llama-2-7b-hf",
|
||||
@ -43,18 +43,40 @@ def test_models(
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
if model == "openbmb/MiniCPM3-4B":
|
||||
# the output becomes slightly different when upgrading to
|
||||
# pytorch 2.5 . Changing to logprobs checks instead of exact
|
||||
# output checks.
|
||||
NUM_LOG_PROBS = 8
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, NUM_LOG_PROBS)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, NUM_LOG_PROBS)
|
||||
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
else:
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
||||
with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model:
|
||||
vllm_outputs = vllm_model.generate_greedy(example_prompts,
|
||||
max_tokens)
|
||||
|
||||
check_outputs_equal(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
|
@ -7,6 +7,7 @@ from functools import lru_cache, wraps
|
||||
from typing import Callable, List, Tuple, TypeVar
|
||||
|
||||
import pynvml
|
||||
import torch
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.logger import init_logger
|
||||
@ -26,6 +27,10 @@ if pynvml.__file__.endswith("__init__.py"):
|
||||
" and cause errors. See https://pypi.org/project/pynvml "
|
||||
"for more information.")
|
||||
|
||||
# pytorch 2.5 uses cudnn sdpa by default, which will cause crash on some models
|
||||
# see https://github.com/huggingface/diffusers/issues/9704 for details
|
||||
torch.backends.cuda.enable_cudnn_sdp(False)
|
||||
|
||||
# NVML utils
|
||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||
# all the related functions work on real physical device ids.
|
||||
|
Loading…
x
Reference in New Issue
Block a user