diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 2f39a0e8..4a359725 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -8,7 +8,6 @@ import torch import torch.nn as nn import vllm.envs as envs -from vllm.config import get_current_vllm_config from vllm.distributed import (tensor_model_parallel_all_gather, tensor_model_parallel_gather) from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -51,11 +50,7 @@ class LogitsProcessor(nn.Module): # Soft cap the logits. Used in Gemma 2. self.soft_cap = soft_cap # Whether to use gather or all-gather to gather the logits. - parallel_config = get_current_vllm_config().parallel_config - self.use_all_gather = current_platform.is_tpu() \ - or current_platform.is_neuron() \ - or envs.VLLM_USE_V1 \ - or parallel_config.distributed_executor_backend == "external_launcher" # noqa + self.use_all_gather = current_platform.use_all_gather() def forward( self, @@ -83,7 +78,8 @@ class LogitsProcessor(nn.Module): logits *= self.scale # Apply logits processors (if any). - if sampling_metadata is not None: + if sampling_metadata is not None and \ + sampling_metadata.seq_groups is not None: logits = _apply_logits_processors(logits, sampling_metadata) return logits diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index d81a66e4..e7e55e11 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -330,6 +330,19 @@ class Platform: """ return "vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase" # noqa + @classmethod + def use_all_gather(cls) -> bool: + """ + Whether to use allgather in LogitsProcessor to gather the logits. + """ + import vllm.envs as envs + from vllm.config import get_current_vllm_config + + parallel_config = get_current_vllm_config().parallel_config + return (envs.VLLM_USE_V1 + or parallel_config.distributed_executor_backend + == "external_launcher") + class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 5a03f5f7..b2eadb79 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -55,3 +55,7 @@ class NeuronPlatform(Platform): def is_pin_memory_available(cls) -> bool: logger.warning("Pin memory is not supported on Neuron.") return False + + @classmethod + def use_all_gather(cls) -> bool: + return True diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index cdf835a5..0b66b527 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -119,3 +119,7 @@ class TpuPlatform(Platform): @classmethod def get_device_communicator_cls(cls) -> str: return "vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa + + @classmethod + def use_all_gather(cls) -> bool: + return True