Skip to content

Commit 613657b

Browse files
committed
Some more fixes
1 parent b8d3424 commit 613657b

File tree

5 files changed

+148
-0
lines changed

5 files changed

+148
-0
lines changed

src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCreator.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,8 @@ public static boolean validateRuleApplicability(RewriterRule rule, final RuleCon
305305

306306
Set<RewriterStatement> mVars = vars.stream().map(createdObjects::get).collect(Collectors.toSet());
307307

308+
//DMLExecutor.println(stmt.toParsableString(ctx));
309+
308310
RewriterStatement.MatcherContext mCtx = RewriterStatement.MatcherContext.exactMatch(ctx, stmt, stmt1ReplaceNCols);
309311
if (stmt1ReplaceNCols.match(mCtx)) {
310312
// Check if also the right variables are associated

src/main/java/org/apache/sysds/hops/rewriter/RewriterRuntimeUtils.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,8 @@ private static RewriterStatement buildReorgOp(ReorgOp op, @Nullable String expec
896896
switch(op.getOpString()) {
897897
case "r(r')": // Matrix multiplication
898898
return RewriterUtils.parse("t(A)", ctx, matrixDefs, floatDefs, intDefs, boolDefs);
899+
case "r(rev)":
900+
return RewriterUtils.parse("rev(A)", ctx, "MATRIX:A");
899901
case "r(rdiag)":
900902
return RewriterUtils.parse("diag(A)", ctx, "MATRIX:A");
901903
}

src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterClusteringTest.java

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import java.util.Collections;
3131
import java.util.Comparator;
3232
import java.util.List;
33+
import java.util.Random;
3334
import java.util.UUID;
3435
import java.util.concurrent.atomic.AtomicLong;
3536
import java.util.function.Function;
@@ -212,6 +213,55 @@ public static void testExpressionClustering() {
212213
//System.out.println(ops + " >> " + actualCtr);
213214
});
214215
}
216+
217+
// Now we will just do random sampling for a few rounds
218+
Random rd = new Random(42);
219+
int nMaxN = RewriterAlphabetEncoder.getMaxSearchNumberForNumOps(4);
220+
for (int batch = 0; batch < 200 && System.currentTimeMillis() - startMillis < MAX_MILLIS && batch * BATCH_SIZE < maxN; batch++) {
221+
List<Integer> indices = IntStream.range(batch * BATCH_SIZE, (batch + 1) * BATCH_SIZE - 1).boxed().map(v -> maxN + rd.nextInt(nMaxN)).collect(Collectors.toList());
222+
//Collections.shuffle(indices);
223+
MutableInt ctr2 = new MutableInt(0);
224+
int maxSize = indices.size();
225+
final int mBATCH = batch;
226+
indices.parallelStream().forEach(idx -> {
227+
if (ctr2.incrementAndGet() % 10 == 0)
228+
System.out.println("Done: " + (mBATCH * BATCH_SIZE + ctr2.intValue()) + " / " + (mBATCH * BATCH_SIZE + maxSize));
229+
230+
List<RewriterAlphabetEncoder.Operand> ops = RewriterAlphabetEncoder.decodeOrderedStatements(idx);
231+
List<RewriterStatement> stmts = RewriterAlphabetEncoder.buildAllPossibleDAGs(ops, ctx, true);
232+
long actualCtr = 0;
233+
234+
for (RewriterStatement dag : stmts) {
235+
List<RewriterStatement> expanded = new ArrayList<>();
236+
expanded.add(dag);
237+
//expanded.addAll(RewriterAlphabetEncoder.buildAssertionVariations(dag, ctx, true));
238+
expanded.addAll(RewriterAlphabetEncoder.buildVariations(dag, ctx));
239+
actualCtr += expanded.size();
240+
for (RewriterStatement stmt : expanded) {
241+
try {
242+
String mstmt = stmt.toParsableString(ctx, true);
243+
stmt = RewriterUtils.parse(mstmt, ctx);
244+
ctx.metaPropagator.apply(stmt);
245+
RewriterStatement canonicalForm = converter.apply(stmt);
246+
247+
//canonicalForm.compress();
248+
//stmt.compress();
249+
synchronized (lock) {
250+
RewriterEquivalenceDatabase.DBEntry entry = canonicalExprDB.insert(ctx, canonicalForm, stmt);
251+
252+
if (entry.equivalences.size() == 2)
253+
foundEquivalences.add(entry);
254+
}
255+
} catch (Exception e) {
256+
System.err.println("Faulty expression: " + stmt.toParsableString(ctx));
257+
e.printStackTrace();
258+
}
259+
}
260+
}
261+
262+
//System.out.println(ops + " >> " + actualCtr);
263+
});
264+
}
215265
}
216266

217267
printEquivalences(/*foundEquivalences*/ Collections.emptyList(), System.currentTimeMillis() - startTime, generatedExpressions.longValue(), evaluatedExpressions.longValue(), totalCanonicalizationMillis.longValue(), failures.longValue(), true);

