Skip to content

Commit 0e15664

Browse files
committed
Some more improvements (probably expensive)
1 parent ecd2e59 commit 0e15664

File tree

7 files changed

+381
-60
lines changed

7 files changed

+381
-60
lines changed

src/main/java/org/apache/sysds/hops/rewriter/RewriterAssertions.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@ public Set<RewriterStatement> getAssertions(RewriterStatement stmt) {
201201
}
202202

203203
public RewriterStatement getAssertionStatement(RewriterStatement stmt, RewriterStatement parent) {
204+
//System.out.println("Checking: " + stmt);
205+
//System.out.println("In: " + this);
204206
RewriterAssertion set = assertionMatcher.get(stmt);
205207

206208
if (set == null)
@@ -222,6 +224,32 @@ public RewriterStatement getAssertionStatement(RewriterStatement stmt, RewriterS
222224
return mstmt;
223225
}
224226

227+
// TODO: This does not handle metadata
228+
public RewriterStatement update(RewriterStatement root) {
229+
RewriterStatement eClass = getAssertionStatement(root, null);
230+
231+
if (eClass == null)
232+
eClass = root;
233+
else if (root.getMeta("_assertions") != null)
234+
eClass.unsafePutMeta("_assertions", root.getMeta("_assertions"));
235+
236+
updateRecursively(eClass);
237+
238+
return eClass;
239+
}
240+
241+
private void updateRecursively(RewriterStatement cur) {
242+
for (int i = 0; i < cur.getOperands().size(); i++) {
243+
RewriterStatement child = cur.getChild(i);
244+
RewriterStatement eClass = getAssertionStatement(child, cur);
245+
246+
if (eClass != child)
247+
cur.getOperands().set(i, eClass);
248+
249+
updateRecursively(cur.getChild(i));
250+
}
251+
}
252+
225253
// TODO: We have to copy the assertions to the root node if it changes
226254
/*public RewriterStatement buildEquivalences(RewriterStatement stmt) {
227255
RewriterStatement mAssert = getAssertionStatement(stmt);

src/main/java/org/apache/sysds/hops/rewriter/RewriterContextSettings.java

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -293,15 +293,20 @@ public static String getDefaultContextString() {
293293
builder.append("[](MATRIX,INT,INT)::FLOAT\n");
294294
builder.append("[](MATRIX,INT,INT,INT,INT)::MATRIX\n");
295295
builder.append("diag(MATRIX)::MATRIX\n");
296-
builder.append("sum(FLOAT...)::FLOAT\n");
297-
builder.append("sum(FLOAT*)::FLOAT\n");
298-
builder.append("sum(FLOAT)::FLOAT\n");
296+
297+
List.of("INT", "FLOAT", "BOOL").forEach(t -> {
298+
builder.append("sum(" + t + "...)::" + t + "\n");
299+
builder.append("sum(" + t + "*)::" + t + "\n");
300+
builder.append("sum(" + t + ")::" + t + "\n");
301+
});
299302

300303
builder.append("_m(INT,INT,FLOAT)::MATRIX\n");
301-
builder.append("_idxExpr(INT,FLOAT)::FLOAT*\n");
302-
builder.append("_idxExpr(INT,FLOAT*)::FLOAT*\n");
303-
builder.append("_idxExpr(INT...,FLOAT)::FLOAT*\n");
304-
builder.append("_idxExpr(INT...,FLOAT*)::FLOAT*\n");
304+
List.of("FLOAT", "INT", "BOOL").forEach(t -> {
305+
builder.append("_idxExpr(INT," + t + ")::" + t + "*\n");
306+
builder.append("_idxExpr(INT," + t + "*)::" + t + "*\n");
307+
builder.append("_idxExpr(INT...," + t + ")::" + t + "*\n");
308+
builder.append("_idxExpr(INT...," + t + "*)::" + t + "*\n");
309+
});
305310
//builder.append("_idxExpr(INT,FLOAT...)::FLOAT*\n");
306311
builder.append("_idx(INT,INT)::INT\n");
307312
//builder.append("_nrow()::INT\n");

src/main/java/org/apache/sysds/hops/rewriter/RewriterInstruction.java

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,9 @@
1717

1818
public class RewriterInstruction extends RewriterStatement {
1919

20+
private String id;
2021
private String instr;
21-
private RewriterDataType result = new RewriterDataType();
22+
//private RewriterDataType result = new RewriterDataType();
2223
private ArrayList<RewriterStatement> operands = new ArrayList<>();
2324
private Function<List<RewriterStatement>, Long> costFunction = null;
2425
private boolean consolidated = false;
@@ -28,15 +29,15 @@ public class RewriterInstruction extends RewriterStatement {
2829

2930
@Override
3031
public String getId() {
31-
return result.getId();
32+
return id;
3233
}
3334

3435
@Override
3536
public String getResultingDataType(final RuleContext ctx) {
3637
if (isArgumentList()) {
3738
return getOperands().stream().map(op -> op.getResultingDataType(ctx)).reduce(RewriterUtils::defaultTypeHierarchy).get() + "...";
3839
}
39-
return getResult(ctx).getResultingDataType(ctx);
40+
return ctx.instrTypes.get(trueTypedInstruction(ctx));//getResult(ctx).getResultingDataType(ctx);
4041
}
4142

4243
@Override
@@ -63,27 +64,21 @@ public RewriterStatement consolidate(final RuleContext ctx) {
6364
for (RewriterStatement operand : operands)
6465
operand.consolidate(ctx);
6566

66-
getResult(ctx).consolidate(ctx);
67+
//getResult(ctx).consolidate(ctx);
6768

68-
if (isArgumentList())
69-
hashCode = Objects.hash(rid, refCtr, instr, getResultingDataType(ctx), operands);
70-
else
71-
hashCode = Objects.hash(rid, refCtr, instr, result, operands);
69+
hashCode = Objects.hash(rid, refCtr, instr, getResultingDataType(ctx), operands);
7270
consolidated = true;
7371

7472
return this;
7573
}
7674
@Override
7775
public int recomputeHashCodes(boolean recursively, final RuleContext ctx) {
7876
if (recursively) {
79-
result.recomputeHashCodes(true, ctx);
77+
//result.recomputeHashCodes(true, ctx);
8078
operands.forEach(op -> op.recomputeHashCodes(true, ctx));
8179
}
8280

83-
if (isArgumentList())
84-
hashCode = Objects.hash(rid, refCtr, instr, getResultingDataType(ctx), operands.stream().map(RewriterStatement::structuralHashCode).collect(Collectors.toList()));
85-
else
86-
hashCode = Objects.hash(rid, refCtr, instr, result.structuralHashCode(), operands.stream().map(RewriterStatement::structuralHashCode).collect(Collectors.toList()));
81+
hashCode = Objects.hash(rid, refCtr, instr, getResultingDataType(ctx), operands.stream().map(RewriterStatement::structuralHashCode).collect(Collectors.toList()));
8782
return hashCode;
8883
}
8984

@@ -148,7 +143,8 @@ public boolean match(final MatcherContext mCtx) {
148143
public RewriterStatement copyNode() {
149144
RewriterInstruction mCopy = new RewriterInstruction();
150145
mCopy.instr = instr;
151-
mCopy.result = (RewriterDataType)result.copyNode();
146+
//mCopy.result = (RewriterDataType)result.copyNode();
147+
mCopy.id = id;
152148
mCopy.costFunction = costFunction;
153149
mCopy.consolidated = consolidated;
154150
mCopy.operands = new ArrayList<>(operands);
@@ -173,7 +169,8 @@ public RewriterStatement nestedCopyOrInject(Map<RewriterStatement, RewriterState
173169

174170
RewriterInstruction mCopy = new RewriterInstruction();
175171
mCopy.instr = instr;
176-
mCopy.result = (RewriterDataType)result.copyNode();
172+
//mCopy.result = (RewriterDataType)result.copyNode();
173+
mCopy.id = id;
177174
mCopy.costFunction = costFunction;
178175
mCopy.consolidated = consolidated;
179176
mCopy.operands = new ArrayList<>(operands.size());
@@ -211,7 +208,8 @@ public boolean isInstruction() {
211208
public RewriterStatement clone() {
212209
RewriterInstruction mClone = new RewriterInstruction();
213210
mClone.instr = instr;
214-
mClone.result = (RewriterDataType)result.clone();
211+
//mClone.result = (RewriterDataType)result.clone();
212+
mClone.id = id;
215213
ArrayList<RewriterStatement> clonedOperands = new ArrayList<>(operands.size());
216214

217215
for (RewriterStatement stmt : operands)
@@ -224,13 +222,13 @@ public RewriterStatement clone() {
224222
return mClone;
225223
}
226224

227-
public void injectData(final RuleContext ctx, RewriterInstruction origData) {
225+
/*public void injectData(final RuleContext ctx, RewriterInstruction origData) {
228226
instr = origData.instr;
229227
result = (RewriterDataType)origData.getResult(ctx).copyNode();
230228
operands = new ArrayList<>(origData.operands);
231229
costFunction = origData.costFunction;
232230
meta = origData.meta;
233-
}
231+
}*/
234232

235233
/*public RewriterInstruction withLinks(DualHashBidiMap<RewriterStatement, RewriterStatement> links) {
236234
this.links = links;
@@ -328,11 +326,11 @@ public Optional<RewriterStatement> findOperand(String id) {
328326
public RewriterInstruction as(String id) {
329327
if (consolidated)
330328
throw new IllegalArgumentException("An instruction cannot be modified after consolidation");
331-
this.result.as(id);
329+
this.id = id;
332330
return this;
333331
}
334332

335-
public RewriterDataType getResult(final RuleContext ctx) {
333+
/*public RewriterDataType getResult(final RuleContext ctx) {
336334
if (this.result.getType() == null) {
337335
String type = ctx.instrTypes.get(typedInstruction(ctx));
338336
@@ -343,7 +341,7 @@ public RewriterDataType getResult(final RuleContext ctx) {
343341
}
344342
345343
return this.result;
346-
}
344+
}*/
347345

348346
public String typedInstruction(final RuleContext ctx) {
349347
return typedInstruction(this.instr, ctx);

src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCollection.java

Lines changed: 108 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,8 +1062,6 @@ public static void pushdownStreamSelections(final List<RewriterRule> rules, fina
10621062
);
10631063
});
10641064

