[ci][distributed] fix device count call
[ci][distributed] fix some cuda init that makes it necessary to use spawn (#5991)
This commit is contained in:
parent
9d47f64eb6
commit
2be6955a3f
@ -45,9 +45,6 @@ steps:
|
|||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
commands:
|
commands:
|
||||||
- bash ../.buildkite/download-images.sh
|
- bash ../.buildkite/download-images.sh
|
||||||
# FIXIT: find out which code initialize cuda before running the test
|
|
||||||
# before the fix, we need to use spawn to test it
|
|
||||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
|
||||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
|
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
|
||||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||||
@ -60,8 +57,7 @@ steps:
|
|||||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||||
- TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
|
- TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
|
||||||
# FIXIT: find out why TP is failing with mp backend on phi3-v
|
- TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
|
||||||
# - TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
|
|
||||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
|
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
|
||||||
@ -71,9 +67,6 @@ steps:
|
|||||||
working_dir: "/vllm-workspace/tests"
|
working_dir: "/vllm-workspace/tests"
|
||||||
num_gpus: 4
|
num_gpus: 4
|
||||||
commands:
|
commands:
|
||||||
# FIXIT: find out which code initialize cuda before running the test
|
|
||||||
# before the fix, we need to use spawn to test it
|
|
||||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
|
||||||
- pytest -v -s distributed/test_pynccl.py
|
- pytest -v -s distributed/test_pynccl.py
|
||||||
# We want to test that models which use 2 GPUs work with 4 GPUs, which is why we duplicate them here.
|
# We want to test that models which use 2 GPUs work with 4 GPUs, which is why we duplicate them here.
|
||||||
# See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context.
|
# See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context.
|
||||||
@ -225,9 +218,6 @@ steps:
|
|||||||
gpu: a100
|
gpu: a100
|
||||||
num_gpus: 4
|
num_gpus: 4
|
||||||
commands:
|
commands:
|
||||||
# FIXIT: find out which code initialize cuda before running the test
|
|
||||||
# before the fix, we need to use spawn to test it
|
|
||||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
|
||||||
# NOTE: don't test llama model here, it seems hf implementation is buggy
|
# NOTE: don't test llama model here, it seems hf implementation is buggy
|
||||||
# see https://github.com/vllm-project/vllm/pull/5689 for details
|
# see https://github.com/vllm-project/vllm/pull/5689 for details
|
||||||
- pytest -v -s distributed/test_custom_all_reduce.py
|
- pytest -v -s distributed/test_custom_all_reduce.py
|
||||||
|
@ -5,8 +5,8 @@ from collections import UserList
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (Any, Dict, List, Literal, Optional, Tuple, TypedDict,
|
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple,
|
||||||
TypeVar)
|
TypedDict, TypeVar)
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
@ -14,7 +14,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
|
from transformers import (AutoModelForCausalLM, AutoModelForVision2Seq,
|
||||||
AutoProcessor, AutoTokenizer, BatchEncoding)
|
AutoTokenizer, BatchEncoding)
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
|
from vllm.config import TokenizerPoolConfig, VisionLanguageConfig
|
||||||
@ -22,8 +22,12 @@ from vllm.distributed import (destroy_distributed_environment,
|
|||||||
destroy_model_parallel)
|
destroy_model_parallel)
|
||||||
from vllm.inputs import TextPrompt
|
from vllm.inputs import TextPrompt
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.multimodal import MultiModalData
|
|
||||||
from vllm.multimodal.image import ImageFeatureData, ImagePixelData
|
if TYPE_CHECKING:
|
||||||
|
from vllm.multimodal import MultiModalData
|
||||||
|
else:
|
||||||
|
# it will call torch.cuda.device_count()
|
||||||
|
MultiModalData = None
|
||||||
from vllm.sequence import SampleLogprobs
|
from vllm.sequence import SampleLogprobs
|
||||||
from vllm.utils import cuda_device_count_stateless, is_cpu
|
from vllm.utils import cuda_device_count_stateless, is_cpu
|
||||||
|
|
||||||
@ -63,6 +67,10 @@ class ImageAsset:
|
|||||||
return self.pil_image
|
return self.pil_image
|
||||||
|
|
||||||
def for_vllm(self, vision_config: VisionLanguageConfig) -> MultiModalData:
|
def for_vllm(self, vision_config: VisionLanguageConfig) -> MultiModalData:
|
||||||
|
# don't put this import at the top level
|
||||||
|
# it will call torch.cuda.device_count()
|
||||||
|
from vllm.multimodal.image import ImageFeatureData # noqa: F401
|
||||||
|
from vllm.multimodal.image import ImagePixelData
|
||||||
image_input_type = vision_config.image_input_type
|
image_input_type = vision_config.image_input_type
|
||||||
ImageInputType = VisionLanguageConfig.ImageInputType
|
ImageInputType = VisionLanguageConfig.ImageInputType
|
||||||
|
|
||||||
@ -216,6 +224,9 @@ class HfRunner:
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# don't put this import at the top level
|
||||||
|
# it will call torch.cuda.device_count()
|
||||||
|
from transformers import AutoProcessor # noqa: F401
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
self.processor = AutoProcessor.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
|
@ -15,7 +15,8 @@ TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
|
||||||
|
from vllm.utils import cuda_device_count_stateless
|
||||||
|
|
||||||
from ..models.utils import check_outputs_equal
|
from ..models.utils import check_outputs_equal
|
||||||
|
|
||||||
@ -25,7 +26,7 @@ MODELS = [
|
|||||||
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"
|
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(cuda_device_count_stateless() < 2,
|
||||||
reason="Need at least 2 GPUs to run the test.")
|
reason="Need at least 2 GPUs to run the test.")
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
@ -40,9 +41,10 @@ def test_models(
|
|||||||
) -> None:
|
) -> None:
|
||||||
distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND)
|
distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND)
|
||||||
|
|
||||||
with hf_runner(model, dtype=dtype) as hf_model:
|
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
# vLLM needs a fresh new process without cuda initialization.
|
||||||
|
# if we run HF first, the cuda initialization will be done and it
|
||||||
|
# will hurt multiprocessing backend with fork method (the default method).
|
||||||
with vllm_runner(model,
|
with vllm_runner(model,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
tensor_parallel_size=2,
|
tensor_parallel_size=2,
|
||||||
@ -50,6 +52,9 @@ def test_models(
|
|||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|
||||||
|
with hf_runner(model, dtype=dtype) as hf_model:
|
||||||
|
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|
||||||
check_outputs_equal(
|
check_outputs_equal(
|
||||||
outputs_0_lst=hf_outputs,
|
outputs_0_lst=hf_outputs,
|
||||||
outputs_1_lst=vllm_outputs,
|
outputs_1_lst=vllm_outputs,
|
||||||
|
@ -14,7 +14,8 @@ TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf \
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
|
||||||
|
from vllm.utils import cuda_device_count_stateless
|
||||||
|
|
||||||
from ..models.utils import check_outputs_equal
|
from ..models.utils import check_outputs_equal
|
||||||
|
|
||||||
@ -24,7 +25,7 @@ MODELS = [
|
|||||||
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"
|
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(cuda_device_count_stateless() < 2,
|
||||||
reason="Need at least 2 GPUs to run the test.")
|
reason="Need at least 2 GPUs to run the test.")
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@pytest.mark.parametrize("dtype", ["half"])
|
@pytest.mark.parametrize("dtype", ["half"])
|
||||||
@ -47,8 +48,10 @@ def test_models(
|
|||||||
enable_chunked_prefill = True
|
enable_chunked_prefill = True
|
||||||
max_num_batched_tokens = chunked_prefill_token_size
|
max_num_batched_tokens = chunked_prefill_token_size
|
||||||
|
|
||||||
with hf_runner(model, dtype=dtype) as hf_model:
|
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||||
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
# vLLM needs a fresh new process without cuda initialization.
|
||||||
|
# if we run HF first, the cuda initialization will be done and it
|
||||||
|
# will hurt multiprocessing backend with fork method (the default method).
|
||||||
|
|
||||||
with vllm_runner(
|
with vllm_runner(
|
||||||
model,
|
model,
|
||||||
@ -61,6 +64,9 @@ def test_models(
|
|||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|
||||||
|
with hf_runner(model, dtype=dtype) as hf_model:
|
||||||
|
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|
||||||
check_outputs_equal(
|
check_outputs_equal(
|
||||||
outputs_0_lst=hf_outputs,
|
outputs_0_lst=hf_outputs,
|
||||||
outputs_1_lst=vllm_outputs,
|
outputs_1_lst=vllm_outputs,
|
||||||
|
@ -88,17 +88,11 @@ def run_test(
|
|||||||
"""
|
"""
|
||||||
model_id, vlm_config = model_and_config
|
model_id, vlm_config = model_and_config
|
||||||
hf_images = [asset.for_hf() for asset in image_assets]
|
hf_images = [asset.for_hf() for asset in image_assets]
|
||||||
vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets]
|
|
||||||
|
|
||||||
with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model:
|
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||||
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS,
|
# vLLM needs a fresh new process without cuda initialization.
|
||||||
max_tokens,
|
# if we run HF first, the cuda initialization will be done and it
|
||||||
images=hf_images)
|
# will hurt multiprocessing backend with fork method (the default method).
|
||||||
|
|
||||||
vllm_image_prompts = [
|
|
||||||
p.replace("<image>", "<image>" * vlm_config.image_feature_size)
|
|
||||||
for p in HF_IMAGE_PROMPTS
|
|
||||||
]
|
|
||||||
|
|
||||||
with vllm_runner(model_id,
|
with vllm_runner(model_id,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
@ -106,10 +100,26 @@ def run_test(
|
|||||||
distributed_executor_backend=distributed_executor_backend,
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
**vlm_config.as_cli_args_dict()) as vllm_model:
|
**vlm_config.as_cli_args_dict()) as vllm_model:
|
||||||
|
|
||||||
|
# NOTE: `asset.for_vllm` will call `torch.cuda.device_count()`
|
||||||
|
# we must put it inside the vllm_runner context manager
|
||||||
|
# i.e. after creating vLLM instance.
|
||||||
|
vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets]
|
||||||
|
|
||||||
|
vllm_image_prompts = [
|
||||||
|
p.replace("<image>", "<image>" * vlm_config.image_feature_size)
|
||||||
|
for p in HF_IMAGE_PROMPTS
|
||||||
|
]
|
||||||
|
|
||||||
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
|
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
images=vllm_images)
|
images=vllm_images)
|
||||||
|
|
||||||
|
with hf_runner(model_id, dtype=dtype, is_vision_model=True) as hf_model:
|
||||||
|
hf_outputs = hf_model.generate_greedy(HF_IMAGE_PROMPTS,
|
||||||
|
max_tokens,
|
||||||
|
images=hf_images)
|
||||||
|
|
||||||
check_outputs_equal(
|
check_outputs_equal(
|
||||||
hf_outputs,
|
hf_outputs,
|
||||||
[
|
[
|
||||||
|
@ -96,7 +96,34 @@ def run_test(
|
|||||||
"""
|
"""
|
||||||
model_id, vlm_config = model_and_config
|
model_id, vlm_config = model_and_config
|
||||||
hf_images = [asset.for_hf() for asset in image_assets]
|
hf_images = [asset.for_hf() for asset in image_assets]
|
||||||
vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets]
|
|
||||||
|
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||||
|
# vLLM needs a fresh new process without cuda initialization.
|
||||||
|
# if we run HF first, the cuda initialization will be done and it
|
||||||
|
# will hurt multiprocessing backend with fork method (the default method).
|
||||||
|
|
||||||
|
with vllm_runner(model_id,
|
||||||
|
max_model_len=2048,
|
||||||
|
dtype=dtype,
|
||||||
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
|
enforce_eager=True,
|
||||||
|
distributed_executor_backend=distributed_executor_backend,
|
||||||
|
**vlm_config.as_cli_args_dict()) as vllm_model:
|
||||||
|
# NOTE: `asset.for_vllm` will call `torch.cuda.device_count()`
|
||||||
|
# we must put it inside the vllm_runner context manager
|
||||||
|
# i.e. after creating vLLM instance.
|
||||||
|
|
||||||
|
vllm_images = [asset.for_vllm(vlm_config) for asset in image_assets]
|
||||||
|
|
||||||
|
vllm_image_prompts = [
|
||||||
|
p.replace("<|image_1|>",
|
||||||
|
"<|image|>" * vlm_config.image_feature_size + "<s>")
|
||||||
|
for p in HF_IMAGE_PROMPTS
|
||||||
|
]
|
||||||
|
|
||||||
|
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
|
||||||
|
max_tokens,
|
||||||
|
images=vllm_images)
|
||||||
|
|
||||||
# use eager mode for hf runner, since phi3_v didn't work with flash_attn
|
# use eager mode for hf runner, since phi3_v didn't work with flash_attn
|
||||||
hf_model_kwargs = {"_attn_implementation": "eager"}
|
hf_model_kwargs = {"_attn_implementation": "eager"}
|
||||||
@ -108,23 +135,6 @@ def run_test(
|
|||||||
images=hf_images,
|
images=hf_images,
|
||||||
eos_token_id=hf_model.processor.tokenizer.eos_token_id)
|
eos_token_id=hf_model.processor.tokenizer.eos_token_id)
|
||||||
|
|
||||||
vllm_image_prompts = [
|
|
||||||
p.replace("<|image_1|>",
|
|
||||||
"<|image|>" * vlm_config.image_feature_size + "<s>")
|
|
||||||
for p in HF_IMAGE_PROMPTS
|
|
||||||
]
|
|
||||||
|
|
||||||
with vllm_runner(model_id,
|
|
||||||
max_model_len=2048,
|
|
||||||
dtype=dtype,
|
|
||||||
tensor_parallel_size=tensor_parallel_size,
|
|
||||||
enforce_eager=True,
|
|
||||||
distributed_executor_backend=distributed_executor_backend,
|
|
||||||
**vlm_config.as_cli_args_dict()) as vllm_model:
|
|
||||||
vllm_outputs = vllm_model.generate_greedy(vllm_image_prompts,
|
|
||||||
max_tokens,
|
|
||||||
images=vllm_images)
|
|
||||||
|
|
||||||
check_outputs_equal(
|
check_outputs_equal(
|
||||||
hf_outputs,
|
hf_outputs,
|
||||||
[
|
[
|
||||||
|
Loading…
x
Reference in New Issue
Block a user