@@ -67,8 +67,17 @@ def split_tensor_along_last_dim(
6767def get_pp_indices (num_hidden_layers : int , pp_rank : int ,
6868 pp_size : int ) -> Tuple [int , int ]:
6969 """Try to evenly distribute layers across partitions.
70+
7071 If the number of layers is not divisible by the number of partitions,
71- the last partition will have the remaining layers.
72+ the remaining layers are evenly distributed across all but the last
73+ partition. The last partition is excluded because it often contains an
74+ additional norm layer and we are attempting to balance compute.
75+
76+ If `pp_size > 2` and the number of remaining layers is
77+ `0 < x <= pp_size - 2` then the remaining layers are evenly distributed
78+ across the middle partitions. The first and last partitions are excluded
79+ because they contain the input and output embeddings respectively and we
80+ are attempting to reduce maximum memory consumption across partitions.
7281 """
7382 partition_list_str = envs .VLLM_PP_LAYER_PARTITION
7483 if partition_list_str is not None :
@@ -84,15 +93,20 @@ def get_pp_indices(num_hidden_layers: int, pp_rank: int,
8493 if sum (partitions ) != num_hidden_layers :
8594 raise ValueError (
8695 f"{ sum (partitions )= } does not match { num_hidden_layers = } ." )
87- start_layer = sum (partitions [:pp_rank ])
88- end_layer = start_layer + partitions [pp_rank ]
8996 else :
9097 layers_per_partition = num_hidden_layers // pp_size
91- start_layer = pp_rank * layers_per_partition
92- end_layer = start_layer + layers_per_partition
93-
94- if pp_rank == pp_size - 1 :
95- end_layer = num_hidden_layers
98+ partitions = [layers_per_partition for _ in range (pp_size )]
99+
100+ if remaining_layers := num_hidden_layers % pp_size :
101+ for i in range (2 , remaining_layers + 2 ):
102+ partitions [- i ] += 1
103+ logger .info ("Hidden layers were unevenly partitioned: %s" ,
104+ "," .join (str (p ) for p in partitions ))
105+ logger .info ("This can be manually overridden using the "
106+ "VLLM_PP_LAYER_PARTITION environment variable" )
107+
108+ start_layer = sum (partitions [:pp_rank ])
109+ end_layer = start_layer + partitions [pp_rank ]
96110
97111 return (start_layer , end_layer )
98112
0 commit comments