[ci][distributed] fix device count call

[ci][distributed] fix some cuda init that makes it necessary to use spawn (#5991)
This commit is contained in:
youkaichao 2024-06-30 01:06:13 -07:00 committed by GitHub
parent 9d47f64eb6
commit 2be6955a3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 85 additions and 53 deletions

View File

@ -45,9 +45,6 @@ steps:
num_gpus: 2 num_gpus: 2
commands: commands:
- bash ../.buildkite/download-images.sh - bash ../.buildkite/download-images.sh
# FIXIT: find out which code initialize cuda before running the test
# before the fix, we need to use spawn to test it
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
@ -60,8 +57,7 @@ steps:
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py - TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
# FIXIT: find out why TP is failing with mp backend on phi3-v - TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
# - TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
@ -71,9 +67,6 @@ steps:
working_dir: "/vllm-workspace/tests" working_dir: "/vllm-workspace/tests"
num_gpus: 4 num_gpus: 4
commands: commands:
# FIXIT: find out which code initialize cuda before running the test
# before the fix, we need to use spawn to test it
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s distributed/test_pynccl.py - pytest -v -s distributed/test_pynccl.py
# We want to test that models which use 2 GPUs work with 4 GPUs, which is why we duplicate them here. # We want to test that models which use 2 GPUs work with 4 GPUs, which is why we duplicate them here.
# See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context. # See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context.
@ -225,9 +218,6 @@ steps:
gpu: a100 gpu: a100
num_gpus: 4 num_gpus: 4
commands: commands:
# FIXIT: find out which code initialize cuda before running the test
# before the fix, we need to use spawn to test it
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
# NOTE: don't test llama model here, it seems hf implementation is buggy # NOTE: don't test llama model here, it seems hf implementation is buggy
# see https://github.com/vllm-project/vllm/pull/5689 for details # see https://github.com/vllm-project/vllm/pull/5689 for details
- pytest -v -s distributed/test_custom_all_reduce.py - pytest -v -s distributed/test_custom_all_reduce.py

View File

@ -5,8 +5,8 @@ from collections import UserList
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property from functools import cached_property
from pathlib import Path from pathlib import Path
from typing import (Any, Dict, List, Literal, Optional, Tuple, TypedDict, from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple,
TypeVar) TypedDict, TypeVar)
import pytest import pytest
import torch import torch
@ -14,7 +14,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from PIL import Image from PIL import Image
from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq, from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
AutoProcessor, AutoTokenizer, BatchEncoding) AutoTokenizer, BatchEncoding)
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
@ -22,8 +22,12 @@ from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel) destroy_model_parallel)
from vllm.inputs import TextPrompt from vllm.inputs import TextPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
if TYPE_CHECKING:
from vllm.multimodal import MultiModalData from vllm.multimodal import MultiModalData
from vllm.multimodal.image import ImageFeatureData, ImagePixelData else:
# it will call torch.cuda.device_count()
MultiModalData = None
from vllm.sequence import SampleLogprobs from vllm.sequence import SampleLogprobs
from vllm.utils import cuda_device_count_stateless, is_cpu from vllm.utils import cuda_device_count_stateless, is_cpu
@ -63,6 +67,10 @@ class ImageAsset:
return self.pil_image return self.pil_image
def for_vllm(self, vision_config: VisionLanguageConfig) -> MultiModalData: def for_vllm(self, vision_config: VisionLanguageConfig) -> MultiModalData:
# don't put this import at the top level
# it will call torch.cuda.device_count()
from vllm.multimodal.image import ImageFeatureData # noqa: F401
from vllm.multimodal.image import ImagePixelData
image_input_type = vision_config.image_input_type image_input_type = vision_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType ImageInputType = VisionLanguageConfig.ImageInputType
@ -216,6 +224,9 @@ class HfRunner:
) )
try: try:
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoProcessor # noqa: F401
self.processor = AutoProcessor.from_pretrained( self.processor = AutoProcessor.from_pretrained(
model_name, model_name,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,

View File

@ -15,7 +15,8 @@ TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \
import os import os
import pytest import pytest
import torch
from vllm.utils import cuda_device_count_stateless
from ..models.utils import check_outputs_equal from ..models.utils import check_outputs_equal
@ -25,7 +26,7 @@ MODELS = [
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND" DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.skipif(cuda_device_count_stateless() < 2,
reason="Need at least 2 GPUs to run the test.") reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@ -40,9 +41,10 @@ def test_models(
) -> None: ) -> None:
distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND) distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND)
with hf_runner(model, dtype=dtype) as hf_model: # NOTE: take care of the order. run vLLM first, and then run HF.
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) # vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(model, with vllm_runner(model,
dtype=dtype, dtype=dtype,
tensor_parallel_size=2, tensor_parallel_size=2,
@ -50,6 +52,9 @@ def test_models(
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
check_outputs_equal( check_outputs_equal(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs, outputs_1_lst=vllm_outputs,

View File

@ -14,7 +14,8 @@ TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \
import os import os
import pytest import pytest
import torch
from vllm.utils import cuda_device_count_stateless
from ..models.utils import check_outputs_equal from ..models.utils import check_outputs_equal
@ -24,7 +25,7 @@ MODELS = [
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND" DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"
@pytest.mark.skipif(torch.cuda.device_count() < 2, @pytest.mark.skipif(cuda_device_count_stateless() < 2,
reason="Need at least 2 GPUs to run the test.") reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@ -47,8 +48,10 @@ def test_models(
enable_chunked_prefill = True enable_chunked_prefill = True
max_num_batched_tokens = chunked_prefill_token_size max_num_batched_tokens = chunked_prefill_token_size
with hf_runner(model, dtype=dtype) as hf_model: # NOTE: take care of the order. run vLLM first, and then run HF.
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) # vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with vllm_runner( with vllm_runner(
model, model,
@ -61,6 +64,9 @@ def test_models(
) as vllm_model: ) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
check_outputs_equal( check_outputs_equal(
outputs_0_lst=hf_outputs, outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs, outputs_1_lst=vllm_outputs,

View File

@ -88,17 +88,11 @@ def run_test(
""" """
model_id, vlm_config = model_and_config model_id, vlm_config = model_and_config
hf_images = [asset.for_hf() for asset in image_assets] hf_images = [asset.for_hf() for asset in image_assets]
vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets]
with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model: # NOTE: take care of the order. run vLLM first, and then run HF.
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS, # vLLM needs a fresh new process without cuda initialization.
max_tokens, # if we run HF first, the cuda initialization will be done and it
images=hf_images) # will hurt multiprocessing backend with fork method (the default method).
vllm_image_prompts = [
p.replace("<image>", "<image>" * vlm_config.image_feature_size)
for p in HF_IMAGE_PROMPTS
]
with vllm_runner(model_id, with vllm_runner(model_id,
dtype=dtype, dtype=dtype,
@ -106,10 +100,26 @@ def run_test(
distributed_executor_backend=distributed_executor_backend, distributed_executor_backend=distributed_executor_backend,
enforce_eager=True, enforce_eager=True,
**vlm_config.as_cli_args_dict()) as vllm_model: **vlm_config.as_cli_args_dict()) as vllm_model:
# NOTE: `asset.for_vllm` will call `torch.cuda.device_count()`
# we must put it inside the vllm_runner context manager
# i.e. after creating vLLM instance.
vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets]
vllm_image_prompts = [
p.replace("<image>", "<image>" * vlm_config.image_feature_size)
for p in HF_IMAGE_PROMPTS
]
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts, vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
max_tokens, max_tokens,
images=vllm_images) images=vllm_images)
with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model:
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS,
max_tokens,
images=hf_images)
check_outputs_equal( check_outputs_equal(
hf_outputs, hf_outputs,
[ [

View File

@ -96,8 +96,35 @@ def run_test(
""" """
model_id, vlm_config = model_and_config model_id, vlm_config = model_and_config
hf_images = [asset.for_hf() for asset in image_assets] hf_images = [asset.for_hf() for asset in image_assets]
# NOTE: take care of the order. run vLLM first, and then run HF.
# vLLM needs a fresh new process without cuda initialization.
# if we run HF first, the cuda initialization will be done and it
# will hurt multiprocessing backend with fork method (the default method).
with vllm_runner(model_id,
max_model_len=2048,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
enforce_eager=True,
distributed_executor_backend=distributed_executor_backend,
**vlm_config.as_cli_args_dict()) as vllm_model:
# NOTE: `asset.for_vllm` will call `torch.cuda.device_count()`
# we must put it inside the vllm_runner context manager
# i.e. after creating vLLM instance.
vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets] vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets]
vllm_image_prompts = [
p.replace("<|image_1|>",
"<|image|>" * vlm_config.image_feature_size + "<s>")
for p in HF_IMAGE_PROMPTS
]
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
max_tokens,
images=vllm_images)
# use eager mode for hf runner, since phi3_v didn't work with flash_attn # use eager mode for hf runner, since phi3_v didn't work with flash_attn
hf_model_kwargs = {"_attn_implementation": "eager"} hf_model_kwargs = {"_attn_implementation": "eager"}
with hf_runner(model_id, dtype=dtype, with hf_runner(model_id, dtype=dtype,
@ -108,23 +135,6 @@ def run_test(
images=hf_images, images=hf_images,
eos_token_id=hf_model.processor.tokenizer.eos_token_id) eos_token_id=hf_model.processor.tokenizer.eos_token_id)
vllm_image_prompts = [
p.replace("<|image_1|>",
"<|image|>" * vlm_config.image_feature_size + "<s>")
for p in HF_IMAGE_PROMPTS
]
with vllm_runner(model_id,
max_model_len=2048,
dtype=dtype,
tensor_parallel_size=tensor_parallel_size,
enforce_eager=True,
distributed_executor_backend=distributed_executor_backend,
**vlm_config.as_cli_args_dict()) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
max_tokens,
images=vllm_images)
check_outputs_equal( check_outputs_equal(
hf_outputs, hf_outputs,
[ [