Improve pipeline partitioning (#13839)

This commit is contained in:
Harry Mellor 2025-02-26 02:53:56 +00:00 committed by GitHub
parent 094b7d9496
commit 145944cb94
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 45 additions and 7 deletions

View File

@ -34,3 +34,27 @@ def test_custom_layer_partition():
# Wrong number of layers # Wrong number of layers
with pytest.raises(ValueError): with pytest.raises(ValueError):
_verify("5,5,5,5", 21, 4, [(0, 5), (5, 10), (10, 15), (15, 20)]) _verify("5,5,5,5", 21, 4, [(0, 5), (5, 10), (10, 15), (15, 20)])
@pytest.mark.parametrize(
"num_hidden_layers,pp_size,pp_rank,indices",
[
# pp_size 2
(2, 2, 0, (0, 1)),
(2, 2, 1, (1, 2)),
(3, 2, 0, (0, 2)),
(3, 2, 1, (2, 3)),
# pp_size 3
(3, 3, 0, (0, 1)),
(3, 3, 1, (1, 2)),
(3, 3, 2, (2, 3)),
(4, 3, 0, (0, 1)),
(4, 3, 1, (1, 3)),
(4, 3, 2, (3, 4)),
(5, 3, 0, (0, 2)),
(5, 3, 1, (2, 4)),
(5, 3, 2, (4, 5)),
])
def test_uneven_auto_partition(num_hidden_layers: int, pp_size: int,
pp_rank: int, indices: tuple[int, int]):
assert indices == get_pp_indices(num_hidden_layers, pp_rank, pp_size)

View File

@ -67,8 +67,17 @@ def split_tensor_along_last_dim(
def get_pp_indices(num_hidden_layers: int, pp_rank: int, def get_pp_indices(num_hidden_layers: int, pp_rank: int,
pp_size: int) -> Tuple[int, int]: pp_size: int) -> Tuple[int, int]:
"""Try to evenly distribute layers across partitions. """Try to evenly distribute layers across partitions.
If the number of layers is not divisible by the number of partitions, If the number of layers is not divisible by the number of partitions,
the last partition will have the remaining layers. the remaining layers are evenly distributed across all but the last
partition. The last partition is excluded because it often contains an
additional norm layer and we are attempting to balance compute.
If `pp_size > 2` and the number of remaining layers is
`0 < x <= pp_size - 2` then the remaining layers are evenly distributed
across the middle partitions. The first and last partitions are excluded
because they contain the input and output embeddings respectively and we
are attempting to reduce maximum memory consumption across partitions.
""" """
partition_list_str = envs.VLLM_PP_LAYER_PARTITION partition_list_str = envs.VLLM_PP_LAYER_PARTITION
if partition_list_str is not None: if partition_list_str is not None:
@ -84,15 +93,20 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int,
if sum(partitions) != num_hidden_layers: if sum(partitions) != num_hidden_layers:
raise ValueError( raise ValueError(
f"{sum(partitions)=} does not match {num_hidden_layers=}.") f"{sum(partitions)=} does not match {num_hidden_layers=}.")
start_layer = sum(partitions[:pp_rank])
end_layer = start_layer + partitions[pp_rank]
else: else:
layers_per_partition = num_hidden_layers // pp_size layers_per_partition = num_hidden_layers // pp_size
start_layer = pp_rank * layers_per_partition partitions = [layers_per_partition for _ in range(pp_size)]
end_layer = start_layer + layers_per_partition
if pp_rank == pp_size - 1: if remaining_layers := num_hidden_layers % pp_size:
end_layer = num_hidden_layers for i in range(2, remaining_layers + 2):
partitions[-i] += 1
logger.info("Hidden layers were unevenly partitioned: %s",
",".join(str(p) for p in partitions))
logger.info("This can be manually overridden using the "
"VLLM_PP_LAYER_PARTITION environment variable")
start_layer = sum(partitions[:pp_rank])
end_layer = start_layer + partitions[pp_rank]
return (start_layer, end_layer) return (start_layer, end_layer)