[Bugfix] Fix multi nodes TP+PP for XPU (#8884)

Signed-off-by: YiSheng5 <syhm@mail.ustc.edu.cn>
Signed-off-by: yan ma <yan.ma@intel.com>
Co-authored-by: YiSheng5 <syhm@mail.ustc.edu.cn>
This commit is contained in:
Yan Ma 2024-10-30 12:34:45 +08:00 committed by GitHub
parent 62fac4b9aa
commit 04a3ae0aca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 63 additions and 11 deletions

View File

@ -60,3 +60,21 @@ Build from source
- FP16 is the default data type in the current XPU backend. The BF16 data - FP16 is the default data type in the current XPU backend. The BF16 data
type will be supported in the future. type will be supported in the future.
Distributed inference and serving
---------------------------------
XPU platform supports tensor-parallel inference/serving and also supports pipeline parallel as a beta feature for online serving. We requires Ray as the distributed runtime backend. For example, a reference execution likes following:
.. code-block:: console
$ python -m vllm.entrypoints.openai.api_server \
$ --model=facebook/opt-13b \
$ --dtype=bfloat16 \
$ --device=xpu \
$ --max_model_len=1024 \
$ --distributed-executor-backend=ray \
$ --pipeline-parallel-size=2 \
$ -tp=8
By default, a ray instance will be launched automatically if no existing one is detected in system, with ``num-gpus`` equals to ``parallel_config.world_size``. We recommend properly starting a ray cluster before execution, referring helper `script <https://github.com/vllm-project/vllm/tree/main/examples/run_cluster.sh>`_.

View File

@ -13,4 +13,4 @@ torch == 2.3.1+cxx11.abi
intel-extension-for-pytorch == 2.3.110+xpu intel-extension-for-pytorch == 2.3.110+xpu
oneccl_bind_pt == 2.3.100+xpu oneccl_bind_pt == 2.3.100+xpu
triton-xpu == 3.0.0b2 triton-xpu == 3.0.0b1

View File

@ -431,6 +431,28 @@ class GroupCoordinator:
if dim < 0: if dim < 0:
# Convert negative dim to positive. # Convert negative dim to positive.
dim += input_.dim() dim += input_.dim()
# For xpu path, gather doesn't work properly together with ray
# cluster so we use all_gather instead for now.
if current_platform.is_xpu():
input_size = input_.size()
# Allocate output tensor.
output_tensor = torch.empty((world_size, ) + input_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
torch.distributed.all_gather_into_tensor(output_tensor,
input_,
group=self.device_group)
if self.rank_in_group == dst:
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(world_size *
input_size[dim], ) +
input_size[dim + 1:])
else:
output_tensor = None
return output_tensor
# Allocate output tensor. # Allocate output tensor.
if self.rank_in_group == dst: if self.rank_in_group == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)] gather_list = [torch.empty_like(input_) for _ in range(world_size)]

View File

@ -44,7 +44,7 @@ class XPUExecutor(GPUExecutor):
self.cache_config = cache_config self.cache_config = cache_config
self.load_config = load_config self.load_config = load_config
self.lora_config = lora_config self.lora_config = lora_config
self.parallel_config = parallel_config self.parallel_config = _verify_and_get_parallel_config(parallel_config)
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config self.device_config = device_config
self.prompt_adapter_config = prompt_adapter_config self.prompt_adapter_config = prompt_adapter_config
@ -94,3 +94,13 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
"mode.") "mode.")
config.enforce_eager = True config.enforce_eager = True
return config return config
def _verify_and_get_parallel_config(config: ParallelConfig) -> ParallelConfig:
if (config.distributed_executor_backend is not None
and config.distributed_executor_backend != "ray"):
logger.warning(
"%s is not supported on XPU, fallback to ray distributed executor "
"backend.", config.distributed_executor_backend)
config.distributed_executor_backend = "ray"
return config

View File

@ -45,6 +45,9 @@ except Exception:
is_xpu = False is_xpu = False
try: try:
# installed IPEX if the machine has XPUs.
import intel_extension_for_pytorch # noqa: F401
import oneccl_bindings_for_pytorch # noqa: F401
import torch import torch
if hasattr(torch, 'xpu') and torch.xpu.is_available(): if hasattr(torch, 'xpu') and torch.xpu.is_available():
is_xpu = True is_xpu = True

View File

@ -20,3 +20,7 @@ class XPUPlatform(Platform):
def get_device_total_memory(cls, device_id: int = 0) -> int: def get_device_total_memory(cls, device_id: int = 0) -> int:
device_props = torch.xpu.get_device_properties(device_id) device_props = torch.xpu.get_device_properties(device_id)
return device_props.total_memory return device_props.total_memory
@staticmethod
def inference_mode():
return torch.no_grad()

View File

@ -14,7 +14,6 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
SpeculativeConfig) SpeculativeConfig)
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment) init_distributed_environment)
from vllm.distributed.parallel_state import get_pp_group
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.platforms import current_platform from vllm.platforms import current_platform
@ -183,11 +182,10 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
# use sockets as default Level zero IPC exchange backend. By # use sockets as default Level zero IPC exchange backend. By
# default oneccl will use `drmfd` as mechanism which need extra # default oneccl will use `drmfd` as mechanism which need extra
# dependency (libdrm and drm headers) on your system. # dependency (libdrm and drm headers) on your system.
ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE", ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi")
"sockets")
ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE", ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE",
str(parallel_config.world_size)) str(parallel_config.world_size))
os.environ['CCL_ZE_IPC_EXCHANGE'] = ENV_CCL_ZE_IPC_EXCHANGE os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT
os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
os.environ["LOCAL_RANK"] = str(self.local_rank) os.environ["LOCAL_RANK"] = str(self.local_rank)
init_distributed_environment( init_distributed_environment(
@ -200,8 +198,5 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
ensure_model_parallel_initialized( ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size, parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size)
# global all_reduce needed for overall oneccl warm up
if parallel_config.pipeline_parallel_size > 1: torch.distributed.all_reduce(torch.zeros(1).xpu())
# torch-ccl xpu need a collective API warm up
# before calling send/recv API
get_pp_group().all_reduce(torch.zeros(1).xpu())