Skip to content

Commit 32989d8

Browse files
committed
add pattern for final allreduce in model
Signed-off-by: Luka Govedič <[email protected]>
1 parent 5619bc3 commit 32989d8

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

vllm/compilation/collective_fusion.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,18 @@ def replacement(
775775
pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass
776776
)
777777

778+
# Same pattern, but only return the output and not residual
779+
# (helpful for end of graph where residual is not used again)
780+
first_return_only = lambda fn: lambda a, b, c: fn(a, b, c)[0]
781+
782+
pm.register_replacement(
783+
first_return_only(pattern),
784+
first_return_only(replacement),
785+
self.get_inputs(),
786+
pm.fwd_only,
787+
pm_pass,
788+
)
789+
778790

779791
class AllReduceFusedRMSNormStaticQuantFP8Pattern(BasePattern):
780792
"""

0 commit comments

Comments
 (0)