Improve pipeline partitioning (#13839)
This commit is contained in:
parent
094b7d9496
commit
145944cb94
@ -34,3 +34,27 @@ def test_custom_layer_partition():
|
|||||||
# Wrong number of layers
|
# Wrong number of layers
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
_verify("5,5,5,5", 21, 4, [(0, 5), (5, 10), (10, 15), (15, 20)])
|
_verify("5,5,5,5", 21, 4, [(0, 5), (5, 10), (10, 15), (15, 20)])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"num_hidden_layers,pp_size,pp_rank,indices",
|
||||||
|
[
|
||||||
|
# pp_size 2
|
||||||
|
(2, 2, 0, (0, 1)),
|
||||||
|
(2, 2, 1, (1, 2)),
|
||||||
|
(3, 2, 0, (0, 2)),
|
||||||
|
(3, 2, 1, (2, 3)),
|
||||||
|
# pp_size 3
|
||||||
|
(3, 3, 0, (0, 1)),
|
||||||
|
(3, 3, 1, (1, 2)),
|
||||||
|
(3, 3, 2, (2, 3)),
|
||||||
|
(4, 3, 0, (0, 1)),
|
||||||
|
(4, 3, 1, (1, 3)),
|
||||||
|
(4, 3, 2, (3, 4)),
|
||||||
|
(5, 3, 0, (0, 2)),
|
||||||
|
(5, 3, 1, (2, 4)),
|
||||||
|
(5, 3, 2, (4, 5)),
|
||||||
|
])
|
||||||
|
def test_uneven_auto_partition(num_hidden_layers: int, pp_size: int,
|
||||||
|
pp_rank: int, indices: tuple[int, int]):
|
||||||
|
assert indices == get_pp_indices(num_hidden_layers, pp_rank, pp_size)
|
||||||
|
@ -67,8 +67,17 @@ def split_tensor_along_last_dim(
|
|||||||
def get_pp_indices(num_hidden_layers: int, pp_rank: int,
|
def get_pp_indices(num_hidden_layers: int, pp_rank: int,
|
||||||
pp_size: int) -> Tuple[int, int]:
|
pp_size: int) -> Tuple[int, int]:
|
||||||
"""Try to evenly distribute layers across partitions.
|
"""Try to evenly distribute layers across partitions.
|
||||||
|
|
||||||
If the number of layers is not divisible by the number of partitions,
|
If the number of layers is not divisible by the number of partitions,
|
||||||
the last partition will have the remaining layers.
|
the remaining layers are evenly distributed across all but the last
|
||||||
|
partition. The last partition is excluded because it often contains an
|
||||||
|
additional norm layer and we are attempting to balance compute.
|
||||||
|
|
||||||
|
If `pp_size > 2` and the number of remaining layers is
|
||||||
|
`0 < x <= pp_size - 2` then the remaining layers are evenly distributed
|
||||||
|
across the middle partitions. The first and last partitions are excluded
|
||||||
|
because they contain the input and output embeddings respectively and we
|
||||||
|
are attempting to reduce maximum memory consumption across partitions.
|
||||||
"""
|
"""
|
||||||
partition_list_str = envs.VLLM_PP_LAYER_PARTITION
|
partition_list_str = envs.VLLM_PP_LAYER_PARTITION
|
||||||
if partition_list_str is not None:
|
if partition_list_str is not None:
|
||||||
@ -84,15 +93,20 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int,
|
|||||||
if sum(partitions) != num_hidden_layers:
|
if sum(partitions) != num_hidden_layers:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{sum(partitions)=} does not match {num_hidden_layers=}.")
|
f"{sum(partitions)=} does not match {num_hidden_layers=}.")
|
||||||
start_layer = sum(partitions[:pp_rank])
|
|
||||||
end_layer = start_layer + partitions[pp_rank]
|
|
||||||
else:
|
else:
|
||||||
layers_per_partition = num_hidden_layers // pp_size
|
layers_per_partition = num_hidden_layers // pp_size
|
||||||
start_layer = pp_rank * layers_per_partition
|
partitions = [layers_per_partition for _ in range(pp_size)]
|
||||||
end_layer = start_layer + layers_per_partition
|
|
||||||
|
|
||||||
if pp_rank == pp_size - 1:
|
if remaining_layers := num_hidden_layers % pp_size:
|
||||||
end_layer = num_hidden_layers
|
for i in range(2, remaining_layers + 2):
|
||||||
|
partitions[-i] += 1
|
||||||
|
logger.info("Hidden layers were unevenly partitioned: %s",
|
||||||
|
",".join(str(p) for p in partitions))
|
||||||
|
logger.info("This can be manually overridden using the "
|
||||||
|
"VLLM_PP_LAYER_PARTITION environment variable")
|
||||||
|
|
||||||
|
start_layer = sum(partitions[:pp_rank])
|
||||||
|
end_layer = start_layer + partitions[pp_rank]
|
||||||
|
|
||||||
return (start_layer, end_layer)
|
return (start_layer, end_layer)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user