Skip to content

Commit 4751581

Browse files
committed
Some progress
1 parent f8250c1 commit 4751581

File tree

3 files changed

+20
-7
lines changed

3 files changed

+20
-7
lines changed

src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCollection.java

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,7 @@ public static void expandStreamingExpressions(final List<RewriterRule> rules, fi
765765
.parseGlobalVars("LITERAL_INT:1")
766766
.parseGlobalVars("LITERAL_FLOAT:0.0")
767767
.withParsedStatement("diag(A)", hooks)
768-
.toParsedStatement("$4:_m($1:_idx(1, nrow(A)), $2:_idx(1, ncol(A)), [](A, $1, $1))", hooks)
768+
.toParsedStatement("$4:_m($1:_idx(1, nrow(A)), $2:_idx(1, ncol(A)), ifelse(==($1,$2), [](A, $1, $2), 0.0))", hooks)
769769
.apply(hooks.get(1).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide
770770
.apply(hooks.get(2).getId(), stmt -> stmt.unsafePutMeta("idxId", UUID.randomUUID()), true) // Assumes it will never collide
771771
.apply(hooks.get(4).getId(), (stmt, match) -> {
@@ -775,8 +775,8 @@ public static void expandStreamingExpressions(final List<RewriterRule> rules, fi
775775

776776
RewriterStatement aRef = stmt.getChild(0, 1, 0);
777777

778-
System.out.println("GETTING: ");
779-
System.out.println(match.getNewExprRoot().getAssertions(ctx).getAssertionStatement(aRef.getNCol(), null));
778+
//System.out.println("GETTING: ");
779+
//System.out.println(match.getNewExprRoot().getAssertions(ctx).getAssertionStatement(aRef.getNCol(), null));
780780
match.getNewExprRoot().getAssertions(ctx).addEqualityAssertion(aRef.getNCol(), aRef.getNRow(), match.getNewExprRoot());
781781
}, true) // Assumes it will never collide
782782
.build()
@@ -1163,6 +1163,18 @@ public static void expandArbitraryMatrices(final List<RewriterRule> rules, final
11631163
public static void pushdownStreamSelections(final List<RewriterRule> rules, final RuleContext ctx) {
11641164
HashMap<Integer, RewriterStatement> hooks = new HashMap<>();
11651165

1166+
// ifelse merging
1167+
// TODO: Permutations e.g. ==(l2, l1) etc.
1168+
rules.add(new RewriterRuleBuilder(ctx)
1169+
.setUnidirectional(true)
1170+
.parseGlobalVars("FLOAT:a,b,c,d")
1171+
.parseGlobalVars("INT:l1,l2")
1172+
.withParsedStatement("$1:ElementWiseInstruction(ifelse(==(l1, l2), a, b), ifelse(==(l1, l2), c, d))", hooks)
1173+
.toParsedStatement("ifelse(==(l1, l2), $2:ElementWiseInstruction(a, c), $3:ElementWiseInstruction(b, d))", hooks)
1174+
.linkManyUnidirectional(hooks.get(1).getId(), List.of(hooks.get(2).getId(), hooks.get(3).getId()), RewriterStatement::transferMeta, true)
1175+
.build()
1176+
);
1177+
11661178
rules.add(new RewriterRuleBuilder(ctx)
11671179
.setUnidirectional(true)
11681180
.parseGlobalVars("INT:l")

src/main/java/org/apache/sysds/hops/rewriter/assertions/RewriterAssertions.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -466,11 +466,11 @@ public boolean addEqualityAssertion(RewriterStatement stmt1, RewriterStatement s
466466

467467
//System.out.println("MNew parts: " + partOfAssertion);
468468

469-
System.out.println("New assertion1: " + newAssertion);
469+
//System.out.println("New assertion1: " + newAssertion);
470470
return true;
471471
}
472472

473-
System.out.println("Assertion already exists");
473+
//System.out.println("Assertion already exists");
474474
return false; // The assertion already exists
475475
}
476476

@@ -497,7 +497,7 @@ public boolean addEqualityAssertion(RewriterStatement stmt1, RewriterStatement s
497497
return true;
498498
}, false);
499499

500-
System.out.println("New assertion2: " + existingAssertion);
500+
//System.out.println("New assertion2: " + existingAssertion);
501501
return true;
502502
}
503503

@@ -521,7 +521,7 @@ public boolean addEqualityAssertion(RewriterStatement stmt1, RewriterStatement s
521521
if (stmt1Assertions.stmt != null)
522522
assertionMatcher.put(stmt1Assertions.stmt, stmt2Assertions); // Only temporary
523523

524-
System.out.println("New assertion3: " + stmt2Assertions);
524+
//System.out.println("New assertion3: " + stmt2Assertions);
525525
resolveCyclicAssertions(stmt2Assertions);
526526
stmt2Assertions.deduplicate();
527527

src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1432,5 +1432,6 @@ public void testDiag1() {
14321432
System.out.println(stmt2.toParsableString(ctx, true));
14331433

14341434
assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1));
1435+
assert RewriterCostEstimator.estimateCost(stmt1, ctx) > RewriterCostEstimator.estimateCost(stmt2, ctx);
14351436
}
14361437
}

0 commit comments

Comments
 (0)