@@ -1062,8 +1062,6 @@ public static void pushdownStreamSelections(final List<RewriterRule> rules, fina
10621062 );
10631063 });
10641064
1065-
1066-
10671065 rules .add (new RewriterRuleBuilder (ctx , "sum(sum(v)) => sum(v)" )
10681066 .setUnidirectional (true )
10691067 .parseGlobalVars ("MATRIX:A,B" )
@@ -1151,6 +1149,85 @@ public static void pushdownStreamSelections(final List<RewriterRule> rules, fina
11511149 }
11521150 }
11531151
1152+ // This expands the statements to a common canonical form
1153+ // It is important, however, that
1154+ public static void canonicalExpandAfterFlattening (final List <RewriterRule > rules , final RuleContext ctx ) {
1155+ HashMap <Integer , RewriterStatement > hooks = new HashMap <>();
1156+
1157+ rules .add (new RewriterRuleBuilder (ctx , "sum($1:_idxExpr(indices, -(A))) => -(sum($2:_idxExpr(indices, A)))" )
1158+ .setUnidirectional (true )
1159+ .parseGlobalVars ("FLOAT:a" )
1160+ .parseGlobalVars ("INT...:indices" )
1161+ .withParsedStatement ("sum($1:_idxExpr(indices, -(a)))" , hooks )
1162+ .toParsedStatement ("-(sum($2:_idxExpr(indices, a)))" , hooks )
1163+ .link (hooks .get (1 ).getId (), hooks .get (2 ).getId (), RewriterStatement ::transferMeta )
1164+ .build ()
1165+ );
1166+
1167+ rules .add (new RewriterRuleBuilder (ctx , "sum($1:_idxExpr(indices, -(a))) => -(sum($2:_idxExpr(indices, a)))" )
1168+ .setUnidirectional (true )
1169+ .parseGlobalVars ("INT:a" )
1170+ .parseGlobalVars ("INT...:indices" )
1171+ .withParsedStatement ("sum($1:_idxExpr(indices, -(a)))" , hooks )
1172+ .toParsedStatement ("-(sum($2:_idxExpr(indices, a)))" , hooks )
1173+ .link (hooks .get (1 ).getId (), hooks .get (2 ).getId (), RewriterStatement ::transferMeta )
1174+ .build ()
1175+ );
1176+
1177+ rules .add (new RewriterRuleBuilder (ctx , "sum(_idxExpr(indices, +(ops))) => +(argList(sum(_idxExpr(indices, op1)), sum(_idxExpr(...)), ...))" )
1178+ .setUnidirectional (true )
1179+ .parseGlobalVars ("INT...:indices" )
1180+ .parseGlobalVars ("FLOAT...:ops" )
1181+ .withParsedStatement ("sum($1:_idxExpr(indices, +(ops)))" , hooks )
1182+ .toParsedStatement ("+($3:argList(sum($2:_idxExpr(indices, +(ops)))))" , hooks ) // The inner +(ops) is temporary and will be removed
1183+ .link (hooks .get (1 ).getId (), hooks .get (2 ).getId (), RewriterStatement ::transferMeta )
1184+ .apply (hooks .get (3 ).getId (), newArgList -> {
1185+ RewriterStatement oldArgList = newArgList .getChild (0 , 0 , 1 , 0 );
1186+ newArgList .getChild (0 , 0 ).getOperands ().set (1 , oldArgList .getChild (0 ));
1187+
1188+ for (int i = 1 ; i < oldArgList .getOperands ().size (); i ++) {
1189+ RewriterStatement newIdxExpr = newArgList .getChild (0 , 0 ).copyNode ();
1190+ newIdxExpr .getOperands ().set (1 , oldArgList .getChild (i ));
1191+ RewriterStatement newSum = new RewriterInstruction ()
1192+ .as (UUID .randomUUID ().toString ())
1193+ .withInstruction ("sum" )
1194+ .withOps (newIdxExpr )
1195+ .consolidate (ctx );
1196+ RewriterUtils .copyIndexList (newIdxExpr );
1197+ newArgList .getOperands ().add (newSum );
1198+ }
1199+ }, true )
1200+ .build ()
1201+ );
1202+ }
1203+
1204+ public static void flattenedAlgebraRewrites (final List <RewriterRule > rules , final RuleContext ctx ) {
1205+ HashMap <Integer , RewriterStatement > hooks = new HashMap <>();
1206+
1207+ // Minus pushdown
1208+ rules .add (new RewriterRuleBuilder (ctx , "-(+(...)) => +(-(el1), -(el2), ...)" )
1209+ .setUnidirectional (true )
1210+ .parseGlobalVars ("FLOAT...:ops" )
1211+ .withParsedStatement ("-(+(ops))" , hooks )
1212+ .toParsedStatement ("$1:+(ops)" , hooks ) // Temporary
1213+ .apply (hooks .get (1 ).getId (), (stmt , match ) -> {
1214+ RewriterStatement argList = stmt .getChild (0 );
1215+
1216+ for (int i = 0 ; i < argList .getOperands ().size (); i ++) {
1217+ RewriterInstruction newStmt = new RewriterInstruction ().as (UUID .randomUUID ().toString ()).withInstruction ("-" ).withOps (argList .getOperands ().get (i ));
1218+ newStmt .consolidate (ctx );
1219+ argList .getOperands ().set (i , newStmt );
1220+ }
1221+
1222+ // TODO: This is inefficient
1223+ RewriterUtils .tryFlattenNestedOperatorPatterns (ctx , match .getNewExprRoot ());
1224+ }, true )
1225+ .build ()
1226+ );
1227+
1228+ // TODO: Distributive law
1229+ }
1230+
11541231 public static void buildElementWiseAlgebraicCanonicalization (final List <RewriterRule > rules , final RuleContext ctx ) {
11551232 RewriterUtils .buildTernaryPermutations (List .of ("FLOAT" , "INT" , "BOOL" ), (t1 , t2 , t3 ) -> {
11561233 rules .add (new RewriterRuleBuilder (ctx , "*(+(a, b), c) => +(*(a, c), *(b, c))" )
@@ -1215,23 +1292,37 @@ public static void flattenOperations(final List<RewriterRule> rules, final RuleC
12151292 HashMap <Integer , RewriterStatement > hooks = new HashMap <>();
12161293
12171294 RewriterUtils .buildBinaryPermutations (List .of ("INT" , "INT..." ), (t1 , t2 ) -> {
1218- rules .add (new RewriterRuleBuilder (ctx )
1219- .setUnidirectional (true )
1220- .parseGlobalVars (t1 + ":i" )
1221- .parseGlobalVars (t2 + ":j" )
1222- .parseGlobalVars ("FLOAT:v" )
1223- .withParsedStatement ("$1:_idxExpr(i, $2:_idxExpr(j, v))" , hooks )
1224- .toParsedStatement ("$3:_idxExpr(argList(i, j), v)" , hooks )
1225- .link (hooks .get (1 ).getId (), hooks .get (3 ).getId (), RewriterStatement ::transferMeta )
1226- .apply (hooks .get (3 ).getId (), (stmt , match ) -> {
1227- UUID newOwnerId = (UUID )stmt .getMeta ("ownerId" );
1295+ for (String t3 : List .of ("FLOAT" , "FLOAT*" , "INT" , "INT*" , "BOOL" , "BOOL*" )) {
1296+ rules .add (new RewriterRuleBuilder (ctx )
1297+ .setUnidirectional (true )
1298+ .parseGlobalVars (t1 + ":i" )
1299+ .parseGlobalVars (t2 + ":j" )
1300+ .parseGlobalVars (t3 + ":v" )
1301+ .withParsedStatement ("$1:_idxExpr(i, $2:_idxExpr(j, v))" , hooks )
1302+ .toParsedStatement ("$3:_idxExpr(argList(i, j), v)" , hooks )
1303+ .link (hooks .get (1 ).getId (), hooks .get (3 ).getId (), RewriterStatement ::transferMeta )
1304+ .apply (hooks .get (3 ).getId (), (stmt , match ) -> {
1305+ UUID newOwnerId = (UUID ) stmt .getMeta ("ownerId" );
12281306
1229- if (newOwnerId == null )
1230- throw new IllegalArgumentException ();
1307+ if (newOwnerId == null )
1308+ throw new IllegalArgumentException ();
12311309
1232- stmt .getOperands ().get (0 ).getOperands ().get (1 ).unsafePutMeta ("ownerId" , newOwnerId );
1233- }, true )
1234- .build ());
1310+ stmt .getOperands ().get (0 ).getOperands ().get (1 ).unsafePutMeta ("ownerId" , newOwnerId );
1311+ }, true )
1312+ .build ());
1313+
1314+ if (t1 .equals ("INT" )) {
1315+ // This must be executed after the rule above
1316+ rules .add (new RewriterRuleBuilder (ctx )
1317+ .setUnidirectional (true )
1318+ .parseGlobalVars (t1 + ":i" )
1319+ .parseGlobalVars (t3 + ":v" )
1320+ .withParsedStatement ("$1:_idxExpr(i, v)" , hooks )
1321+ .toParsedStatement ("$3:_idxExpr(argList(i), v)" , hooks )
1322+ .link (hooks .get (1 ).getId (), hooks .get (3 ).getId (), RewriterStatement ::transferMeta )
1323+ .build ());
1324+ }
1325+ }
12351326 });
12361327
12371328 RewriterUtils .buildBinaryPermutations (List .of ("MATRIX" , "INT" , "FLOAT" , "BOOL" ), (t1 , t2 ) -> {
0 commit comments