Skip to content

Commit 5f9be9e

Browse files
committed
Some improvements
1 parent 616c199 commit 5f9be9e

File tree

10 files changed

+115
-28
lines changed

10 files changed

+115
-28
lines changed

src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import org.apache.sysds.hops.OptimizerUtils;
3030
import org.apache.sysds.hops.rewriter.GeneratedRewriteClass;
3131
import org.apache.sysds.hops.rewriter.RewriteAutomaticallyGenerated;
32+
import org.apache.sysds.hops.rewriter.dml.DMLExecutor;
3233
import org.apache.sysds.parser.DMLProgram;
3334
import org.apache.sysds.parser.ForStatement;
3435
import org.apache.sysds.parser.ForStatementBlock;
@@ -143,6 +144,9 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites)
143144
_dagRuleSet.add(new RewriteAutomaticallyGenerated(new GeneratedRewriteClass()));
144145
}
145146
}
147+
148+
if (DMLExecutor.APPLY_INJECTED_REWRITES)
149+
_dagRuleSet.add(new RewriteAutomaticallyGenerated(DMLExecutor.REWRITE_FUNCTION));
146150
}
147151

148152
// cleanup after all rewrites applied

src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleCreator.java

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
import org.apache.sysds.hops.rewriter.estimators.RewriterCostEstimator;
99
import org.apache.sysds.hops.rewriter.utils.RewriterUtils;
1010

11+
import javax.annotation.Nullable;
1112
import java.util.ArrayList;
1213
import java.util.Collections;
1314
import java.util.HashMap;
1415
import java.util.LinkedList;
1516
import java.util.List;
1617
import java.util.Map;
18+
import java.util.Objects;
1719
import java.util.Set;
1820
import java.util.UUID;
1921
import java.util.function.Consumer;
@@ -219,13 +221,16 @@ public static boolean validateRuleCorrectness(RewriterRule rule, final RuleConte
219221
}
220222

221223
public static boolean validateRuleApplicability(RewriterRule rule, final RuleContext ctx) {
222-
return validateRuleApplicability(rule, ctx, false);
224+
return validateRuleApplicability(rule, ctx, false, null);
223225
}
224226