src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,6 +1282,82 @@ public void testSum() {
12821282
assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1));
12831283
}
12841284

1285+
@Test
1286+
public void testRowSums() {
1287+
RewriterStatement stmt1 = RewriterUtils.parse("*(rowSums(/(a,C)),b)", ctx, "MATRIX:A,B,C", "FLOAT:a,b");
1288+
RewriterStatement stmt2 = RewriterUtils.parse("rowSums(/(*(a,b),C))", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_FLOAT:0.0");
1289+
1290+
System.out.println("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx));
1291+
System.out.println("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx));
1292+
1293+
stmt1 = canonicalConverter.apply(stmt1);
1294+
stmt2 = canonicalConverter.apply(stmt2);
1295+
1296+
System.out.println("==========");
1297+
System.out.println(stmt1.toParsableString(ctx, true));
1298+
System.out.println("==========");
1299+
System.out.println(stmt2.toParsableString(ctx, true));
1300+
1301+
assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1));
1302+
}
1303+
1304+
@Test
1305+
public void testRowSums2() {
1306+
RewriterStatement stmt1 = RewriterUtils.parse("rowSums(*(A,+(B,1.0)))", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_FLOAT:1.0");
1307+
RewriterStatement stmt2 = RewriterUtils.parse("+(rowSums(A), rowSums(*(B,A)))", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_FLOAT:1.0");
1308+
1309+
System.out.println("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx));
1310+
System.out.println("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx));
1311+
1312+
stmt1 = canonicalConverter.apply(stmt1);
1313+
stmt2 = canonicalConverter.apply(stmt2);
1314+
1315+
System.out.println("==========");
1316+
System.out.println(stmt1.toParsableString(ctx, true));
1317+
System.out.println("==========");
1318+
System.out.println(stmt2.toParsableString(ctx, true));
1319+
1320+
assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1));
1321+
}
1322+
1323+
@Test
1324+
public void testDistrib3() {
1325+
RewriterStatement stmt1 = RewriterUtils.parse("*(A,+(B,1.0))", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_FLOAT:1.0");
1326+
RewriterStatement stmt2 = RewriterUtils.parse("+(A, *(B,A))", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_FLOAT:1.0");
1327+
1328+
System.out.println("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx));
1329+
System.out.println("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx));
1330+
1331+
stmt1 = canonicalConverter.apply(stmt1);
1332+
stmt2 = canonicalConverter.apply(stmt2);
1333+
1334+
System.out.println("==========");
1335+
System.out.println(stmt1.toParsableString(ctx, true));
1336+
System.out.println("==========");
1337+
System.out.println(stmt2.toParsableString(ctx, true));
1338+
1339+
assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1));
1340+
}
1341+
1342+
@Test
1343+
public void testRev2() {
1344+
RewriterStatement stmt1 = RewriterUtils.parse("trace(rev(A))", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_FLOAT:1.0");
1345+
RewriterStatement stmt2 = RewriterUtils.parse("trace(A)", ctx, "MATRIX:A,B,C", "FLOAT:a,b", "LITERAL_FLOAT:1.0");
1346+
1347+
System.out.println("Cost1: " + RewriterCostEstimator.estimateCost(stmt1, ctx));
1348+
System.out.println("Cost2: " + RewriterCostEstimator.estimateCost(stmt2, ctx));
1349+
1350+
stmt1 = canonicalConverter.apply(stmt1);
1351+
stmt2 = canonicalConverter.apply(stmt2);
1352+
1353+
System.out.println("==========");
1354+
System.out.println(stmt1.toParsableString(ctx, true));
1355+
System.out.println("==========");
1356+
System.out.println(stmt2.toParsableString(ctx, true));
1357+
1358+
assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2, stmt1));
1359+
}
1360+
12851361
@Test
12861362
public void testSumInequality() {
12871363
RewriterStatement stmt1 = RewriterUtils.parse("sum(+(a,*(B,c)))", ctx, "MATRIX:B", "FLOAT:a,c");

src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/DMLCodeGenTest.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,4 +165,22 @@ public void testFused3() {
165165

166166
assert RewriterRuleCreator.validateRuleApplicability(rule, ctx);
167167
}
168+
169+
@Test
170+
public void testRev() {
171+
String ruleStr = "MATRIX:A\n" +
172+
"FLOAT:b\n" +
173+
"\n" +
174+
"rev(*(rev(A),b))\n" +
175+
"=>\n" +
176+
"*(A,b)";
177+
178+
RewriterRule rule = RewriterUtils.parseRule(ruleStr, ctx);
179+
180+
System.out.println(DMLCodeGenerator.generateRuleValidationDML(rule, "test", ctx));
181+
182+
assert RewriterRuleCreator.validateRuleCorrectness(rule, ctx);
183+
184+
assert RewriterRuleCreator.validateRuleApplicability(rule, ctx);
185+
}
168186
}

0 commit comments

Comments
 (0)