diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 540a35e1..e195a03c 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -27,7 +27,6 @@ batchsize_forward_time: defaultdict = defaultdict(list) @dataclass class DPMetadata: - num_tokens_across_dp: list[int] cu_tokens_across_dp_cpu: torch.Tensor @@ -89,7 +88,7 @@ def set_forward_context(attn_metadata: Any, from vllm.distributed.parallel_state import get_dp_group dist.all_reduce(num_tokens_tensor, group=get_dp_group().cpu_group) cu_tokens_across_dp_cpu = torch.cumsum(num_tokens_tensor, dim=0) - dp_metadata = DPMetadata(num_tokens_across_dp, cu_tokens_across_dp_cpu) + dp_metadata = DPMetadata(cu_tokens_across_dp_cpu) global _forward_context prev_context = _forward_context