@@ -765,7 +765,7 @@ public static void expandStreamingExpressions(final List<RewriterRule> rules, fi
765765 .parseGlobalVars ("LITERAL_INT:1" )
766766 .parseGlobalVars ("LITERAL_FLOAT:0.0" )
767767 .withParsedStatement ("diag(A)" , hooks )
768- .toParsedStatement ("$4:_m($1:_idx(1, nrow(A)), $2:_idx(1, ncol(A)), [](A, $1, $1 ))" , hooks )
768+ .toParsedStatement ("$4:_m($1:_idx(1, nrow(A)), $2:_idx(1, ncol(A)), ifelse(==($1,$2), [](A, $1, $2), 0.0 ))" , hooks )
769769 .apply (hooks .get (1 ).getId (), stmt -> stmt .unsafePutMeta ("idxId" , UUID .randomUUID ()), true ) // Assumes it will never collide
770770 .apply (hooks .get (2 ).getId (), stmt -> stmt .unsafePutMeta ("idxId" , UUID .randomUUID ()), true ) // Assumes it will never collide
771771 .apply (hooks .get (4 ).getId (), (stmt , match ) -> {
@@ -775,8 +775,8 @@ public static void expandStreamingExpressions(final List<RewriterRule> rules, fi
775775
776776 RewriterStatement aRef = stmt .getChild (0 , 1 , 0 );
777777
778- System .out .println ("GETTING: " );
779- System .out .println (match .getNewExprRoot ().getAssertions (ctx ).getAssertionStatement (aRef .getNCol (), null ));
778+ // System.out.println("GETTING: ");
779+ // System.out.println(match.getNewExprRoot().getAssertions(ctx).getAssertionStatement(aRef.getNCol(), null));
780780 match .getNewExprRoot ().getAssertions (ctx ).addEqualityAssertion (aRef .getNCol (), aRef .getNRow (), match .getNewExprRoot ());
781781 }, true ) // Assumes it will never collide
782782 .build ()
@@ -1163,6 +1163,18 @@ public static void expandArbitraryMatrices(final List<RewriterRule> rules, final
11631163 public static void pushdownStreamSelections (final List <RewriterRule > rules , final RuleContext ctx ) {
11641164 HashMap <Integer , RewriterStatement > hooks = new HashMap <>();
11651165
1166+ // ifelse merging
1167+ // TODO: Permutations e.g. ==(l2, l1) etc.
1168+ rules .add (new RewriterRuleBuilder (ctx )
1169+ .setUnidirectional (true )
1170+ .parseGlobalVars ("FLOAT:a,b,c,d" )
1171+ .parseGlobalVars ("INT:l1,l2" )
1172+ .withParsedStatement ("$1:ElementWiseInstruction(ifelse(==(l1, l2), a, b), ifelse(==(l1, l2), c, d))" , hooks )
1173+ .toParsedStatement ("ifelse(==(l1, l2), $2:ElementWiseInstruction(a, c), $3:ElementWiseInstruction(b, d))" , hooks )
1174+ .linkManyUnidirectional (hooks .get (1 ).getId (), List .of (hooks .get (2 ).getId (), hooks .get (3 ).getId ()), RewriterStatement ::transferMeta , true )
1175+ .build ()
1176+ );
1177+
11661178 rules .add (new RewriterRuleBuilder (ctx )
11671179 .setUnidirectional (true )
11681180 .parseGlobalVars ("INT:l" )
0 commit comments