11import itertools
2+ import operator
23import sys
34from collections import Counter , defaultdict , deque
45from collections .abc import Generator
5- from functools import cache
6+ from functools import cache , reduce
67from typing import TypeVar
78from warnings import warn
89
1617from pytensor .graph .features import ReplaceValidate
1718from pytensor .graph .fg import Output
1819from pytensor .graph .rewriting .basic import (
19- EquilibriumGraphRewriter ,
2020 GraphRewriter ,
2121 copy_stack_trace ,
2222 in2out ,
2323 node_rewriter ,
24+ out2in ,
2425)
2526from pytensor .graph .rewriting .db import SequenceDB
2627from pytensor .graph .utils import InconsistencyError , MethodNotDefined
2930 MakeVector ,
3031 alloc ,
3132 cast ,
33+ constant ,
3234 get_underlying_scalar_constant_value ,
3335)
3436from pytensor .tensor .elemwise import CAReduce , DimShuffle , Elemwise
3537from pytensor .tensor .exceptions import NotScalarConstantError
36- from pytensor .tensor .math import exp
38+ from pytensor .tensor .math import add , exp , mul
3739from pytensor .tensor .rewriting .basic import (
3840 alloc_like ,
41+ broadcasted_by ,
3942 register_canonicalize ,
4043 register_specialize ,
4144)
@@ -542,8 +545,8 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
542545 return rval
543546
544547
545- @node_rewriter ([Elemwise ])
546- def local_add_mul_fusion (fgraph , node ):
548+ @node_rewriter ([add , mul ])
549+ def flatten_nested_add_mul (fgraph , node ):
547550 """Fuse consecutive add or mul in one such node with more inputs.
548551
549552 It is better to fuse add/mul that way then in a Composite node as
@@ -554,27 +557,16 @@ def local_add_mul_fusion(fgraph, node):
554557 This rewrite is almost useless after the AlgebraicCanonizer is used,
555558 but it catches a few edge cases that are not canonicalized by it
556559 """
557- if not (
558- isinstance (node .op , Elemwise ) and isinstance (node .op .scalar_op , ps .Add | ps .Mul )
559- ):
560- return False
561-
562- s_op = node .op .scalar_op .__class__
560+ s_op = node .op .scalar_op
563561 new_inp = []
564562 fused = False
565- nb_inputs = len (node .inputs )
566- max_inputs = float ("inf" )
567- if hasattr (node .op , "max_inputs" ):
568- max_inputs = node .op .max_inputs (node )
569563 for inp in node .inputs :
570564 if (
571565 inp .owner
572566 and isinstance (inp .owner .op , Elemwise )
573- and isinstance (inp .owner .op .scalar_op , s_op )
574- and
567+ and inp .owner .op .scalar_op == s_op
575568 # Do not duplicate the operation.
576- len (fgraph .clients [inp ]) == 1
577- and (nb_inputs + len (inp .owner .inputs ) - 1 ) <= max_inputs
569+ and len (fgraph .clients [inp ]) == 1
578570 ):
579571 new_inp .extend (inp .owner .inputs )
580572 fused = True
@@ -590,7 +582,7 @@ def local_add_mul_fusion(fgraph, node):
590582 # Do the recursion here to help lower the number of
591583 # FusionOptimizer iteration.
592584 if output .owner :
593- output2 = local_add_mul_fusion .transform (fgraph , output .owner )
585+ output2 = flatten_nested_add_mul .transform (fgraph , output .owner )
594586 if output2 :
595587 return output2
596588 return [output ]
@@ -1237,6 +1229,76 @@ def local_inline_composite_constants(fgraph, node):
12371229 return new_outputs
12381230
12391231
1232+ @node_rewriter (tracks = [add , mul ])
1233+ def constant_fold_branches_of_add_mul (fgraph , node ):
1234+ old_constants = [inp for inp in node .inputs if isinstance (inp , TensorConstant )]
1235+
1236+ if len (old_constants ) <= 1 :
1237+ return None
1238+
1239+ new_constants = old_constants .copy ()
1240+
1241+ # Multiply constants if it doesn't result in higher intermediate memory
1242+ while True :
1243+ n_constants = len (new_constants )
1244+ if n_constants <= 1 :
1245+ break
1246+
1247+ for i in range (n_constants ):
1248+ reference_inp = new_constants [i ]
1249+ other_inps = []
1250+ for j in range (n_constants ):
1251+ if i == j :
1252+ continue
1253+ other_inp = new_constants [j ]
1254+ if not broadcasted_by (reference_inp , other_inp ):
1255+ other_inps .append (other_inp )
1256+ if other_inps :
1257+ python_op = operator .mul if node .op == mul else operator .add
1258+ folded_inputs = [reference_inp , * other_inps ]
1259+ new_inp = constant (
1260+ reduce (python_op , (const .data for const in folded_inputs ))
1261+ )
1262+ new_constants = [
1263+ new_inp ,
1264+ * (inp for inp in new_constants if inp not in folded_inputs ),
1265+ ]
1266+ break
1267+ else : # no-break
1268+ break
1269+
1270+ if len (new_constants ) == len (old_constants ):
1271+ return None
1272+
1273+ non_constants = [inp for inp in node .inputs if not isinstance (inp , TensorConstant )]
1274+ new_out = node .op (
1275+ * new_constants ,
1276+ * non_constants ,
1277+ )
1278+ copy_stack_trace (node .outputs [0 ], new_out )
1279+ return [new_out ]
1280+
1281+
1282+ add_mul_flat_seqopt = SequenceDB ()
1283+ compile .optdb .register (
1284+ "add_mul_flat" ,
1285+ add_mul_flat_seqopt ,
1286+ "fast_run" ,
1287+ position = 48 , # Before Elemwise fusion
1288+ )
1289+ add_mul_flat_seqopt .register (
1290+ flatten_nested_add_mul .__name__ ,
1291+ out2in (flatten_nested_add_mul , ignore_newtrees = False ),
1292+ "fast_run" ,
1293+ position = 0 ,
1294+ )
1295+ add_mul_flat_seqopt .register (
1296+ constant_fold_branches_of_add_mul .__name__ ,
1297+ in2out (constant_fold_branches_of_add_mul , ignore_newtrees = True ),
1298+ "fast_run" ,
1299+ position = 1 ,
1300+ )
1301+
12401302# Register fusion database just before AddDestroyHandler(49.5) (inplace rewrites)
12411303fuse_seqopt = SequenceDB ()
12421304compile .optdb .register (
@@ -1248,14 +1310,6 @@ def local_inline_composite_constants(fgraph, node):
12481310 "FusionOptimizer" ,
12491311 position = 49 ,
12501312)
1251-
1252- fuse_seqopt .register (
1253- "local_add_mul_fusion" ,
1254- EquilibriumGraphRewriter (rewriters = [local_add_mul_fusion ], max_use_ratio = 1000 ),
1255- "fast_run" ,
1256- "fusion" ,
1257- position = 0 ,
1258- )
12591313fuse_seqopt .register (
12601314 "composite_elemwise_fusion" ,
12611315 FusionOptimizer (),
@@ -1279,7 +1333,7 @@ def local_inline_composite_constants(fgraph, node):
12791333)
12801334fuse_seqopt .register (
12811335 "local_inline_composite_constants" ,
1282- in2out (local_inline_composite_constants ),
1336+ in2out (local_inline_composite_constants , ignore_newtrees = True ),
12831337 "fast_run" ,
12841338 "fusion" ,
12851339 position = 20 ,
0 commit comments