Skip to content

Commit 145944c

Browse files
authored
Improve pipeline partitioning (#13839)
1 parent 094b7d9 commit 145944c

File tree

2 files changed

+46
-8
lines changed

2 files changed

+46
-8
lines changed

tests/distributed/test_pipeline_partition.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,27 @@ def _verify(partition_str, num_layers, pp_size, goldens):
3434
# Wrong number of layers
3535
with pytest.raises(ValueError):
3636
_verify("5,5,5,5", 21, 4, [(0, 5), (5, 10), (10, 15), (15, 20)])
37+
38+
39+
@pytest.mark.parametrize(
40+
"num_hidden_layers,pp_size,pp_rank,indices",
41+
[
42+
# pp_size 2
43+
(2, 2, 0, (0, 1)),
44+
(2, 2, 1, (1, 2)),
45+
(3, 2, 0, (0, 2)),
46+
(3, 2, 1, (2, 3)),
47+
# pp_size 3
48+
(3, 3, 0, (0, 1)),
49+
(3, 3, 1, (1, 2)),
50+
(3, 3, 2, (2, 3)),
51+
(4, 3, 0, (0, 1)),
52+
(4, 3, 1, (1, 3)),
53+
(4, 3, 2, (3, 4)),
54+
(5, 3, 0, (0, 2)),
55+
(5, 3, 1, (2, 4)),
56+
(5, 3, 2, (4, 5)),
57+
])
58+
def test_uneven_auto_partition(num_hidden_layers: int, pp_size: int,
59+
pp_rank: int, indices: tuple[int, int]):
60+
assert indices == get_pp_indices(num_hidden_layers, pp_rank, pp_size)

vllm/distributed/utils.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,17 @@ def split_tensor_along_last_dim(
6767
def get_pp_indices(num_hidden_layers: int, pp_rank: int,
6868
pp_size: int) -> Tuple[int, int]:
6969
"""Try to evenly distribute layers across partitions.
70+
7071
If the number of layers is not divisible by the number of partitions,
71-
the last partition will have the remaining layers.
72+
the remaining layers are evenly distributed across all but the last
73+
partition. The last partition is excluded because it often contains an
74+
additional norm layer and we are attempting to balance compute.
75+
76+
If `pp_size > 2` and the number of remaining layers is
77+
`0 < x <= pp_size - 2` then the remaining layers are evenly distributed
78+
across the middle partitions. The first and last partitions are excluded
79+
because they contain the input and output embeddings respectively and we
80+
are attempting to reduce maximum memory consumption across partitions.
7281
"""
7382
partition_list_str = envs.VLLM_PP_LAYER_PARTITION
7483
if partition_list_str is not None:
@@ -84,15 +93,20 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int,
8493
if sum(partitions) != num_hidden_layers:
8594
raise ValueError(
8695
f"{sum(partitions)=} does not match {num_hidden_layers=}.")
87-
start_layer = sum(partitions[:pp_rank])
88-
end_layer = start_layer + partitions[pp_rank]
8996
else:
9097
layers_per_partition = num_hidden_layers // pp_size
91-
start_layer = pp_rank * layers_per_partition
92-
end_layer = start_layer + layers_per_partition
93-
94-
if pp_rank == pp_size - 1:
95-
end_layer = num_hidden_layers
98+
partitions = [layers_per_partition for _ in range(pp_size)]
99+
100+
if remaining_layers := num_hidden_layers % pp_size:
101+
for i in range(2, remaining_layers + 2):
102+
partitions[-i] += 1
103+
logger.info("Hidden layers were unevenly partitioned: %s",
104+
",".join(str(p) for p in partitions))
105+
logger.info("This can be manually overridden using the "
106+
"VLLM_PP_LAYER_PARTITION environment variable")
107+
108+
start_layer = sum(partitions[:pp_rank])
109+
end_layer = start_layer + partitions[pp_rank]
96110

97111
return (start_layer, end_layer)
98112

0 commit comments

Comments
 (0)