Skip to content

Commit ecd2e59

Browse files
committed
Some improvements
1 parent 3337477 commit ecd2e59

File tree

11 files changed

+323
-209
lines changed

11 files changed

+323
-209
lines changed

src/main/java/org/apache/sysds/hops/rewriter/MetaPropagator.java

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,13 @@ private RewriterStatement propagateDims(RewriterStatement root, RewriterStatemen
151151
root.unsafePutMeta("nrow", firstMatrixStatement.get().getMeta("nrow"));
152152
root.unsafePutMeta("ncol", firstMatrixStatement.get().getMeta("ncol"));
153153
return null;
154+
case "cast.MATRIX":
155+
String mDT = root.getChild(0).getResultingDataType(ctx);
156+
if (mDT.equals("BOOL") || mDT.equals("INT") || mDT.equals("FLOAT")) {
157+
root.unsafePutMeta("ncol", new RewriterDataType().ofType("INT").as("1").asLiteral(1L));
158+
root.unsafePutMeta("nrow", new RewriterDataType().ofType("INT").as("1").asLiteral(1L));
159+
return null;
160+
}
154161
}
155162

156163
switch(root.trueTypedInstruction(ctx)) {
@@ -186,23 +193,43 @@ private RewriterStatement propagateDims(RewriterStatement root, RewriterStatemen
186193
root.unsafePutMeta("ncol", new RewriterDataType().ofType("INT").as("1").asLiteral(1L));
187194
return null;
188195
case "[](MATRIX,INT,INT,INT,INT)":
189-
Integer[] ints = new Integer[4];
196+
Long[] ints = new Long[4];
190197

191-
for (int i = 0; i < 4; i++)
192-
if (root.getOperands().get(1).isLiteral())
193-
ints[i] = (Integer)root.getOperands().get(1).getLiteral();
198+
for (int i = 1; i < 5; i++)
199+
if (root.getChild(i).isLiteral())
200+
if (root.getChild(i).getLiteral() instanceof Integer)
201+
ints[i-1] = (Long)root.getChild(i).getLiteral();
194202

195203
if (ints[0] != null && ints[1] != null) {
196-
root.unsafePutMeta("nrow", ints[1] - ints[0] + 1);
204+
String literalString = Long.toString(ints[1] - ints[0] + 1);
205+
root.unsafePutMeta("nrow", RewriterUtils.parse(literalString, ctx, "LITERAL_INT:" + literalString));
197206
} else {
198-
throw new NotImplementedException();
199-
// TODO:
207+
HashMap<String, RewriterStatement> subStmts = new HashMap<>();
208+
subStmts.put("i1", root.getOperands().get(2));
209+
subStmts.put("i0", root.getOperands().get(1));
210+
211+
if (ints[0] != null) {
212+
root.unsafePutMeta("nrow", RewriterUtils.parse("+(i1, " + (1 - ints[0]) + ")", ctx, subStmts, "LITERAL_INT:" + (1 - ints[0])));
213+
} else if (ints[1] != null) {
214+
root.unsafePutMeta("nrow", RewriterUtils.parse("-(" + (ints[1] + 1) + ", i0)", ctx, subStmts, "LITERAL_INT:" + (ints[1] + 1)));
215+
} else {
216+
root.unsafePutMeta("nrow", RewriterUtils.parse("+(-(i1, i0), 1)", ctx, subStmts, "LITERAL_INT:1"));
217+
}
200218
}
201219

202220
if (ints[2] != null && ints[3] != null) {
203221
root.unsafePutMeta("ncol", ints[3] - ints[2] + 1);
204222
} else {
205-
throw new NotImplementedException();
223+
HashMap<String, RewriterStatement> subStmts = new HashMap<>();
224+
subStmts.put("i3", root.getOperands().get(4));
225+
subStmts.put("i2", root.getOperands().get(3));
226+
if (ints[2] != null) {
227+
root.unsafePutMeta("ncol", RewriterUtils.parse("+(i3, " + (1 - ints[2]) + ")", ctx, subStmts, "LITERAL_INT:" + (1 - ints[2])));
228+
} else if (ints[3] != null) {
229+
root.unsafePutMeta("ncol", RewriterUtils.parse("-(" + (ints[3] + 1) + ", i2)", ctx, subStmts, "LITERAL_INT:" + (ints[3] + 1)));
230+
} else {
231+
root.unsafePutMeta("ncol", RewriterUtils.parse("+(-(i3, i2), 1)", ctx, subStmts, "LITERAL_INT:1"));
232+
}
206233
}
207234

208235
return null;
@@ -214,6 +241,10 @@ private RewriterStatement propagateDims(RewriterStatement root, RewriterStatemen
214241
root.unsafePutMeta("ncol", root.getOperands().get(0).getMeta("ncol"));
215242
root.unsafePutMeta("nrow", new RewriterDataType().ofType("INT").as("1").asLiteral(1L));
216243
return null;
244+
case "cast.MATRIX(MATRIX)":
245+
root.unsafePutMeta("nrow", root.getOperands().get(0).getMeta("nrow"));
246+
root.unsafePutMeta("ncol", root.getOperands().get(0).getMeta("ncol"));
247+
return null;
217248
}
218249

219250
RewriterInstruction instr = (RewriterInstruction) root;
@@ -230,6 +261,18 @@ private RewriterStatement propagateDims(RewriterStatement root, RewriterStatemen
230261
return null;
231262
}
232263

264+
if (instr.getProperties(ctx).contains("ElementWiseUnary.FLOAT")) {
265+
if (root.getOperands().get(0).getResultingDataType(ctx).startsWith("MATRIX")) {
266+
root.unsafePutMeta("nrow", root.getOperands().get(0).getMeta("nrow"));
267+
root.unsafePutMeta("ncol", root.getOperands().get(0).getMeta("ncol"));
268+
} else {
269+
root.unsafePutMeta("nrow", root.getOperands().get(1).getMeta("nrow"));
270+
root.unsafePutMeta("ncol", root.getOperands().get(1).getMeta("ncol"));
271+
}
272+
273+
return null;
274+
}
275+
233276
throw new NotImplementedException("Unknown instruction: " + instr.trueTypedInstruction(ctx) + "\n" + instr.toString(ctx));
234277
}
235278

src/main/java/org/apache/sysds/hops/rewriter/RewriterAssertions.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,8 @@ public boolean addEqualityAssertion(RewriterStatement stmt1, RewriterStatement s
123123
if (stmt1 == stmt2 || (stmt1.isLiteral() && stmt2.isLiteral() && stmt1.getLiteral().equals(stmt2.getLiteral())))
124124
return false;
125125

126-
if (!(stmt1 instanceof RewriterInstruction) || !(stmt2 instanceof RewriterInstruction))
127-
throw new UnsupportedOperationException("Asserting uninjectable objects is not yet supported: " + stmt1 + "; " + stmt2);
126+
//if (!(stmt1 instanceof RewriterInstruction) || !(stmt2 instanceof RewriterInstruction))
127+
// throw new UnsupportedOperationException("Asserting uninjectable objects is not yet supported: " + stmt1 + "; " + stmt2);
128128

129129
//System.out.println("Asserting: " + stmt1 + " := " + stmt2);
130130

src/main/java/org/apache/sysds/hops/rewriter/RewriterContextSettings.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,10 @@ public static String getDefaultContextString() {
227227
builder.append("as.float(BOOL)::FLOAT\n");
228228
builder.append("as.int(BOOL)::INT\n");
229229

230+
RewriterUtils.buildBinaryPermutations(ALL_TYPES, (tFrom, tTo) -> {
231+
builder.append("cast." + tTo + "(" + tFrom + ")::" + tTo + "\n");
232+
});
233+
230234
builder.append("max(MATRIX)::FLOAT\n");
231235
builder.append("min(MATRIX)::FLOAT\n");
232236

@@ -260,6 +264,12 @@ public static String getDefaultContextString() {
260264
builder.append("/(" + t1+ "," + t2 + ")::" + RewriterUtils.defaultTypeHierarchy(t1, t2) + "\n");
261265
});
262266

267+
// Unary ops
268+
ALL_TYPES.forEach(t -> {
269+
builder.append("ElementWiseUnary.FLOAT(" + t + ")::" + (t.equals("MATRIX") ? "MATRIX" : "FLOAT") + "\n");
270+
builder.append("impl sqrt\n");
271+
});
272+
263273

264274
// Meta-Instruction
265275
builder.append("_lower(INT)::FLOAT\n");

src/main/java/org/apache/sysds/hops/rewriter/RewriterDataType.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ public boolean match(final MatcherContext mCtx) {
132132

133133
RewriterStatement assoc = mCtx.getDependencyMap().get(this);
134134
if (assoc == null) {
135-
// TODO: This is very inefficient
136135
if (!mCtx.allowDuplicatePointers && mCtx.getDependencyMap().containsValue(stmt))
137136
return false; // Then the statement variable is already associated with another variable
138137
mCtx.getDependencyMap().put(this, stmt);

src/main/java/org/apache/sysds/hops/rewriter/RewriterHeuristic.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ public RewriterStatement apply(RewriterStatement currentStmt, @Nullable BiFuncti
5858
foundRewrite.setValue(true);
5959

6060
while (rule != null) {
61-
//System.out.println("Pre-apply: " + rule.rule.getName());
61+
/*System.out.println("Pre-apply: " + rule.rule.getName());
62+
System.out.println("Expr: " + rule.matches.get(0).getExpressionRoot().toParsableString(ruleSet.getContext()));
63+
System.out.println("At: " + rule.matches.get(0).getMatchRoot().toParsableString(ruleSet.getContext()));*/
6264
currentStmt = rule.rule.apply(rule.matches.get(0), currentStmt, rule.forward, false);
6365

6466
if (handler != null && !handler.apply(currentStmt, rule.rule))

src/main/java/org/apache/sysds/hops/rewriter/RewriterInstruction.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,14 @@ public boolean match(final MatcherContext mCtx) {
114114
if (this.operands.size() != inst.operands.size())
115115
return false;
116116

117-
RewriterStatement existingRef = mCtx.findInternalReference(new RewriterRule.IdentityRewriterStatement(this));
117+
RewriterStatement existingRef = mCtx.findInternalReference(this);
118+
118119
if (existingRef != null)
119120
return existingRef == stmt;
120121

122+
if (!mCtx.allowDuplicatePointers && mCtx.getInternalReferences().containsValue(stmt))
123+
return false;
124+
121125
RewriterRule.LinkObject ruleLink = mCtx.ruleLinks.get(this);
122126

123127
if (ruleLink != null)
@@ -127,12 +131,12 @@ public boolean match(final MatcherContext mCtx) {
127131

128132
for (int i = 0; i < s; i++) {
129133
mCtx.currentStatement = inst.operands.get(i);
130-
if (!operands.get(i).match(mCtx)) {
134+
135+
if (!operands.get(i).match(mCtx))
131136
return false;
132-
}
133137
}
134138

135-
mCtx.getInternalReferences().put(new RewriterRule.IdentityRewriterStatement(this), stmt);
139+
mCtx.getInternalReferences().put(this, stmt);
136140

137141
return true;
138142
}

0 commit comments

Comments
 (0)