diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index fa493fef..e0eeeffb 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -194,9 +194,11 @@ class GroupCoordinator: from vllm.platforms import current_platform - # TODO: fix it for other platforms if current_platform.is_cuda_alike(): self.device = torch.device(f"cuda:{local_rank}") + elif current_platform.is_out_of_tree(): + self.device = torch.device( + f"{current_platform.device_name}:{local_rank}") else: self.device = torch.device("cpu")