225-
public static boolean validateRuleApplicability(RewriterRule rule, final RuleContext ctx, boolean print) {
227+
public static boolean validateRuleApplicability(RewriterRule rule, final RuleContext ctx, boolean print, @Nullable Function<Hop, Hop> injectedRewriteClass) {
226228
RewriterStatement _mstmt = rule.getStmt1();
227-
if (ctx.metaPropagator != null)
229+
RewriterStatement _mstmt2 = rule.getStmt2();
230+
if (ctx.metaPropagator != null) {
228231
ctx.metaPropagator.apply(_mstmt);
232+
ctx.metaPropagator.apply(_mstmt2);
233+
}
229234

230235
final RewriterStatement stmt1 = RewriterUtils.unfuseOperators(_mstmt, ctx);
231236

@@ -243,6 +248,8 @@ public static boolean validateRuleApplicability(RewriterRule rule, final RuleCon
243248

244249
MutableBoolean isRelevant = new MutableBoolean(false);
245250

251+
final RewriterStatement expectedStmt = injectedRewriteClass != null ? _mstmt2 : _mstmt;
252+
246253
RewriterRuntimeUtils.attachHopInterceptor(prog -> {
247254
Hop hop;
248255

@@ -301,7 +308,7 @@ public static boolean validateRuleApplicability(RewriterRule rule, final RuleCon
301308

302309
Map<RewriterStatement, RewriterStatement> createdObjects = new HashMap<>();
303310

304-
RewriterStatement stmt1ReplaceNCols = _mstmt.nestedCopyOrInject(createdObjects, mstmt -> {
311+
RewriterStatement stmt1ReplaceNCols = expectedStmt.nestedCopyOrInject(createdObjects, mstmt -> {
305312
if (mstmt.isInstruction() && (mstmt.trueInstruction().equals("ncol") || mstmt.trueInstruction().equals("nrow")))
306313
return RewriterStatement.literal(ctx, DMLCodeGenerator.MATRIX_DIMS);
307314
return null;
@@ -310,7 +317,7 @@ public static boolean validateRuleApplicability(RewriterRule rule, final RuleCon
310317
stmt1ReplaceNCols.prepareForHashing();
311318
stmt1ReplaceNCols.recomputeHashCodes(ctx);
312319

313-
Set<RewriterStatement> mVars = vars.stream().map(createdObjects::get).collect(Collectors.toSet());
320+
Set<RewriterStatement> mVars = vars.stream().map(createdObjects::get).filter(Objects::nonNull).collect(Collectors.toSet());
314321

315322
if (print) {
316323
DMLExecutor.println("Observed statement: " + stmt.toParsableString(ctx));
@@ -322,34 +329,51 @@ public static boolean validateRuleApplicability(RewriterRule rule, final RuleCon
322329
// Check if also the right variables are associated
323330
boolean assocsMatching = true;
324331
//DMLExecutor.println(mCtx.getDependencyMap());
325-
for (RewriterStatement var : mVars) {
326-
RewriterStatement assoc = mCtx.getDependencyMap().get(var.isInstruction() && !var.trueInstruction().equals("const") ? var.getChild(0) : var);
332+
if (mCtx.getDependencyMap() != null) {
333+
for (RewriterStatement var : mVars) {
334+
//DMLExecutor.println("Var: " + var);
335+
RewriterStatement assoc = mCtx.getDependencyMap().get(var.isInstruction() && !var.trueInstruction().equals("const") ? var.getChild(0) : var);
327336

328-
if (assoc == null)
329-
throw new IllegalArgumentException("Association is null!");
337+
if (assoc == null)
338+
throw new IllegalArgumentException("Association is null!");
330339

331-
if (!assoc.getId().equals(var.getId())) {
332-
assocsMatching = false;
333-
break;
340+
if (!assoc.getId().equals(var.getId())) {
341+
assocsMatching = false;
342+
break;
343+
}
334344
}
335345
}
336346

337347
if (assocsMatching) {
338348
// Then the rule matches, meaning that the statement is not rewritten by SystemDS
339349
isRelevant.setValue(true);
350+
//DMLExecutor.println("MATCH");
340351
}
341352
}
342353

343354
// TODO: Maybe we can still rewrite the new graph if it still has less cost
344355

345356
// TODO: Evaluate cost and if our rule can still be applied
346-
return false; // The program should not be executed as we just want to extract any rewrites that are applied to the current statement
357+
return injectedRewriteClass != null; // The program should not be executed as we just want to extract any rewrites that are applied to the current statement
347358
});
348359

349-
DMLExecutor.executeCode(code2, true);
360+
MutableBoolean wasApplied = new MutableBoolean(true);
361+
362+
if (injectedRewriteClass != null) {
363+
String ruleStr = rule.toString();
364+
wasApplied.setValue(false);
365+
DMLExecutor.executeCode(code2, s -> {
366+
if (s.equals("Applying rewrite: " + ruleStr)) {
367+
wasApplied.setValue(true);
368+
}
369+
}, injectedRewriteClass);
370+
} else {
371+
DMLExecutor.executeCode(code2, true);
372+
}
373+
350374
RewriterRuntimeUtils.detachHopInterceptor();
351375

352-
return isRelevant.booleanValue();
376+
return isRelevant.booleanValue() && wasApplied.booleanValue();
353377
}
354378

355379
public static RewriterRule createRule(RewriterStatement from, RewriterStatement to, RewriterStatement canonicalForm1, RewriterStatement canonicalForm2, final RuleContext ctx) {

src/main/java/org/apache/sysds/hops/rewriter/RewriterRuleSet.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,29 @@ public String serialize(final RuleContext ctx) {
274274
return sb.toString();
275275
}
276276

277+
public boolean generateCodeAndTest(boolean optimize, boolean print) {
278+
String javaCode = toJavaCode("MGeneratedRewriteClass", optimize, false, true);
279+
Function<Hop, Hop> f = RewriterCodeGen.compile(javaCode, "MGeneratedRewriteClass");
280+
281+
if (f == null)
282+
return false; // Then, the code could not compile
283+
284+
int origSize = rules.size();
285+
286+
for (int i = 0; i < rules.size(); i++) {
287+
if (!RewriterRuleCreator.validateRuleApplicability(rules.get(i), ctx, print, f)) {
288+
System.out.println("Faulty rule: " + rules.get(i));
289+
rules.remove(i);
290+
i--;
291+
}
292+
}
293+
294+
if (rules.size() != origSize)
295+
accelerate();
296+
297+
return true;
298+
}
299+
277300
public static RewriterRuleSet deserialize(String data, final RuleContext ctx) {
278301
return deserialize(data.split("\n"), ctx);
279302
}

src/main/java/org/apache/sysds/hops/rewriter/RewriterRuntimeUtils.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848

4949
public class RewriterRuntimeUtils {
5050
public static final boolean interceptAll = false;
51-
public static final boolean printUnknowns = true;
51+
public static boolean printUnknowns = true;
5252
public static final String dbFile = "/Users/janniklindemann/Dev/MScThesis/expressions.db";
5353
public static final boolean readDB = true;
5454
public static final boolean writeDB = true;

src/main/java/org/apache/sysds/hops/rewriter/codegen/RewriterCodeGen.java

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,19 @@ public static Function<Hop, Hop> compileRewrites(String className, List<Tuple2<S
3434
return (Function<Hop, Hop>) instance;
3535
}
3636

37+
public static Function<Hop, Hop> compile(String javaCode, String className) {
38+
try {
39+
SimpleCompiler compiler = new SimpleCompiler();
40+
compiler.cook(javaCode);
41+
Class<?> mClass = compiler.getClassLoader().loadClass(className);
42+
Object instance = mClass.getDeclaredConstructor().newInstance();
43+
return (Function<Hop, Hop>) instance;
44+
} catch (Exception e) {
45+
e.printStackTrace();
46+
return null;
47+
}
48+
}
49+
3750
public static String generateClass(String className, List<Tuple2<String, RewriterRule>> rewrites, boolean optimize, boolean includePackageInfo, final RuleContext ctx, boolean ignoreErrors, boolean printErrors) {
3851
StringBuilder msb = new StringBuilder();
3952

@@ -54,6 +67,7 @@ public static String generateClass(String className, List<Tuple2<String, Rewrite
5467
msb.append("import org.apache.sysds.hops.TernaryOp;\n");
5568
msb.append("import org.apache.sysds.common.Types;\n");
5669
msb.append("import org.apache.sysds.hops.rewrite.HopRewriteUtils;\n");
70+
msb.append("import org.apache.sysds.hops.rewriter.dml.DMLExecutor;\n");
5771
msb.append("\n");
5872
msb.append("public class " + className + " implements Function {\n\n");
5973

@@ -165,8 +179,8 @@ private static void buildMatchingSequence(String name, RewriterStatement from, R
165179
recursivelyBuildMatchingSequence(from, sb, "hi", ctx, indentation, vars, allowedMultiRefs, allowCombinations);
166180

167181
if (fromCost != null && toCost != null) {
168-
System.out.println("FromCost: " + fromCost.toParsableString(ctx));
169-
System.out.println("ToCost: " + toCost.toParsableString(ctx));
182+
//System.out.println("FromCost: " + fromCost.toParsableString(ctx));
183+
//System.out.println("ToCost: " + toCost.toParsableString(ctx));
170184

171185
StringBuilder msb = new StringBuilder();
172186
StringBuilder msb2 = new StringBuilder();
@@ -232,6 +246,8 @@ private static void buildMatchingSequence(String name, RewriterStatement from, R
232246
if (DEBUG) {
233247
indent(indentation, sb);
234248
sb.append("System.out.println(\"Applying rewrite: " + name + "\");\n");
249+
//indent(indentation, sb);
250+
//sb.append("DMLExecutor.println(\"Applying rewrite: " + name + "\");\n");
235251
}
236252

237253
Set<RewriterStatement> activeStatements = buildRewrite(to, sb, combinedAssertions, vars, ctx, indentation);

src/main/java/org/apache/sysds/hops/rewriter/dml/DMLCodeGenerator.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,13 +240,13 @@ public static String generateDMLVariables(Set<RewriterStatement> vars) {
240240
continue;
241241
}
242242
}
243-
sb.append(mId + " = (rand(rows=" + nrow + ", cols=" + ncol + ") * rand(rows=" + nrow + ", cols=" + ncol + ", min=(as.scalar(rand())+1.0), max=(as.scalar(rand())+2.0), seed=" + rd.nextInt(1000) + "))^as.scalar(rand())\n");
243+
sb.append(mId + " = cos((rand(rows=" + nrow + ", cols=" + ncol + ") * rand(rows=" + nrow + ", cols=" + ncol + ", min=(as.scalar(rand())+1.0), max=(as.scalar(rand())+2.0), seed=" + rd.nextInt(1000) + "))^as.scalar(rand()))\n");
244244
break;
245245
case "FLOAT":
246-
sb.append(var.getId() + " = as.scalar(rand(min=(as.scalar(rand())+1.0), max=(as.scalar(rand())+2.0), seed=" + rd.nextInt(1000) + "))^as.scalar(rand())\n");
246+
sb.append(var.getId() + " = cos(as.scalar(rand(min=(as.scalar(rand())+1.0), max=(as.scalar(rand())+2.0), seed=" + rd.nextInt(1000) + "))^as.scalar(rand()))\n");
247247
break;
248248
case "INT":
249-
sb.append(var.getId() + " = as.integer(as.scalar(rand(min=(as.scalar(rand())+1.0), max=(as.scalar(rand()+200000.0)), seed=" + rd.nextInt(1000) + "))^as.scalar(rand()))\n");
249+
sb.append(var.getId() + " = as.integer(cos(as.scalar(rand(min=(as.scalar(rand())+1.0), max=(as.scalar(rand()+200000.0)), seed=" + rd.nextInt(1000) + "))^as.scalar(rand())))\n");
250250
break;
251251
case "BOOL":
252252
sb.append(var.getId() + " = as.scalar(rand()) < 0.5\n");

src/main/java/org/apache/sysds/hops/rewriter/dml/DMLExecutor.java

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,30 @@
11
package org.apache.sysds.hops.rewriter.dml;
22

33
import org.apache.sysds.api.DMLScript;
4+
import org.apache.sysds.hops.Hop;
45

56
import java.io.OutputStream;
67
import java.io.PrintStream;
78
import java.util.function.Consumer;
9+
import java.util.function.Function;
810

911
public class DMLExecutor {
1012
private static PrintStream origPrintStream = System.out;
1113

12-
public static synchronized void executeCode(String code, boolean intercept, String... additionalArgs) {
14+
public static boolean APPLY_INJECTED_REWRITES = false;
15+
public static Function<Hop, Hop> REWRITE_FUNCTION = null;
16+
17+
public static void executeCode(String code, boolean intercept, String... additionalArgs) {
1318
executeCode(code, intercept ? s -> {} : null, additionalArgs);
1419
}
1520

21+
public static void executeCode(String code, Consumer<String> consoleInterceptor, String... additionalArgs) {
22+
executeCode(code, consoleInterceptor, null, additionalArgs);
23+
}
24+
1625
// TODO: We will probably need some kind of watchdog
1726
// This cannot run in parallel
18-
public static synchronized void executeCode(String code, Consumer<String> consoleInterceptor, String... additionalArgs) {
27+
public static synchronized void executeCode(String code, Consumer<String> consoleInterceptor, Function<Hop, Hop> injectedRewriteClass, String... additionalArgs) {
1928
try {
2029
if (consoleInterceptor != null)
2130
System.setOut(new PrintStream(new CustomOutputStream(System.out, consoleInterceptor)));
@@ -27,12 +36,21 @@ public static synchronized void executeCode(String code, Consumer<String> consol
2736

2837
args[additionalArgs.length] = "-s";
2938
args[additionalArgs.length + 1] = code;
39+
40+
if (injectedRewriteClass != null) {
41+
APPLY_INJECTED_REWRITES = true;
42+
REWRITE_FUNCTION = injectedRewriteClass;
43+
}
44+
3045
DMLScript.executeScript(args);
3146

3247
} catch (Exception e) {
3348
e.printStackTrace();
3449
}
3550

51+
APPLY_INJECTED_REWRITES = false;
52+
REWRITE_FUNCTION = null;
53+
3654
if (consoleInterceptor != null)
3755
System.setOut(origPrintStream);
3856
}

src/test/java/org/apache/sysds/test/AutomatedTestBase.java

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1482,9 +1482,6 @@ private ByteArrayOutputStream runTestWithTimeout(boolean newWay, boolean excepti
14821482
TestUtils.printDMLScript(fullDMLScriptName);
14831483
}
14841484
}
1485-
1486-
// TODO
1487-
//args.add("-applyGeneratedRewrites");
14881485

14891486
ByteArrayOutputStream buff = outputBuffering ? new ByteArrayOutputStream() : null;
14901487
PrintStream old = System.out;

src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/CodeGenTests.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import org.apache.sysds.hops.ReorgOp;
1010
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
1111
import org.apache.sysds.hops.rewriter.RewriteAutomaticallyGenerated;
12+
import org.apache.sysds.hops.rewriter.RewriterRuntimeUtils;
1213
import org.apache.sysds.hops.rewriter.codegen.RewriterCodeGen;
1314
import org.apache.sysds.hops.rewriter.RewriterRule;
1415
import org.apache.sysds.hops.rewriter.RewriterRuleBuilder;
@@ -213,6 +214,10 @@ public void codeGen() {
213214
try {
214215
List<String> lines = Files.readAllLines(Paths.get(RewriteAutomaticallyGenerated.FILE_PATH));
215216
RewriterRuleSet ruleSet = RewriterRuleSet.deserialize(lines, ctx);
217+
218+
RewriterRuntimeUtils.printUnknowns = false;
219+
ruleSet.generateCodeAndTest(true, false);
220+
216221
RewriterCodeGen.DEBUG = true;
217222
String javaCode = ruleSet.toJavaCode("GeneratedRewriteClass", true, true, true);
218223
String filePath = "/Users/janniklindemann/Dev/MScThesis/other/GeneratedRewriteClass.java";

src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/DMLCodeGenTest.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ public void testFused4() {
201201

202202
assert RewriterRuleCreator.validateRuleCorrectness(rule, ctx);
203203

204-
assert RewriterRuleCreator.validateRuleApplicability(rule, ctx, true);
204+
assert RewriterRuleCreator.validateRuleApplicability(rule, ctx, true, null);
205205
}
206206

207207
@Test
@@ -224,6 +224,6 @@ public void testFused5() {
224224

225225
assert RewriterRuleCreator.validateRuleCorrectness(rule, ctx);
226226

227-
assert RewriterRuleCreator.validateRuleApplicability(rule, ctx, true);
227+
assert RewriterRuleCreator.validateRuleApplicability(rule, ctx, true, null);
228228
}
229229
}

0 commit comments

Comments
 (0)