@@ -94,10 +94,59 @@ class PassConfig:
9494 dictionary mapping each world size to the threshold in MB
9595 { <world size>: <max size in mb> }
9696 Unspecified world sizes will fallback to
97- { 2: 32, 4: 32, 8: 2 }"""
97+ _FI_ALLREDUCE_MAX_INPUT_SIZES = {
98+ "9.0": {
99+ 2: 64 * MiB, # 64MB
100+ 4: 2 * MiB, # 2MB
101+ 8: 1 * MiB, # 1MB
102+ },
103+ "10.0": {
104+ 2: 64 * MiB, # 64MB
105+ 4: 32 * MiB, # 32MB
106+ 8: 1 * MiB, # 1MB
107+ },
108+ }, where key is the device capability"""
98109
99110 # TODO(luka) better pass enabling system.
100111
112+ def flashinfer_max_size (self , world_size : int ) -> Optional [int ]:
113+ """
114+ Returns the max communication size in bytes for flashinfer
115+ allreduce fusion for the given world size. Falls back to
116+ conservative defaults if the world size is not specified in config.
117+ """
118+
119+ # import here to avoid circular dependencies
120+ from vllm .platforms import current_platform
121+ MiB = 1024 * 1024
122+
123+ # Max size of the input tensor per world size per device capability
124+ # to use flashinfer fused allreduce
125+ _FI_ALLREDUCE_MAX_INPUT_SIZES = {
126+ "9.0" : {
127+ 2 : 64 * MiB , # 64MB
128+ 4 : 2 * MiB , # 2MB
129+ 8 : 1 * MiB , # 1MB
130+ },
131+ "10.0" : {
132+ 2 : 64 * MiB , # 64MB
133+ 4 : 32 * MiB , # 32MB
134+ 8 : 1 * MiB , # 1MB
135+ },
136+ }
137+
138+ device_capability = current_platform .get_device_capability (
139+ ).as_version_str ()
140+ max_sizes = _FI_ALLREDUCE_MAX_INPUT_SIZES .get (device_capability , {})
141+ max_sizes .update ({
142+ k : int (v * MiB )
143+ for k , v in self .fi_allreduce_fusion_max_size_mb .items ()
144+ })
145+ if world_size not in max_sizes :
146+ # FlashInfer doesn't support other world sizes
147+ return None
148+ return max_sizes [world_size ]
149+
101150 def uuid (self ):
102151 """
103152 Produces a hash unique to the pass configuration.
@@ -223,9 +272,11 @@ class CompilationConfig:
223272 compile_ranges_split_points : Optional [list [int ]] = None
224273 """Split points that represent compile ranges for inductor.
225274 The compile ranges are
226- [1, split_points[0]],
227- [split_points[0], split_points[1]], ...,
228- [split_points[-1], max_num_batched_tokens].
275+ [1, split_points[0]),
276+ [split_points[0], split_points[1]), ...,
277+ [split_points[-1], max_num_batched_tokens + 1).
278+ Compile sizes are also used single element ranges:
279+ [compile_sizes[i], compile_sizes[i] + 1).
229280 """
230281
231282 inductor_compile_config : dict = field (default_factory = dict )
@@ -579,3 +630,22 @@ def set_splitting_ops_for_v1(self):
579630 def splitting_ops_contain_attention (self ) -> bool :
580631 return self .splitting_ops is not None and all (
581632 op in self .splitting_ops for op in self ._attention_ops )
633+
634+ def get_compile_ranges (self ) -> list [tuple [int , int ]]:
635+ """Get the compile ranges for the compilation config."""
636+ compile_ranges_split_points = self .compile_ranges_split_points
637+ compile_ranges = []
638+ # max_num_batched_tokens + 1
639+ max_split_point = max (compile_ranges_split_points )
640+ split_points = sorted (
641+ set (self .compile_sizes ).union (set (
642+ self .compile_ranges_split_points )))
643+ split_points = split_points .filter (lambda x : x <= max_split_point )
644+ for i , s in enumerate (split_points ):
645+ if i == 0 :
646+ self .compile_ranges .append ((1 , s ))
647+ else :
648+ self .compile_ranges .append ((split_points [i - 1 ], s ))
649+ if s in self .compile_sizes and s != 1 :
650+ self .compile_ranges .append ((s , s ))
651+ return sorted (compile_ranges )
0 commit comments