Skip to content

Commit 99b720e

Browse files
committed
Some improvements
1 parent f11e557 commit 99b720e

15 files changed

+220
-93
lines changed

src/main/java/org/apache/sysds/hops/rewriter/DMLCodeGenerator.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ public static Consumer<String> ruleValidationScript(String sessionId, Consumer<B
7777
return;
7878

7979
if (line.endsWith("valid: TRUE")) {
80+
//DMLExecutor.println("Rule is valid!");
8081
validator.accept(true);
8182
} else {
8283
DMLExecutor.println("An invalid rule was found!");

src/main/java/org/apache/sysds/hops/rewriter/MetaPropagator.java

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,28 @@ private RewriterStatement propagateDims(RewriterStatement root, RewriterStatemen
117117
Object rowAccess;
118118

119119
if (root.getOperands() == null || root.getOperands().isEmpty()) {
120-
if (root.getMeta("ncol") == null)
120+
RewriterStatement ncol = root.getNCol();
121+
122+
if (ncol == null) {
121123
root.unsafePutMeta("ncol", new RewriterInstruction().withInstruction("ncol").withOps(root).as(UUID.randomUUID().toString()).consolidate(ctx));
122-
if (root.getMeta("nrow") == null)
124+
} /*else {
125+
RewriterStatement asserted = assertions != null ? assertions.getAssertionStatement(ncol, null) : null;
126+
127+
if (asserted != null && asserted != ncol)
128+
root.unsafePutMeta("ncol", asserted);
129+
}*/
130+
131+
RewriterStatement nrow = root.getNRow();
132+
133+
if (nrow == null) {
123134
root.unsafePutMeta("nrow", new RewriterInstruction().withInstruction("nrow").withOps(root).as(UUID.randomUUID().toString()).consolidate(ctx));
135+
} /*else {
136+
RewriterStatement asserted = assertions != null ? assertions.getAssertionStatement(nrow, null) : null;
137+
138+
if (asserted != null && asserted != ncol)
139+
root.unsafePutMeta("nrow", asserted);
140+
}*/
141+
124142
return null;
125143
}
126144

src/main/java/org/apache/sysds/hops/rewriter/RewriterDataType.java

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,47 @@ public boolean match(final MatcherContext mCtx) {
224224
}
225225
}
226226

227+
// If matrix, check if the dimensions
228+
if (dType.equals("MATRIX")) {
229+
RewriterStatement ncolEquiv = getNCol();
230+
RewriterStatement nrowEquiv = getNRow();
231+
232+
if (ncolEquiv != null && nrowEquiv != null) {
233+
if (!mCtx.wasVisited(this)) {
234+
mCtx.dontVisitAgain(this);
235+
RewriterStatement ncolEquivThat = stmt.getNCol();
236+
RewriterStatement nrowEquivThat = stmt.getNRow();
237+
238+
RewriterAssertions assertionsThis = mCtx.getOldAssertionsThis();
239+
RewriterAssertions assertionsThat = mCtx.getOldAssertionsThat();
240+
241+
if (assertionsThis != null) {
242+
RewriterStatement ncolAssertion = assertionsThis.getAssertionStatement(ncolEquiv, null);
243+
244+
RewriterStatement nrowAssertion = assertionsThis.getAssertionStatement(nrowEquiv, null);
245+
ncolEquiv = ncolAssertion == null ? ncolEquiv : ncolAssertion;
246+
nrowEquiv = nrowAssertion == null ? nrowEquiv : nrowAssertion;
247+
}
248+
249+
if (assertionsThat != null) {
250+
RewriterStatement ncolAssertionThat = assertionsThat.getAssertionStatement(ncolEquivThat, null);
251+
252+
RewriterStatement nrowAssertionThat = assertionsThat.getAssertionStatement(nrowEquivThat, null);
253+
ncolEquivThat = ncolAssertionThat == null ? ncolEquiv : ncolAssertionThat;
254+
nrowEquivThat = nrowAssertionThat == null ? nrowEquiv : nrowAssertionThat;
255+
}
256+
257+
// Now, match those statements
258+
mCtx.currentStatement = ncolEquivThat;
259+
if (!ncolEquiv.match(mCtx))
260+
return false;
261+
mCtx.currentStatement = nrowEquivThat;
262+
if (!nrowEquiv.match(mCtx))
263+
return false;
264+
}
265+
}
266+
}
267+
227268
RewriterStatement assoc = mCtx.getDependencyMap().get(this);
228269
if (assoc == null) {
229270
if (!mCtx.allowDuplicatePointers && mCtx.getDependencyMap().containsValue(stmt)) {

src/main/java/org/apache/sysds/hops/rewriter/RewriterRule.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ public boolean matchStmt1(RewriterStatement stmt, ArrayList<RewriterStatement.Ma
122122
}*/
123123

124124
public RewriterStatement.MatchingSubexpression matchSingleStmt1(RewriterStatement exprRoot, RewriterStatement.RewriterPredecessor pred, RewriterStatement stmt, HashMap<RewriterStatement, RewriterStatement> dependencyMap, List<ExplicitLink> links, Map<RewriterStatement, LinkObject> ruleLinks) {
125-
RewriterStatement.MatcherContext mCtx = new RewriterStatement.MatcherContext(ctx, stmt, pred, exprRoot, true, true, false, true, true, false, true, false, false, linksStmt1ToStmt2);
125+
RewriterStatement.MatcherContext mCtx = new RewriterStatement.MatcherContext(ctx, stmt, pred, exprRoot, getStmt1(), true, true, false, true, true, false, true, false, false, linksStmt1ToStmt2);
126126
mCtx.currentStatement = stmt;
127127
boolean match = getStmt1().match(mCtx);
128128

@@ -142,7 +142,7 @@ public boolean matchStmt2(RewriterStatement stmt, ArrayList<RewriterStatement.Ma
142142
}*/
143143

144144
public RewriterStatement.MatchingSubexpression matchSingleStmt2(RewriterStatement exprRoot, RewriterStatement.RewriterPredecessor pred, RewriterStatement stmt, HashMap<RewriterStatement, RewriterStatement> dependencyMap, List<ExplicitLink> links, Map<RewriterStatement, LinkObject> ruleLinks) {
145-
RewriterStatement.MatcherContext mCtx = new RewriterStatement.MatcherContext(ctx, stmt, pred, exprRoot, true, true, false, true, true, false, true, false, false, linksStmt2ToStmt1);
145+
RewriterStatement.MatcherContext mCtx = new RewriterStatement.MatcherContext(ctx, stmt, pred, exprRoot, getStmt2(), true, true, false, true, true, false, true, false, false, linksStmt2ToStmt1);
146146
mCtx.currentStatement = stmt;
147147
boolean match = getStmt2().match(mCtx);
148148

src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCreator.java

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ public synchronized boolean registerRule(RewriterRule rule, long preCost, long p
3939
boolean converged = false;
4040
boolean changed = false;
4141

42+
List<RewriterRule> appliedRules = new ArrayList<>();
43+
4244
for (int i = 0; i < 500; i++) {
4345
RewriterRuleSet.ApplicableRule applicableRule = ruleSet.acceleratedFindFirst(toTest);
4446

@@ -48,11 +50,12 @@ public synchronized boolean registerRule(RewriterRule rule, long preCost, long p
4850
}
4951

5052
toTest = applicableRule.rule.apply(applicableRule.matches.get(0), toTest, applicableRule.forward, false);
53+
appliedRules.add(applicableRule.rule);
5154
changed = true;
5255
}
5356

5457
if (!converged)
55-
throw new IllegalArgumentException("The existing rule-set did not seem to converge for the example: \n" + toTest.toParsableString(ctx, true));
58+
throw new IllegalArgumentException("The existing rule-set did not seem to converge for the example: \n" + toTest.toParsableString(ctx, true) + "\n" + String.join("\n", appliedRules.stream().map(rl -> rl.toParsableString(ctx)).collect(Collectors.toList())));
5659

5760
if (changed) {
5861
long existingPostCost;
@@ -73,6 +76,8 @@ public synchronized boolean registerRule(RewriterRule rule, long preCost, long p
7376
if (!validateRuleCorrectnessAndGains(rule, ctx))
7477
return false; // Then, either the rule is incorrect or is already implemented
7578

79+
System.out.println("Rule is correct!");
80+
7681
RewriterRuleSet probingSet = new RewriterRuleSet(ctx, List.of(rule));
7782
List<RewriterRule> rulesToRemove = new ArrayList<>();
7883
List<RewriterRule> rulesThatMustComeBefore = new ArrayList<>();
@@ -143,12 +148,16 @@ public static boolean validateRuleCorrectnessAndGains(RewriterRule rule, final R
143148
String code = DMLCodeGenerator.generateRuleValidationDML(rule, sessionId);
144149

145150
MutableBoolean isValid = new MutableBoolean(false);
146-
System.out.println(code);
151+
//System.out.println("=== CODE ===");
152+
//System.out.println(code);
147153
DMLExecutor.executeCode(code, DMLCodeGenerator.ruleValidationScript(sessionId, isValid::setValue));
148154

149155
if (!isValid.booleanValue())
150156
return false;
151157

158+
if (true)
159+
return true;
160+
152161
Set<RewriterStatement> vars = DMLCodeGenerator.getVariables(rule.getStmt1());
153162
Set<String> varNames = vars.stream().map(RewriterStatement::getId).collect(Collectors.toSet());
154163
String code2Header = DMLCodeGenerator.generateDMLVariables(vars);
@@ -192,7 +201,7 @@ public static boolean validateRuleCorrectnessAndGains(RewriterRule rule, final R
192201
stmt.prepareForHashing();
193202
stmt.recomputeHashCodes(ctx);
194203

195-
RewriterStatement.MatcherContext mCtx = RewriterStatement.MatcherContext.exactMatch(ctx, stmt);
204+
RewriterStatement.MatcherContext mCtx = RewriterStatement.MatcherContext.exactMatch(ctx, stmt, rule.getStmt1());
196205
if (rule.getStmt1().match(mCtx)) {
197206
// Check if also the right variables are associated
198207
boolean assocsMatching = true;
@@ -255,7 +264,7 @@ private static Map<RewriterStatement, RewriterStatement> getAssociations(Rewrite
255264
Map<RewriterStatement, RewriterStatement> fromCanonicalLink = getAssociationToCanonicalForm(from, canonicalFormFrom, true, ctx);
256265
Map<RewriterStatement, RewriterStatement> toCanonicalLink = getAssociationToCanonicalForm(to, canonicalFormTo, true, ctx);
257266

258-
RewriterStatement.MatcherContext matcher = RewriterStatement.MatcherContext.exactMatch(ctx, canonicalFormTo);
267+
RewriterStatement.MatcherContext matcher = RewriterStatement.MatcherContext.exactMatch(ctx, canonicalFormTo, canonicalFormFrom);
259268
canonicalFormFrom.match(matcher);
260269

261270
Map<RewriterStatement, RewriterStatement> assocs = new HashMap<>();

src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ public static class MatcherContext {
149149
final boolean traceVariableEliminations;
150150
final Map<RewriterStatement, RewriterRule.LinkObject> ruleLinks;
151151
final RewriterStatement expressionRoot;
152+
final RewriterStatement thisExpressionRoot;
152153
RewriterStatement matchRoot;
153154
RewriterPredecessor pred;
154155

@@ -161,16 +162,21 @@ public static class MatcherContext {
161162
private List<MatcherContext> subMatches;
162163
private Tuple2<RewriterStatement, RewriterStatement> firstMismatch;
163164
private boolean debug;
165+
private boolean assertionsFetched = false;
166+
private RewriterAssertions assertionsThat;
167+
private RewriterAssertions assertionsThis;
168+
private Set<RewriterStatement> dontVisitAgain;
164169

165-
public MatcherContext(final RuleContext ctx, RewriterStatement matchRoot, RewriterStatement expressionRoot) {
166-
this(ctx, matchRoot, expressionRoot, false, false, false, false, false, false, false, false, false, Collections.emptyMap());
170+
public MatcherContext(final RuleContext ctx, RewriterStatement matchRoot, RewriterStatement expressionRoot, RewriterStatement thisExpressionRoot) {
171+
this(ctx, matchRoot, expressionRoot, thisExpressionRoot, false, false, false, false, false, false, false, false, false, Collections.emptyMap());
167172
}
168173

169-
public MatcherContext(final RuleContext ctx, RewriterStatement matchRoot, RewriterStatement expressionRoot, final boolean statementsCanBeVariables, final boolean literalsCanBeVariables, final boolean ignoreLiteralValues, final boolean allowDuplicatePointers, final boolean allowPropertyScan, final boolean allowTypeHierarchy, final boolean terminateOnFirstMatch, final boolean findMinimalMismatchRoot, boolean traceVariableEliminations, final Map<RewriterStatement, RewriterRule.LinkObject> ruleLinks) {
174+
public MatcherContext(final RuleContext ctx, RewriterStatement matchRoot, RewriterStatement expressionRoot, RewriterStatement thisExpressionRoot, final boolean statementsCanBeVariables, final boolean literalsCanBeVariables, final boolean ignoreLiteralValues, final boolean allowDuplicatePointers, final boolean allowPropertyScan, final boolean allowTypeHierarchy, final boolean terminateOnFirstMatch, final boolean findMinimalMismatchRoot, boolean traceVariableEliminations, final Map<RewriterStatement, RewriterRule.LinkObject> ruleLinks) {
170175
this.ctx = ctx;
171176
this.matchRoot = matchRoot;
172177
this.pred = new RewriterPredecessor();
173178
this.expressionRoot = expressionRoot;
179+
this.thisExpressionRoot = thisExpressionRoot;
174180
this.statementsCanBeVariables = statementsCanBeVariables;
175181
this.currentStatement = matchRoot;
176182
this.literalsCanBeVariables = literalsCanBeVariables;
@@ -185,11 +191,12 @@ public MatcherContext(final RuleContext ctx, RewriterStatement matchRoot, Rewrit
185191
this.debug = false;
186192
}
187193

188-
public MatcherContext(final RuleContext ctx, RewriterStatement matchRoot, RewriterPredecessor pred, RewriterStatement expressionRoot, final boolean statementsCanBeVariables, final boolean literalsCanBeVariables, final boolean ignoreLiteralValues, final boolean allowDuplicatePointers, final boolean allowPropertyScan, final boolean allowTypeHierarchy, final boolean terminateOnFirstMatch, final boolean findMinimalMismatchRoot, boolean traceVariableEliminations, final Map<RewriterStatement, RewriterRule.LinkObject> ruleLinks) {
194+
public MatcherContext(final RuleContext ctx, RewriterStatement matchRoot, RewriterPredecessor pred, RewriterStatement expressionRoot, RewriterStatement thisExprRoot, final boolean statementsCanBeVariables, final boolean literalsCanBeVariables, final boolean ignoreLiteralValues, final boolean allowDuplicatePointers, final boolean allowPropertyScan, final boolean allowTypeHierarchy, final boolean terminateOnFirstMatch, final boolean findMinimalMismatchRoot, boolean traceVariableEliminations, final Map<RewriterStatement, RewriterRule.LinkObject> ruleLinks) {
189195
this.ctx = ctx;
190196
this.matchRoot = matchRoot;
191197
this.pred = pred;
192198
this.expressionRoot = expressionRoot;
199+
this.thisExpressionRoot = thisExprRoot;
193200
this.currentStatement = matchRoot;
194201
this.statementsCanBeVariables = statementsCanBeVariables;
195202
this.literalsCanBeVariables = literalsCanBeVariables;
@@ -204,6 +211,41 @@ public MatcherContext(final RuleContext ctx, RewriterStatement matchRoot, Rewrit
204211
this.debug = false;
205212
}
206213

214+
private void fetchAssertions() {
215+
if (!assertionsFetched) {
216+
assertionsThat = (RewriterAssertions) expressionRoot.getMeta("_assertions");
217+
assertionsThis = (RewriterAssertions) thisExpressionRoot.getMeta("_assertions");
218+
assertionsFetched = true;
219+
}
220+
}
221+
222+
public void dontVisitAgain(RewriterStatement stmt) {
223+
if (dontVisitAgain == null) {
224+
dontVisitAgain = new HashSet<>();
225+
}
226+
227+
dontVisitAgain.add(stmt);
228+
}
229+
230+
public boolean wasVisited(RewriterStatement stmt) {
231+
if (dontVisitAgain == null)
232+
return false;
233+
234+
return dontVisitAgain.contains(stmt);
235+
}
236+
237+
public RewriterAssertions getOldAssertionsThat() {
238+
fetchAssertions();
239+
240+
return assertionsThat;
241+
}
242+
243+
public RewriterAssertions getOldAssertionsThis() {
244+
fetchAssertions();
245+
246+
return assertionsThis;
247+
}
248+
207249
public Map<RewriterStatement, RewriterStatement> getDependencyMap() {
208250
if (dependencyMap == null)
209251
if (allowDuplicatePointers)
@@ -304,16 +346,16 @@ public boolean isDebug() {
304346
return debug;
305347
}
306348

307-
public static MatcherContext exactMatch(final RuleContext ctx, RewriterStatement stmt) {
308-
return new MatcherContext(ctx, stmt, stmt);
349+
public static MatcherContext exactMatch(final RuleContext ctx, RewriterStatement stmt, RewriterStatement thisExprRoot) {
350+
return new MatcherContext(ctx, stmt, stmt, thisExprRoot);
309351
}
310352

311-
public static MatcherContext exactMatchWithDifferentLiteralValues(final RuleContext ctx, RewriterStatement stmt) {
312-
return new MatcherContext(ctx, stmt, stmt, false, false, true, false, false, false, false, false, false, Collections.emptyMap());
353+
public static MatcherContext exactMatchWithDifferentLiteralValues(final RuleContext ctx, RewriterStatement stmt, RewriterStatement thisExprRoot) {
354+
return new MatcherContext(ctx, stmt, stmt, thisExprRoot, false, false, true, false, false, false, false, false, false, Collections.emptyMap());
313355
}
314356

315-
public static MatcherContext findMinimalDifference(final RuleContext ctx, RewriterStatement stmt) {
316-
return new MatcherContext(ctx, stmt, stmt, false, false, true, false, false, false, false, true, false, Collections.emptyMap());
357+
public static MatcherContext findMinimalDifference(final RuleContext ctx, RewriterStatement stmt, RewriterStatement thisExpressionRoot) {
358+
return new MatcherContext(ctx, stmt, stmt, thisExpressionRoot, false, false, true, false, false, false, false, true, false, Collections.emptyMap());
317359
}
318360
}
319361

src/main/java/org/apache/sysds/hops/rewriter/RewriterStatementEntry.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ public boolean equals(Object o) {
2525
return true;
2626
if (instr.structuralHashCode() != ((RewriterStatement)o).structuralHashCode())
2727
return false;
28-
return instr.match(new RewriterStatement.MatcherContext(ctx, (RewriterStatement) o, new RewriterStatement.RewriterPredecessor(), (RewriterStatement) o, false, false, false, false, false, false, true, false, false, new HashMap<>()));
28+
return instr.match(new RewriterStatement.MatcherContext(ctx, (RewriterStatement) o, new RewriterStatement.RewriterPredecessor(), (RewriterStatement) o, instr, false, false, false, false, false, false, true, false, false, new HashMap<>()));
2929
}
3030

3131
if (o.hashCode() != hashCode())
@@ -34,7 +34,7 @@ public boolean equals(Object o) {
3434
if (o instanceof RewriterStatementEntry) {
3535
if (instr == ((RewriterStatementEntry) o).instr)
3636
return true;
37-
return instr.match(new RewriterStatement.MatcherContext(ctx, ((RewriterStatementEntry) o).instr, new RewriterStatement.RewriterPredecessor(), ((RewriterStatementEntry) o).instr, false, false, false, false, false, false, true, false, false, new HashMap<>()));
37+
return instr.match(new RewriterStatement.MatcherContext(ctx, ((RewriterStatementEntry) o).instr, new RewriterStatement.RewriterPredecessor(), ((RewriterStatementEntry) o).instr, instr, false, false, false, false, false, false, true, false, false, new HashMap<>()));
3838
}
3939
return false;
4040
}

src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterClusteringTest.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,9 +357,9 @@ private boolean checkRelevance(List<RewriterStatement> stmts) {
357357
TopologicalSort.sort(stmt1, ctx);
358358
TopologicalSort.sort(stmt2, ctx);
359359

360-
if (!stmt1.match(RewriterStatement.MatcherContext.exactMatchWithDifferentLiteralValues(ctx, stmt2))) {
360+
if (!stmt1.match(RewriterStatement.MatcherContext.exactMatchWithDifferentLiteralValues(ctx, stmt2, stmt1))) {
361361
// TODO: Minimal difference can still prune valid rewrites (e.g. sum(A %*% B) -> sum(A * t(B)))
362-
RewriterStatement.MatcherContext mCtx = RewriterStatement.MatcherContext.findMinimalDifference(ctx, stmts.get(j));
362+
RewriterStatement.MatcherContext mCtx = RewriterStatement.MatcherContext.findMinimalDifference(ctx, stmts.get(j), stmts.get(i));
363363
stmts.get(i).match(mCtx);
364364
Tuple2<RewriterStatement, RewriterStatement> minimalDifference = mCtx.getFirstMismatch();
365365

@@ -372,7 +372,7 @@ private boolean checkRelevance(List<RewriterStatement> stmts) {
372372
minStmt1 = converter.apply(minStmt1);
373373
minStmt2 = converter.apply(minStmt2);
374374

375-
if (minStmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, minStmt2))) {
375+
if (minStmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, minStmt2, minStmt1))) {
376376
// Then the minimal difference does not imply equivalence
377377
// For now, just keep every result then
378378
match = false;

0 commit comments

Comments
 (0)