[Bugfix] Fix multi nodes TP+PP for XPU (#8884)
Signed-off-by: YiSheng5 <syhm@mail.ustc.edu.cn> Signed-off-by: yan ma <yan.ma@intel.com> Co-authored-by: YiSheng5 <syhm@mail.ustc.edu.cn>
This commit is contained in:
parent
62fac4b9aa
commit
04a3ae0aca
@ -60,3 +60,21 @@ Build from source
|
|||||||
- FP16 is the default data type in the current XPU backend. The BF16 data
|
- FP16 is the default data type in the current XPU backend. The BF16 data
|
||||||
type will be supported in the future.
|
type will be supported in the future.
|
||||||
|
|
||||||
|
|
||||||
|
Distributed inference and serving
|
||||||
|
---------------------------------
|
||||||
|
|
||||||
|
XPU platform supports tensor-parallel inference/serving and also supports pipeline parallel as a beta feature for online serving. We requires Ray as the distributed runtime backend. For example, a reference execution likes following:
|
||||||
|
|
||||||
|
.. code-block:: console
|
||||||
|
|
||||||
|
$ python -m vllm.entrypoints.openai.api_server \
|
||||||
|
$ --model=facebook/opt-13b \
|
||||||
|
$ --dtype=bfloat16 \
|
||||||
|
$ --device=xpu \
|
||||||
|
$ --max_model_len=1024 \
|
||||||
|
$ --distributed-executor-backend=ray \
|
||||||
|
$ --pipeline-parallel-size=2 \
|
||||||
|
$ -tp=8
|
||||||
|
|
||||||
|
By default, a ray instance will be launched automatically if no existing one is detected in system, with ``num-gpus`` equals to ``parallel_config.world_size``. We recommend properly starting a ray cluster before execution, referring helper `script <https://github.com/vllm-project/vllm/tree/main/examples/run_cluster.sh>`_.
|
||||||
|
@ -13,4 +13,4 @@ torch == 2.3.1+cxx11.abi
|
|||||||
intel-extension-for-pytorch == 2.3.110+xpu
|
intel-extension-for-pytorch == 2.3.110+xpu
|
||||||
oneccl_bind_pt == 2.3.100+xpu
|
oneccl_bind_pt == 2.3.100+xpu
|
||||||
|
|
||||||
triton-xpu == 3.0.0b2
|
triton-xpu == 3.0.0b1
|
||||||
|
@ -431,6 +431,28 @@ class GroupCoordinator:
|
|||||||
if dim < 0:
|
if dim < 0:
|
||||||
# Convert negative dim to positive.
|
# Convert negative dim to positive.
|
||||||
dim += input_.dim()
|
dim += input_.dim()
|
||||||
|
# For xpu path, gather doesn't work properly together with ray
|
||||||
|
# cluster so we use all_gather instead for now.
|
||||||
|
if current_platform.is_xpu():
|
||||||
|
input_size = input_.size()
|
||||||
|
# Allocate output tensor.
|
||||||
|
output_tensor = torch.empty((world_size, ) + input_size,
|
||||||
|
dtype=input_.dtype,
|
||||||
|
device=input_.device)
|
||||||
|
# All-gather.
|
||||||
|
torch.distributed.all_gather_into_tensor(output_tensor,
|
||||||
|
input_,
|
||||||
|
group=self.device_group)
|
||||||
|
if self.rank_in_group == dst:
|
||||||
|
# Reshape
|
||||||
|
output_tensor = output_tensor.movedim(0, dim)
|
||||||
|
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||||
|
(world_size *
|
||||||
|
input_size[dim], ) +
|
||||||
|
input_size[dim + 1:])
|
||||||
|
else:
|
||||||
|
output_tensor = None
|
||||||
|
return output_tensor
|
||||||
# Allocate output tensor.
|
# Allocate output tensor.
|
||||||
if self.rank_in_group == dst:
|
if self.rank_in_group == dst:
|
||||||
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
|
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||||
|
@ -44,7 +44,7 @@ class XPUExecutor(GPUExecutor):
|
|||||||
self.cache_config = cache_config
|
self.cache_config = cache_config
|
||||||
self.load_config = load_config
|
self.load_config = load_config
|
||||||
self.lora_config = lora_config
|
self.lora_config = lora_config
|
||||||
self.parallel_config = parallel_config
|
self.parallel_config = _verify_and_get_parallel_config(parallel_config)
|
||||||
self.scheduler_config = scheduler_config
|
self.scheduler_config = scheduler_config
|
||||||
self.device_config = device_config
|
self.device_config = device_config
|
||||||
self.prompt_adapter_config = prompt_adapter_config
|
self.prompt_adapter_config = prompt_adapter_config
|
||||||
@ -94,3 +94,13 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
|
|||||||
"mode.")
|
"mode.")
|
||||||
config.enforce_eager = True
|
config.enforce_eager = True
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_and_get_parallel_config(config: ParallelConfig) -> ParallelConfig:
|
||||||
|
if (config.distributed_executor_backend is not None
|
||||||
|
and config.distributed_executor_backend != "ray"):
|
||||||
|
logger.warning(
|
||||||
|
"%s is not supported on XPU, fallback to ray distributed executor "
|
||||||
|
"backend.", config.distributed_executor_backend)
|
||||||
|
config.distributed_executor_backend = "ray"
|
||||||
|
return config
|
||||||
|
@ -45,6 +45,9 @@ except Exception:
|
|||||||
is_xpu = False
|
is_xpu = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# installed IPEX if the machine has XPUs.
|
||||||
|
import intel_extension_for_pytorch # noqa: F401
|
||||||
|
import oneccl_bindings_for_pytorch # noqa: F401
|
||||||
import torch
|
import torch
|
||||||
if hasattr(torch, 'xpu') and torch.xpu.is_available():
|
if hasattr(torch, 'xpu') and torch.xpu.is_available():
|
||||||
is_xpu = True
|
is_xpu = True
|
||||||
|
@ -20,3 +20,7 @@ class XPUPlatform(Platform):
|
|||||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||||
device_props = torch.xpu.get_device_properties(device_id)
|
device_props = torch.xpu.get_device_properties(device_id)
|
||||||
return device_props.total_memory
|
return device_props.total_memory
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def inference_mode():
|
||||||
|
return torch.no_grad()
|
||||||
|
@ -14,7 +14,6 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
|||||||
SpeculativeConfig)
|
SpeculativeConfig)
|
||||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
init_distributed_environment)
|
init_distributed_environment)
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor import set_random_seed
|
from vllm.model_executor import set_random_seed
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
@ -183,11 +182,10 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
|
|||||||
# use sockets as default Level zero IPC exchange backend. By
|
# use sockets as default Level zero IPC exchange backend. By
|
||||||
# default oneccl will use `drmfd` as mechanism which need extra
|
# default oneccl will use `drmfd` as mechanism which need extra
|
||||||
# dependency (libdrm and drm headers) on your system.
|
# dependency (libdrm and drm headers) on your system.
|
||||||
ENV_CCL_ZE_IPC_EXCHANGE = os.getenv("CCL_ZE_IPC_EXCHANGE",
|
ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi")
|
||||||
"sockets")
|
|
||||||
ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE",
|
ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE",
|
||||||
str(parallel_config.world_size))
|
str(parallel_config.world_size))
|
||||||
os.environ['CCL_ZE_IPC_EXCHANGE'] = ENV_CCL_ZE_IPC_EXCHANGE
|
os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT
|
||||||
os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
|
os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
|
||||||
os.environ["LOCAL_RANK"] = str(self.local_rank)
|
os.environ["LOCAL_RANK"] = str(self.local_rank)
|
||||||
init_distributed_environment(
|
init_distributed_environment(
|
||||||
@ -200,8 +198,5 @@ class XPUWorker(LoraNotSupportedWorkerBase, Worker):
|
|||||||
ensure_model_parallel_initialized(
|
ensure_model_parallel_initialized(
|
||||||
parallel_config.tensor_parallel_size,
|
parallel_config.tensor_parallel_size,
|
||||||
parallel_config.pipeline_parallel_size)
|
parallel_config.pipeline_parallel_size)
|
||||||
|
# global all_reduce needed for overall oneccl warm up
|
||||||
if parallel_config.pipeline_parallel_size > 1:
|
torch.distributed.all_reduce(torch.zeros(1).xpu())
|
||||||
# torch-ccl xpu need a collective API warm up
|
|
||||||
# before calling send/recv API
|
|
||||||
get_pp_group().all_reduce(torch.zeros(1).xpu())
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user