Skip to content

Commit 8fca8e6

Browse files
committed
Bugfix
1 parent edc93d9 commit 8fca8e6

File tree

8 files changed

+61
-58
lines changed

8 files changed

+61
-58
lines changed

src/main/java/org/apache/sysds/hops/rewriter/MetaPropagator.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,10 @@ private RewriterStatement propagateDims(RewriterStatement root, RewriterStatemen
316316
root.unsafePutMeta("ncol", root.getOperands().get(0).getMeta("ncol"));
317317
root.unsafePutMeta("nrow", RewriterStatement.literal(ctx, 1L));
318318
return null;
319+
case "cellMat(MATRIX)":
320+
root.unsafePutMeta("ncol", RewriterStatement.literal(ctx, 1L));
321+
root.unsafePutMeta("nrow", RewriterStatement.literal(ctx, 1L));
322+
return null;
319323
case "rev(MATRIX)":
320324
case "replace(MATRIX,FLOAT,FLOAT)":
321325
case "sumSq(MATRIX)":

src/main/java/org/apache/sysds/hops/rewriter/RewriterAlphabetEncoder.java

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -97,24 +97,6 @@ public class RewriterAlphabetEncoder {
9797

9898
private static RuleContext ctx;
9999

100-
/*private static List<String> allPossibleTypes(Operand op, int argNum) {
101-
if (op == null)
102-
return List.of("MATRIX", "FLOAT");
103-
104-
switch (op.op) {
105-
case "+":
106-
return List.of("MATRIX", "FLOAT");
107-
case "-":
108-
return List.of("MATRIX", "FLOAT");
109-
case "*":
110-
return List.of("MATRIX", "FLOAT");
111-
case "/":
112-
return List.of("MATRIX", "FLOAT");
113-
}
114-
115-
throw new NotImplementedException();
116-
}*/
117-
118100
public static int getMaxSearchNumberForNumOps(int numOps) {
119101
int out = 1;
120102
for (int i = 0; i < numOps; i++)

src/main/java/org/apache/sysds/hops/rewriter/RewriterContextSettings.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ public static String getDefaultContextString() {
364364

365365
builder.append("rowVec(MATRIX)::MATRIX\n");
366366
builder.append("colVec(MATRIX)::MATRIX\n");
367+
builder.append("cellMat(MATRIX)::MATRIX\n");
367368

368369
builder.append("_m(INT,INT,FLOAT)::MATRIX\n");
369370
builder.append("_m(INT,INT,BOOL)::MATRIX\n");

src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCollection.java

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ public static void substituteEquivalentStatements(final List<RewriterRule> rules
389389
.parseGlobalVars("MATRIX:A")
390390
.parseGlobalVars("LITERAL_INT:1")
391391
.withParsedStatement("rowVec(A)")
392-
.toParsedStatement("[](A, 1, nrow(A), 1, 1)")
392+
.toParsedStatement("[]($1:A, 1, 1, 1, ncol(A))", hooks)
393393
.build()
394394
);
395395

@@ -398,7 +398,16 @@ public static void substituteEquivalentStatements(final List<RewriterRule> rules
398398
.parseGlobalVars("MATRIX:A")
399399
.parseGlobalVars("LITERAL_INT:1")
400400
.withParsedStatement("colVec(A)")
401-
.toParsedStatement("[](A, 1, 1, 1, ncol(A))")
401+
.toParsedStatement("[](A, 1, nrow(A), 1, 1)")
402+
.build()
403+
);
404+
405+
rules.add(new RewriterRuleBuilder(ctx, "cellMat(A) => [](A, ...)")
406+
.setUnidirectional(true)
407+
.parseGlobalVars("MATRIX:A")
408+
.parseGlobalVars("LITERAL_INT:1")
409+
.withParsedStatement("cellMat(A)")
410+
.toParsedStatement("[](A, 1, 1, 1, 1)")
402411
.build()
403412
);
404413

@@ -1306,7 +1315,6 @@ public static void pushdownStreamSelections(final List<RewriterRule> rules, fina
13061315
.build()
13071316
);
13081317

1309-
// TODO: We would have to take into account the offset of h, i
13101318
rules.add(new RewriterRuleBuilder(ctx, "Element selection pushdown")
13111319
.setUnidirectional(true)
13121320
.parseGlobalVars("MATRIX:A,B")
@@ -1667,7 +1675,9 @@ public static void canonicalExpandAfterFlattening(final List<RewriterRule> rules
16671675
.as(UUID.randomUUID().toString())
16681676
.withInstruction("sum")
16691677
.withOps(newIdxExpr);
1678+
System.out.println("Copying index list: " + newIdxExpr.toParsableString(ctx));
16701679
RewriterUtils.copyIndexList(newIdxExpr);
1680+
System.out.println("After copy: " + newIdxExpr.toParsableString(ctx));
16711681
newIdxExpr.refreshReturnType(ctx);
16721682
newSum.consolidate(ctx);
16731683
newArgList.getOperands().add(newSum);
@@ -1700,7 +1710,6 @@ public static void flattenedAlgebraRewrites(final List<RewriterRule> rules, fina
17001710
argList.getOperands().set(i, newStmt);
17011711
}
17021712

1703-
// TODO: This is inefficient
17041713
RewriterUtils.tryFlattenNestedOperatorPatterns(ctx, match.getNewExprRoot());
17051714
}, true)
17061715
.build()

src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterUtils.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,7 @@ public static void copyIndexList(RewriterStatement idxExprRoot) {
858858
operands.set(i, cpy);
859859
}
860860

861-
RewriterUtils.replaceReferenceAware(idxExprRoot.getChild(1), stmt -> {
861+
RewriterStatement out = RewriterUtils.replaceReferenceAware(idxExprRoot.getChild(1), stmt -> {
862862
UUID idxId = (UUID) stmt.getMeta("idxId");
863863
if (idxId != null) {
864864
RewriterStatement newStmt = replacements.get(idxId);
@@ -868,6 +868,8 @@ public static void copyIndexList(RewriterStatement idxExprRoot) {
868868

869869
return null;
870870
});
871+
872+
idxExprRoot.getOperands().set(1, out);
871873
}
872874

873875
public static void retargetIndexExpressions(RewriterStatement rootExpr, UUID oldIdxId, RewriterStatement newStatement) {
@@ -1441,6 +1443,7 @@ public static Function<RewriterStatement, RewriterStatement> buildCanonicalFormC
14411443
RewriterRuleCollection.canonicalizeBooleanStatements(algebraicCanonicalizationRules, ctx);
14421444
RewriterRuleCollection.canonicalizeAlgebraicStatements(algebraicCanonicalizationRules, ctx);
14431445
RewriterRuleCollection.eliminateMultipleCasts(algebraicCanonicalizationRules, ctx);
1446+
RewriterRuleCollection.buildElementWiseAlgebraicCanonicalization(algebraicCanonicalizationRules, ctx);
14441447
RewriterHeuristic algebraicCanonicalization = new RewriterHeuristic(new RewriterRuleSet(ctx, algebraicCanonicalizationRules));
14451448

14461449
ArrayList<RewriterRule> expRules = new ArrayList<>();
@@ -1603,7 +1606,6 @@ private static RewriterStatement tryPullOutSum(RewriterStatement sum, final Rule
16031606
components.add(add);
16041607
}
16051608

1606-
//add = foldConstants(add, ctx);
16071609
RewriterStatement out = RewriterStatement.multiArgInstr(ctx, "*", sumBody);
16081610
out.getChild(0).getOperands().addAll(components);
16091611
return foldConstants(out, ctx);

src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterNormalFormTests.java

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ public void testSimplifyDistributiveBinaryOperation() {
8686

8787
@Test
8888
public void testSimplifyBushyBinaryOperation() {
89-
RewriterStatement stmt1 = RewriterUtils.parse("*(A,*(B, %*%(C, rowVec(D))))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0");
90-
RewriterStatement stmt2 = RewriterUtils.parse("*(*(A,B), %*%(C, rowVec(D)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0");
89+
RewriterStatement stmt1 = RewriterUtils.parse("*(A,*(B, %*%(C, colVec(D))))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0");
90+
RewriterStatement stmt2 = RewriterUtils.parse("*(*(A,B), %*%(C, colVec(D)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0");
9191

9292
stmt1 = canonicalConverter.apply(stmt1);
9393
stmt2 = canonicalConverter.apply(stmt2);
@@ -159,7 +159,7 @@ public void testSimplifyTraceMatrixMult() {
159159
@Test
160160
public void testSimplifySlicedMatrixMult() {
161161
RewriterStatement stmt1 = RewriterUtils.parse("[](%*%(A,B), 1, 1)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0", "LITERAL_INT:1");
162-
RewriterStatement stmt2 = RewriterUtils.parse("as.scalar(%*%(colVec(A), rowVec(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0", "LITERAL_INT:1");
162+
RewriterStatement stmt2 = RewriterUtils.parse("as.scalar(%*%(rowVec(A), colVec(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:1.0,2.0", "LITERAL_INT:1");
163163

164164
assert match(stmt1, stmt2);
165165
}
@@ -226,23 +226,23 @@ public void testSimplifyNotOverComparisons() {
226226
public void testRemoveEmptyRightIndexing() {
227227
// We do not directly support the specification of nnz, but we can emulate such a matrix by multiplying with 0
228228
RewriterStatement stmt1 = RewriterUtils.parse("[](*(A, 0.0), 1, nrow(A), 1, 1)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1");
229-
RewriterStatement stmt2 = RewriterUtils.parse("const(rowVec(A), 0.0)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1");
229+
RewriterStatement stmt2 = RewriterUtils.parse("const(colVec(A), 0.0)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1");
230230

231231
assert match(stmt1, stmt2);
232232
}
233233

234234
@Test
235235
public void testRemoveUnnecessaryRightIndexing() {
236-
RewriterStatement stmt1 = RewriterUtils.parse("[](rowVec(A), 1, nrow(A), 1, 1)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1");
237-
RewriterStatement stmt2 = RewriterUtils.parse("rowVec(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1");
236+
RewriterStatement stmt1 = RewriterUtils.parse("[](colVec(A), 1, nrow(A), 1, 1)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1");
237+
RewriterStatement stmt2 = RewriterUtils.parse("colVec(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1");
238238

239239
assert match(stmt1, stmt2);
240240
}
241241

242242
@Test
243243
public void testRemoveUnnecessaryReorgOperation3() {
244-
RewriterStatement stmt1 = RewriterUtils.parse("t(rowVec(colVec(A)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1");
245-
RewriterStatement stmt2 = RewriterUtils.parse("rowVec(colVec(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1");
244+
RewriterStatement stmt1 = RewriterUtils.parse("t(cellMat(A)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1");
245+
RewriterStatement stmt2 = RewriterUtils.parse("cellMat(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1");
246246

247247
assert match(stmt1, stmt2);
248248
}
@@ -275,41 +275,41 @@ public void testFuseDatagenAndReorgOperation() {
275275

276276
@Test
277277
public void testSimplifyColwiseAggregate() {
278-
RewriterStatement stmt1 = RewriterUtils.parse("colSums(colVec(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
279-
RewriterStatement stmt2 = RewriterUtils.parse("colVec(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
278+
RewriterStatement stmt1 = RewriterUtils.parse("colSums(rowVec(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
279+
RewriterStatement stmt2 = RewriterUtils.parse("rowVec(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
280280

281281
assert match(stmt1, stmt2);
282282
}
283283

284284
@Test
285285
public void testSimplifyRowwiseAggregate() {
286-
RewriterStatement stmt1 = RewriterUtils.parse("rowSums(rowVec(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
287-
RewriterStatement stmt2 = RewriterUtils.parse("rowVec(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
286+
RewriterStatement stmt1 = RewriterUtils.parse("rowSums(colVec(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
287+
RewriterStatement stmt2 = RewriterUtils.parse("colVec(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
288288

289289
assert match(stmt1, stmt2);
290290
}
291291

292292
// We don't have broadcasting semantics
293293
@Test
294294
public void testSimplifyColSumsMVMult() {
295-
RewriterStatement stmt1 = RewriterUtils.parse("colSums(*(rowVec(A), rowVec(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
296-
RewriterStatement stmt2 = RewriterUtils.parse("%*%(t(rowVec(B)), rowVec(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
295+
RewriterStatement stmt1 = RewriterUtils.parse("colSums(*(colVec(A), colVec(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
296+
RewriterStatement stmt2 = RewriterUtils.parse("%*%(t(colVec(B)), colVec(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
297297

298298
assert match(stmt1, stmt2);
299299
}
300300

301301
// We don't have broadcasting semantics
302302
@Test
303303
public void testSimplifyRowSumsMVMult() {
304-
RewriterStatement stmt1 = RewriterUtils.parse("rowSums(*(colVec(A), colVec(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
305-
RewriterStatement stmt2 = RewriterUtils.parse("%*%(colVec(A), t(colVec(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
304+
RewriterStatement stmt1 = RewriterUtils.parse("rowSums(*(rowVec(A), rowVec(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
305+
RewriterStatement stmt2 = RewriterUtils.parse("%*%(rowVec(A), t(rowVec(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
306306

307307
assert match(stmt1, stmt2);
308308
}
309309

310310
@Test
311311
public void testSimplifyUnnecessaryAggregate() {
312-
RewriterStatement stmt1 = RewriterUtils.parse("sum(rowVec(colVec(A)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
312+
RewriterStatement stmt1 = RewriterUtils.parse("sum(cellMat(A)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
313313
RewriterStatement stmt2 = RewriterUtils.parse("as.scalar(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
314314

315315
assert match(stmt1, stmt2);
@@ -350,16 +350,16 @@ public void testSimplifyEmptyMatrixMult() {
350350

351351
@Test
352352
public void testSimplifyEmptyMatrixMult2() {
353-
RewriterStatement stmt1 = RewriterUtils.parse("%*%(rowVec(A), cast.MATRIX(1.0))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
354-
RewriterStatement stmt2 = RewriterUtils.parse("rowVec(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
353+
RewriterStatement stmt1 = RewriterUtils.parse("%*%(colVec(A), cast.MATRIX(1.0))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
354+
RewriterStatement stmt2 = RewriterUtils.parse("colVec(A)", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
355355

356356
assert match(stmt1, stmt2);
357357
}
358358

359359
@Test
360360
public void testSimplifyScalarMatrixMult() {
361-
RewriterStatement stmt1 = RewriterUtils.parse("%*%(rowVec(A), cast.MATRIX(a))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
362-
RewriterStatement stmt2 = RewriterUtils.parse("*(rowVec(A), as.scalar(cast.MATRIX(a)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
361+
RewriterStatement stmt1 = RewriterUtils.parse("%*%(colVec(A), cast.MATRIX(a))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
362+
RewriterStatement stmt2 = RewriterUtils.parse("*(colVec(A), as.scalar(cast.MATRIX(a)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
363363

364364
assert match(stmt1, stmt2);
365365
}
@@ -405,8 +405,8 @@ public void testPushdownSumOnAdditiveBinary() {
405405

406406
@Test
407407
public void testSimplifyDotProductSum() {
408-
RewriterStatement stmt1 = RewriterUtils.parse("cast.MATRIX(sum(sq(rowVec(A))))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
409-
RewriterStatement stmt2 = RewriterUtils.parse("%*%(t(rowVec(A)), rowVec(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
408+
RewriterStatement stmt1 = RewriterUtils.parse("cast.MATRIX(sum(sq(colVec(A))))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
409+
RewriterStatement stmt2 = RewriterUtils.parse("%*%(t(colVec(A)), colVec(A))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
410410

411411
assert match(stmt1, stmt2);
412412
}
@@ -477,7 +477,7 @@ public void testSimplifyEmptyBinaryOperation3() {
477477

478478
//@Test
479479
public void testSimplifyScalarMVBinaryOperation() {
480-
RewriterStatement stmt1 = RewriterUtils.parse("*(A, rowVec(colVec(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
480+
RewriterStatement stmt1 = RewriterUtils.parse("*(A, colVec(colVec(B)))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
481481
RewriterStatement stmt2 = RewriterUtils.parse("*(A, as.scalar(B))", ctx, "MATRIX:A,B,C,D", "FLOAT:a,b,c", "LITERAL_FLOAT:0.0,1.0,2.0", "LITERAL_INT:1", "LITERAL_BOOL:TRUE,FALSE", "INT:i");
482482

483483
assert match(stmt1, stmt2);

src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,17 +71,22 @@ public void testAdditionMatrix1() {
7171
@Test
7272
public void testSubtractionFloat1() {
7373
RewriterStatement stmt = RewriterUtils.parse("+(-(a, b), 1)", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_INT:0,1");
74+
RewriterStatement stmt2 = RewriterUtils.parse("+(argList(-(b), a, 1))", ctx, "FLOAT:a,b", "LITERAL_INT:0,1");
7475
stmt = canonicalConverter.apply(stmt);
76+
stmt2 = canonicalConverter.apply(stmt2);
7577
System.out.println(stmt.toParsableString(ctx, true));
76-
assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, RewriterUtils.parse("+(argList(-(b), a, 1))", ctx, "FLOAT:a,b", "LITERAL_INT:0,1"), stmt));
78+
System.out.println(stmt2.toParsableString(ctx, true));
79+
assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt));
7780
}
7881

7982
@Test
8083
public void testSubtractionFloat2() {
8184
RewriterStatement stmt = RewriterUtils.parse("+(1, -(a, -(b, c)))", ctx, "MATRIX:A,B,C", "FLOAT:a,b,c", "LITERAL_INT:0,1");
85+
RewriterStatement stmt2 = RewriterUtils.parse("+(argList(-(b), a, c, 1))", ctx, "FLOAT:a,b, c", "LITERAL_INT:0,1");
8286
stmt = canonicalConverter.apply(stmt);
87+
stmt2 = canonicalConverter.apply(stmt2);
8388
System.out.println(stmt.toParsableString(ctx, true));
84-
assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, RewriterUtils.parse("+(argList(-(b), a, c, 1))", ctx, "FLOAT:a,b, c", "LITERAL_INT:0,1"), stmt));
89+
assert stmt.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt));
8590
}
8691

8792
// Fusion will no longer be pursued
@@ -873,8 +878,8 @@ public void testSumEquality3() {
873878

874879
@Test
875880
public void testSumEquality4() {
876-
RewriterStatement stmt1 = RewriterUtils.parse("%*%(t(rowVec(A)), rowVec(A))", ctx, "MATRIX:A,B", "LITERAL_INT:1");
877-
RewriterStatement stmt2 = RewriterUtils.parse("as.matrix(sum(*(rowVec(A), rowVec(A))))", ctx, "MATRIX:A,B", "LITERAL_INT:1");
881+
RewriterStatement stmt1 = RewriterUtils.parse("%*%(t(colVec(A)), colVec(A))", ctx, "MATRIX:A,B", "LITERAL_INT:1");
882+
RewriterStatement stmt2 = RewriterUtils.parse("as.matrix(sum(*(colVec(A), colVec(A))))", ctx, "MATRIX:A,B", "LITERAL_INT:1");
878883

879884
stmt1 = canonicalConverter.apply(stmt1);
880885
stmt2 = canonicalConverter.apply(stmt2);
@@ -971,8 +976,8 @@ public void testMMEquivalence() {
971976

972977
@Test
973978
public void testMMEquivalence2() {
974-
RewriterStatement stmt1 = RewriterUtils.parse("cast.MATRIX(sum(*(t([](A, 1, 1, 1, ncol(A))), [](B, 1, nrow(B), 1, 1))))", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1");
975-
RewriterStatement stmt2 = RewriterUtils.parse("%*%([](A, 1, 1, 1, ncol(A)), [](B, 1, nrow(B), 1, 1))", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1");
979+
RewriterStatement stmt1 = RewriterUtils.parse("cast.MATRIX(sum(*(t(rowVec(A)), colVec(B))))", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1");
980+
RewriterStatement stmt2 = RewriterUtils.parse("%*%(rowVec(A), colVec(B))", ctx, "MATRIX:A,B", "FLOAT:b", "LITERAL_INT:1");
976981

977982
System.out.println("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx));
978983
System.out.println("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx));

0 commit comments

Comments
 (0)