Skip to content

Commit ab76040

Browse files
committed
Some more improvements
1 parent 9852e36 commit ab76040

File tree

4 files changed

+74
-3
lines changed

4 files changed

+74
-3
lines changed

src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCollection.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,14 @@ public static void substituteFusedOps(final List<RewriterRule> rules, final Rule
471471
.build()
472472
);
473473

474+
rules.add(new RewriterRuleBuilder(ctx, "*2(A) => +(A,A)")
475+
.setUnidirectional(true)
476+
.parseGlobalVars("MATRIX:A")
477+
.withParsedStatement("*2(A)")
478+
.toParsedStatement("+(A,A)")
479+
.build()
480+
);
481+
474482
// TODO
475483
/*rules.add(new RewriterRuleBuilder(ctx, "replace(A, a, b) => A")
476484
.setUnidirectional(true)

src/main/java/org/apache/sysds/hops/rewriter/RewriterUtils.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1461,6 +1461,8 @@ private static RewriterStatement tryPullOutSum(RewriterStatement sum, final Rule
14611461
return RewriterStatement.multiArgInstr(ctx, "*", toRemove.toArray(RewriterStatement[]::new));
14621462
}
14631463
} else if (sumBody.trueInstruction().equals("+")) {
1464+
// TODO: What about sum(+(A, *(a, B)))? We could pull out a
1465+
14641466
// We have to assume here, that this instruction is not referenced anywhere else in the graph
14651467
List<RewriterStatement> argList = sumBody.getChild(0).getOperands();
14661468
List<RewriterStatement> toRemove = new ArrayList<>(argList.size());
@@ -1489,9 +1491,10 @@ private static RewriterStatement tryPullOutSum(RewriterStatement sum, final Rule
14891491
}
14901492

14911493
mul.add(outerSum);
1492-
mul.add(sum);
1494+
RewriterStatement mulStmt = RewriterStatement.multiArgInstr(ctx, "*", mul.toArray(RewriterStatement[]::new));
1495+
//mul.add(sum);
14931496

1494-
return RewriterStatement.multiArgInstr(ctx, "*", mul.toArray(RewriterStatement[]::new));
1497+
return RewriterStatement.multiArgInstr(ctx, "+", mulStmt, sum);
14951498
}
14961499
}
14971500

src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,14 +193,19 @@ public void testSumEquality6() {
193193
@Test
194194
public void testSumEquality() {
195195
RewriterStatement stmt = RewriterUtils.parse("sum(+(B, sum(*(a, A))))", ctx, "MATRIX:A,B", "FLOAT:a");
196-
RewriterStatement stmt2 = RewriterUtils.parse("*(a, sum(+(B, sum(A))))", ctx, "MATRIX:A,B", "FLOAT:a");
196+
RewriterStatement stmt2 = RewriterUtils.parse("+(*(a, length(A)), sum(+(B, sum(A))))", ctx, "MATRIX:A,B", "FLOAT:a");
197+
RewriterStatement stmt3 = RewriterUtils.parse("sum(+(B, *(a, sum(A))))", ctx, "MATRIX:A,B", "FLOAT:a");
197198
stmt = canonicalConverter.apply(stmt);
198199
stmt2 = canonicalConverter.apply(stmt2);
200+
stmt3 = canonicalConverter.apply(stmt3);
199201

200202
System.out.println("==========");
201203
System.out.println(stmt.toParsableString(ctx, true));
202204
System.out.println("==========");
205+
System.out.println(stmt3.toParsableString(ctx, true));
206+
System.out.println("==========");
203207
System.out.println(stmt2.toParsableString(ctx, true));
208+
assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt3, stmt));
204209
assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt));
205210
}
206211

@@ -1245,4 +1250,42 @@ public void testFused6() {
12451250

12461251
assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1));
12471252
}
1253+
1254+
@Test
1255+
public void testSum() {
1256+
RewriterStatement stmt1 = RewriterUtils.parse("sum(+(a,A))", ctx, "MATRIX:A,B", "FLOAT:a");
1257+
RewriterStatement stmt2 = RewriterUtils.parse("+(*(a, length(A)), sum(A))", ctx, "MATRIX:A,B", "FLOAT:a", "LITERAL_FLOAT:0.0");
1258+
1259+
System.out.println("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx));
1260+
System.out.println("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx));
1261+
1262+
stmt1 = canonicalConverter.apply(stmt1);
1263+
stmt2 = canonicalConverter.apply(stmt2);
1264+
1265+
System.out.println("==========");
1266+
System.out.println(stmt1.toParsableString(ctx, true));
1267+
System.out.println("==========");
1268+
System.out.println(stmt2.toParsableString(ctx, true));
1269+
1270+
assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1));
1271+
}
1272+
1273+
@Test
1274+
public void testSumInequality() {
1275+
RewriterStatement stmt1 = RewriterUtils.parse("sum(+(a,*(B,c)))", ctx, "MATRIX:B", "FLOAT:a,c");
1276+
RewriterStatement stmt2 = RewriterUtils.parse("*(a, sum(+(B,c)))", ctx, "MATRIX:B", "FLOAT:a,c", "LITERAL_FLOAT:0.0");
1277+
1278+
System.out.println("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx));
1279+
System.out.println("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx));
1280+
1281+
stmt1 = canonicalConverter.apply(stmt1);
1282+
stmt2 = canonicalConverter.apply(stmt2);
1283+
1284+
System.out.println("==========");
1285+
System.out.println(stmt1.toParsableString(ctx, true));
1286+
System.out.println("==========");
1287+
System.out.println(stmt2.toParsableString(ctx, true));
1288+
1289+
assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1));
1290+
}
12481291
}

src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/DMLCodeGenTest.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,4 +148,21 @@ public void testFused2() {
148148

149149
assert RewriterRuleCreator.validateRuleApplicability(rule, ctx);
150150
}
151+
152+
@Test
153+
public void testFused3() {
154+
// TODO: This rule has been ignored, but why?
155+
String ruleStr = "MATRIX:A,B\nLITERAL_FLOAT:0.0,1.0\n" +
156+
"+(-(A,B),A)\n" +
157+
"=>\n" +
158+
"-(*2(A), B)";
159+
160+
RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx);
161+
162+
System.out.println(DMLCodeGenerator.generateRuleValidationDML(rule, "test", ctx));
163+
164+
assert RewriterRuleCreator.validateRuleCorrectness(rule, ctx);
165+
166+
assert RewriterRuleCreator.validateRuleApplicability(rule, ctx);
167+
}
151168
}

0 commit comments

Comments
 (0)