Skip to content

Commit 82057db

Browse files
committed
Addressing comments
Signed-off-by: ilmarkov <[email protected]>
1 parent 703b9db commit 82057db

File tree

6 files changed

+103
-85
lines changed

6 files changed

+103
-85
lines changed

vllm/compilation/collective_fusion.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
from vllm.model_executor.layers.quantization.utils.quant_utils import (
2020
GroupShape)
2121
from vllm.platforms import current_platform
22-
from vllm.utils import (_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES,
23-
direct_register_custom_op, flashinfer_max_size)
22+
from vllm.utils import direct_register_custom_op
2423

2524
from .inductor_pass import enable_fake_mode
2625
from .vllm_inductor_pass import VllmInductorPass
@@ -398,6 +397,22 @@ def __call__(self, graph: fx.Graph):
398397
if flashinfer_comm is not None:
399398
_FI_WORKSPACE_TENSOR = None
400399

400+
MiB = 1024 * 1024
401+
# Max size of the input tensor per world size per device capability
402+
# to use flashinfer one shot fused allreduce
403+
_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES = {
404+
"9.0": {
405+
2: 32 * MiB, # 32MB
406+
4: 2 * MiB, # 2MB
407+
8: 1 * MiB, # 1MB
408+
},
409+
"10.0": {
410+
2: 32 * MiB, # 32MB
411+
4: 4 * MiB, # 4MB
412+
8: 1 * MiB, # 1MB
413+
},
414+
}
415+
401416
def call_trtllm_fused_allreduce_norm(
402417
allreduce_in: torch.Tensor,
403418
residual: torch.Tensor,
@@ -425,9 +440,11 @@ def call_trtllm_fused_allreduce_norm(
425440
f"element size {element_size}"
426441
device_capability = current_platform.get_device_capability(
427442
).as_version_str()
428-
max_sizes = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES.get(device_capability, {})
429443
# Get one shot input size limit for the current world size
430-
max_one_shot_size = max_sizes.get(world_size, None)
444+
# for the current device capability
445+
max_one_shot_size = _FI_ALLREDUCE_ONE_SHOT_MAX_SIZES. \
446+
get(device_capability, {}). \
447+
get(world_size, None)
431448
# Use one shot if no max size is specified
432449
use_oneshot = max_one_shot_size is None or \
433450
current_tensor_size <= max_one_shot_size
@@ -1449,7 +1466,8 @@ def __init__(self, config: VllmConfig):
14491466
"Flashinfer is not installed or comm module not found, "
14501467
"skipping allreduce fusion pass")
14511468
return
1452-
max_size = flashinfer_max_size(self.tp_size, config)
1469+
max_size = config.compilation_config.\
1470+
pass_config.flashinfer_max_size(self.tp_size)
14531471
if max_size is None:
14541472
# Flashinfer doesn't support current world size
14551473
logger.warning(

vllm/compilation/compiler_interface.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,8 @@ def set_inductor_config(config, compile_range):
555555
if isinstance(compile_range, tuple):
556556
# for a specific range of batchsizes, tuning triton kernel parameters
557557
# can be beneficial
558+
#TODO(luka): max autotune only present with -O3,
559+
# and this should live in config: https://github.com/vllm-project/vllm/issues/20283
558560
config["max_autotune"] = True
559561
config["coordinate_descent_tuning"] = True
560562

vllm/compilation/cuda_piecewise_backend.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -50,22 +50,7 @@ def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
5050

5151
self.is_full_graph = total_piecewise_compiles == 1
5252

53-
self.compile_sizes: set[int] = set(
54-
self.compilation_config.compile_sizes)
55-
self.compile_ranges_split_points: list[
56-
int] = self.compilation_config.compile_ranges_split_points
57-
self.compile_ranges = []
58-
split_points = sorted(
59-
set(self.compile_sizes).union(set(
60-
self.compile_ranges_split_points)))
61-
for i, s in enumerate(split_points):
62-
if i == 0:
63-
self.compile_ranges.append((1, s))
64-
else:
65-
self.compile_ranges.append((split_points[i - 1], s))
66-
if s in self.compile_sizes:
67-
self.compile_ranges.append((s, s))
68-
self.compile_ranges = sorted(self.compile_ranges)
53+
self.compile_ranges = self.compilation_config.get_compile_ranges()
6954
log_string = f"PiecewiseBackend: compile_ranges: {self.compile_ranges}"
7055
logger.debug_once(log_string)
7156

vllm/config/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from vllm.transformers_utils.s3_utils import S3Model
5151
from vllm.transformers_utils.utils import is_s3, maybe_model_redirect
5252
from vllm.utils import (LayerBlockType, LazyLoader, common_broadcastable_dtype,
53-
flashinfer_max_size, random_uuid)
53+
random_uuid)
5454

5555
if TYPE_CHECKING:
5656
from _typeshed import DataclassInstance
@@ -3877,7 +3877,8 @@ def _set_compile_ranges(self):
38773877
# Add the compile ranges for flashinfer
38783878
if compilation_config.pass_config.enable_fi_allreduce_fusion:
38793879
tp_size = self.parallel_config.tensor_parallel_size
3880-
max_size = flashinfer_max_size(tp_size, self)
3880+
max_size = compilation_config.pass_config.flashinfer_max_size(
3881+
tp_size)
38813882
if max_size is not None:
38823883
max_token_num = max_size // (
38833884
self.model_config.get_hidden_size() *

vllm/config/compilation.py

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,59 @@ class PassConfig:
9494
dictionary mapping each world size to the threshold in MB
9595
{ <world size>: <max size in mb> }
9696
Unspecified world sizes will fallback to
97-
{ 2: 32, 4: 32, 8: 2 }"""
97+
_FI_ALLREDUCE_MAX_INPUT_SIZES = {
98+
"9.0": {
99+
2: 64 * MiB, # 64MB
100+
4: 2 * MiB, # 2MB
101+
8: 1 * MiB, # 1MB
102+
},
103+
"10.0": {
104+
2: 64 * MiB, # 64MB
105+
4: 32 * MiB, # 32MB
106+
8: 1 * MiB, # 1MB
107+
},
108+
}, where key is the device capability"""
98109

99110
# TODO(luka) better pass enabling system.
100111

112+
def flashinfer_max_size(self, world_size: int) -> Optional[int]:
113+
"""
114+
Returns the max communication size in bytes for flashinfer
115+
allreduce fusion for the given world size. Falls back to
116+
conservative defaults if the world size is not specified in config.
117+
"""
118+
119+
# import here to avoid circular dependencies
120+
from vllm.platforms import current_platform
121+
MiB = 1024 * 1024
122+
123+
# Max size of the input tensor per world size per device capability
124+
# to use flashinfer fused allreduce
125+
_FI_ALLREDUCE_MAX_INPUT_SIZES = {
126+
"9.0": {
127+
2: 64 * MiB, # 64MB
128+
4: 2 * MiB, # 2MB
129+
8: 1 * MiB, # 1MB
130+
},
131+
"10.0": {
132+
2: 64 * MiB, # 64MB
133+
4: 32 * MiB, # 32MB
134+
8: 1 * MiB, # 1MB
135+
},
136+
}
137+
138+
device_capability = current_platform.get_device_capability(
139+
).as_version_str()
140+
max_sizes = _FI_ALLREDUCE_MAX_INPUT_SIZES.get(device_capability, {})
141+
max_sizes.update({
142+
k: int(v * MiB)
143+
for k, v in self.fi_allreduce_fusion_max_size_mb.items()
144+
})
145+
if world_size not in max_sizes:
146+
# FlashInfer doesn't support other world sizes
147+
return None
148+
return max_sizes[world_size]
149+
101150
def uuid(self):
102151
"""
103152
Produces a hash unique to the pass configuration.
@@ -223,9 +272,11 @@ class CompilationConfig:
223272
compile_ranges_split_points: Optional[list[int]] = None
224273
"""Split points that represent compile ranges for inductor.
225274
The compile ranges are
226-
[1, split_points[0]],
227-
[split_points[0], split_points[1]], ...,
228-
[split_points[-1], max_num_batched_tokens].
275+
[1, split_points[0]),
276+
[split_points[0], split_points[1]), ...,
277+
[split_points[-1], max_num_batched_tokens + 1).
278+
Compile sizes are also used single element ranges:
279+
[compile_sizes[i], compile_sizes[i] + 1).
229280
"""
230281

231282
inductor_compile_config: dict = field(default_factory=dict)
@@ -579,3 +630,22 @@ def set_splitting_ops_for_v1(self):
579630
def splitting_ops_contain_attention(self) -> bool:
580631
return self.splitting_ops is not None and all(
581632
op in self.splitting_ops for op in self._attention_ops)
633+
634+
def get_compile_ranges(self) -> list[tuple[int, int]]:
635+
"""Get the compile ranges for the compilation config."""
636+
compile_ranges_split_points = self.compile_ranges_split_points
637+
compile_ranges = []
638+
# max_num_batched_tokens + 1
639+
max_split_point = max(compile_ranges_split_points)
640+
split_points = sorted(
641+
set(self.compile_sizes).union(set(
642+
self.compile_ranges_split_points)))
643+
split_points = split_points.filter(lambda x: x <= max_split_point)
644+
for i, s in enumerate(split_points):
645+
if i == 0:
646+
self.compile_ranges.append((1, s))
647+
else:
648+
self.compile_ranges.append((split_points[i - 1], s))
649+
if s in self.compile_sizes and s != 1:
650+
self.compile_ranges.append((s, s))
651+
return sorted(compile_ranges)

vllm/utils/__init__.py

Lines changed: 0 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -87,64 +87,6 @@
8787
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
8888
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120
8989

90-
# Max communication size for flashinfer fused allreduce
91-
MiB = 1024 * 1024
92-
93-
# Max size of the input tensor per world size per device capability
94-
# to use flashinfer fused allreduce
95-
_FI_ALLREDUCE_MAX_INPUT_SIZES = {
96-
"9.0": {
97-
2: 64 * MiB, # 64MB
98-
4: 2 * MiB, # 2MB
99-
8: 1 * MiB, # 1MB
100-
},
101-
"10.0": {
102-
2: 64 * MiB, # 64MB
103-
4: 32 * MiB, # 32MB
104-
8: 1 * MiB, # 1MB
105-
},
106-
}
107-
108-
# Max size of the input tensor per world size per device capability
109-
# to use flashinfer one shot fused allreduce
110-
_FI_ALLREDUCE_ONE_SHOT_MAX_SIZES = {
111-
"9.0": {
112-
2: 32 * MiB, # 32MB
113-
4: 2 * MiB, # 2MB
114-
8: 1 * MiB, # 1MB
115-
},
116-
"10.0": {
117-
2: 32 * MiB, # 32MB
118-
4: 4 * MiB, # 4MB
119-
8: 1 * MiB, # 1MB
120-
},
121-
}
122-
123-
124-
def flashinfer_max_size(world_size: int, config: VllmConfig) -> Optional[int]:
125-
"""
126-
Returns the max communication size in bytes for flashinfer
127-
allreduce fusion for the given world size. Falls back to
128-
conservative defaults if the world size is not specified in config.
129-
"""
130-
131-
# import here to avoid circular dependencies
132-
from vllm.platforms import current_platform
133-
134-
device_capability = current_platform.get_device_capability(
135-
).as_version_str()
136-
max_sizes = _FI_ALLREDUCE_MAX_INPUT_SIZES.get(device_capability, {})
137-
max_sizes.update({
138-
k: int(v * MiB)
139-
for k, v in config.compilation_config.pass_config.
140-
fi_allreduce_fusion_max_size_mb.items()
141-
})
142-
if world_size not in max_sizes:
143-
# FlashInfer doesn't support other world sizes
144-
return None
145-
return max_sizes[world_size]
146-
147-
14890
# Exception strings for non-implemented encoder/decoder scenarios
14991

15092
# Reminder: Please update docs/features/compatibility_matrix.md

0 commit comments

Comments
 (0)