1065-
1066-
10671065
rules.add(new RewriterRuleBuilder(ctx, "sum(sum(v)) => sum(v)")
10681066
.setUnidirectional(true)
10691067
.parseGlobalVars("MATRIX:A,B")
@@ -1151,6 +1149,85 @@ public static void pushdownStreamSelections(final List<RewriterRule> rules, fina
11511149
}
11521150
}
11531151

1152+
// This expands the statements to a common canonical form
1153+
// It is important, however, that
1154+
public static void canonicalExpandAfterFlattening(final List<RewriterRule> rules, final RuleContext ctx) {
1155+
HashMap<Integer, RewriterStatement> hooks = new HashMap<>();
1156+
1157+
rules.add(new RewriterRuleBuilder(ctx, "sum($1:_idxExpr(indices, -(A))) => -(sum($2:_idxExpr(indices, A)))")
1158+
.setUnidirectional(true)
1159+
.parseGlobalVars("FLOAT:a")
1160+
.parseGlobalVars("INT...:indices")
1161+
.withParsedStatement("sum($1:_idxExpr(indices, -(a)))", hooks)
1162+
.toParsedStatement("-(sum($2:_idxExpr(indices, a)))", hooks)
1163+
.link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta)
1164+
.build()
1165+
);
1166+
1167+
rules.add(new RewriterRuleBuilder(ctx, "sum($1:_idxExpr(indices, -(a))) => -(sum($2:_idxExpr(indices, a)))")
1168+
.setUnidirectional(true)
1169+
.parseGlobalVars("INT:a")
1170+
.parseGlobalVars("INT...:indices")
1171+
.withParsedStatement("sum($1:_idxExpr(indices, -(a)))", hooks)
1172+
.toParsedStatement("-(sum($2:_idxExpr(indices, a)))", hooks)
1173+
.link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta)
1174+
.build()
1175+
);
1176+
1177+
rules.add(new RewriterRuleBuilder(ctx, "sum(_idxExpr(indices, +(ops))) => +(argList(sum(_idxExpr(indices, op1)), sum(_idxExpr(...)), ...))")
1178+
.setUnidirectional(true)
1179+
.parseGlobalVars("INT...:indices")
1180+
.parseGlobalVars("FLOAT...:ops")
1181+
.withParsedStatement("sum($1:_idxExpr(indices, +(ops)))", hooks)
1182+
.toParsedStatement("+($3:argList(sum($2:_idxExpr(indices, +(ops)))))", hooks) // The inner +(ops) is temporary and will be removed
1183+
.link(hooks.get(1).getId(), hooks.get(2).getId(), RewriterStatement::transferMeta)
1184+
.apply(hooks.get(3).getId(), newArgList -> {
1185+
RewriterStatement oldArgList = newArgList.getChild(0, 0, 1, 0);
1186+
newArgList.getChild(0, 0).getOperands().set(1, oldArgList.getChild(0));
1187+
1188+
for (int i = 1; i < oldArgList.getOperands().size(); i++) {
1189+
RewriterStatement newIdxExpr = newArgList.getChild(0, 0).copyNode();
1190+
newIdxExpr.getOperands().set(1, oldArgList.getChild(i));
1191+
RewriterStatement newSum = new RewriterInstruction()
1192+
.as(UUID.randomUUID().toString())
1193+
.withInstruction("sum")
1194+
.withOps(newIdxExpr)
1195+
.consolidate(ctx);
1196+
RewriterUtils.copyIndexList(newIdxExpr);
1197+
newArgList.getOperands().add(newSum);
1198+
}
1199+
}, true)
1200+
.build()
1201+
);
1202+
}
1203+
1204+
public static void flattenedAlgebraRewrites(final List<RewriterRule> rules, final RuleContext ctx) {
1205+
HashMap<Integer, RewriterStatement> hooks = new HashMap<>();
1206+
1207+
// Minus pushdown
1208+
rules.add(new RewriterRuleBuilder(ctx, "-(+(...)) => +(-(el1), -(el2), ...)")
1209+
.setUnidirectional(true)
1210+
.parseGlobalVars("FLOAT...:ops")
1211+
.withParsedStatement("-(+(ops))", hooks)
1212+
.toParsedStatement("$1:+(ops)", hooks) // Temporary
1213+
.apply(hooks.get(1).getId(), (stmt, match) -> {
1214+
RewriterStatement argList = stmt.getChild(0);
1215+
1216+
for (int i = 0; i < argList.getOperands().size(); i++) {
1217+
RewriterInstruction newStmt = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("-").withOps(argList.getOperands().get(i));
1218+
newStmt.consolidate(ctx);
1219+
argList.getOperands().set(i, newStmt);
1220+
}
1221+
1222+
// TODO: This is inefficient
1223+
RewriterUtils.tryFlattenNestedOperatorPatterns(ctx, match.getNewExprRoot());
1224+
}, true)
1225+
.build()
1226+
);
1227+
1228+
// TODO: Distributive law
1229+
}
1230+
11541231
public static void buildElementWiseAlgebraicCanonicalization(final List<RewriterRule> rules, final RuleContext ctx) {
11551232
RewriterUtils.buildTernaryPermutations(List.of("FLOAT", "INT", "BOOL"), (t1, t2, t3) -> {
11561233
rules.add(new RewriterRuleBuilder(ctx, "*(+(a, b), c) => +(*(a, c), *(b, c))")
@@ -1215,23 +1292,37 @@ public static void flattenOperations(final List<RewriterRule> rules, final RuleC
12151292
HashMap<Integer, RewriterStatement> hooks = new HashMap<>();
12161293

12171294
RewriterUtils.buildBinaryPermutations(List.of("INT", "INT..."), (t1, t2) -> {
1218-
rules.add(new RewriterRuleBuilder(ctx)
1219-
.setUnidirectional(true)
1220-
.parseGlobalVars(t1 + ":i")
1221-
.parseGlobalVars(t2 + ":j")
1222-
.parseGlobalVars("FLOAT:v")
1223-
.withParsedStatement("$1:_idxExpr(i, $2:_idxExpr(j, v))", hooks)
1224-
.toParsedStatement("$3:_idxExpr(argList(i, j), v)", hooks)
1225-
.link(hooks.get(1).getId(), hooks.get(3).getId(), RewriterStatement::transferMeta)
1226-
.apply(hooks.get(3).getId(), (stmt, match) -> {
1227-
UUID newOwnerId = (UUID)stmt.getMeta("ownerId");
1295+
for (String t3 : List.of("FLOAT", "FLOAT*", "INT", "INT*", "BOOL", "BOOL*")) {
1296+
rules.add(new RewriterRuleBuilder(ctx)
1297+
.setUnidirectional(true)
1298+
.parseGlobalVars(t1 + ":i")
1299+
.parseGlobalVars(t2 + ":j")
1300+
.parseGlobalVars(t3 + ":v")
1301+
.withParsedStatement("$1:_idxExpr(i, $2:_idxExpr(j, v))", hooks)
1302+
.toParsedStatement("$3:_idxExpr(argList(i, j), v)", hooks)
1303+
.link(hooks.get(1).getId(), hooks.get(3).getId(), RewriterStatement::transferMeta)
1304+
.apply(hooks.get(3).getId(), (stmt, match) -> {
1305+
UUID newOwnerId = (UUID) stmt.getMeta("ownerId");
12281306

1229-
if (newOwnerId == null)
1230-
throw new IllegalArgumentException();
1307+
if (newOwnerId == null)
1308+
throw new IllegalArgumentException();
12311309

1232-
stmt.getOperands().get(0).getOperands().get(1).unsafePutMeta("ownerId", newOwnerId);
1233-
}, true)
1234-
.build());
1310+
stmt.getOperands().get(0).getOperands().get(1).unsafePutMeta("ownerId", newOwnerId);
1311+
}, true)
1312+
.build());
1313+
1314+
if (t1.equals("INT")) {
1315+
// This must be executed after the rule above
1316+
rules.add(new RewriterRuleBuilder(ctx)
1317+
.setUnidirectional(true)
1318+
.parseGlobalVars(t1 + ":i")
1319+
.parseGlobalVars(t3 + ":v")
1320+
.withParsedStatement("$1:_idxExpr(i, v)", hooks)
1321+
.toParsedStatement("$3:_idxExpr(argList(i), v)", hooks)
1322+
.link(hooks.get(1).getId(), hooks.get(3).getId(), RewriterStatement::transferMeta)
1323+
.build());
1324+
}
1325+
}
12351326
});
12361327

12371328
RewriterUtils.buildBinaryPermutations(List.of("MATRIX", "INT", "FLOAT", "BOOL"), (t1, t2) -> {

src/main/java/org/apache/sysds/hops/rewriter/RewriterStatement.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,10 @@ public RewriterStatement getNRow() {
596596
return (RewriterStatement) getMeta("nrow");
597597
}
598598

599+
public RewriterStatement getChild(int index) {
600+
return getOperands().get(index);
601+
}
602+
599603
public RewriterStatement getChild(int... indices) {
600604
RewriterStatement current = this;
601605

@@ -605,6 +609,21 @@ public RewriterStatement getChild(int... indices) {
605609
return current;
606610
}
607611

612+
// This can only be called from the root expression to add a new assertion manually
613+
public RewriterStatement givenThatEqual(RewriterStatement stmt1, RewriterStatement stmt2, final RuleContext ctx) {
614+
getAssertions(ctx).addEqualityAssertion(stmt1, stmt2);
615+
return this;
616+
}
617+
618+
public RewriterStatement recomputeAssertions() {
619+
RewriterAssertions assertions = (RewriterAssertions) getMeta("_assertions");
620+
621+
if (assertions != null)
622+
return assertions.update(this);
623+
624+
return this;
625+
}
626+
608627
public static void transferMeta(RewriterRule.ExplicitLink link) {
609628
if (link.oldStmt instanceof RewriterInstruction) {
610629
for (RewriterStatement mNew : link.newStmt) {

0 commit comments

Comments
 (0)