Skip to content

Commit 8b866c6

Browse files
committed
Validation script implementation
1 parent e153ae1 commit 8b866c6

File tree

3 files changed

+320
-0
lines changed

3 files changed

+320
-0
lines changed
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
package org.apache.sysds.hops.rewriter;
2+
3+
import org.apache.commons.lang3.NotImplementedException;
4+
5+
import java.util.HashMap;
6+
import java.util.HashSet;
7+
import java.util.Map;
8+
import java.util.Set;
9+
import java.util.function.BiFunction;
10+
import java.util.function.Function;
11+
12+
public class DMLCodeGenerator {
13+
public static final double EPS = 1e-10;
14+
15+
16+
private static final HashSet<String> printAsBinary = new HashSet<>();
17+
private static final HashMap<String, BiFunction<RewriterStatement, StringBuilder, Boolean>> customEncoders = new HashMap<>();
18+
private static final RuleContext ctx = RewriterUtils.buildDefaultContext();
19+
20+
static {
21+
printAsBinary.add("+");
22+
printAsBinary.add("-");
23+
printAsBinary.add("*");
24+
printAsBinary.add("/");
25+
printAsBinary.add("^");
26+
printAsBinary.add("==");
27+
printAsBinary.add("!=");
28+
printAsBinary.add(">");
29+
printAsBinary.add(">=");
30+
printAsBinary.add("<");
31+
printAsBinary.add("<=");
32+
33+
customEncoders.put("[]", (stmt, sb) -> {
34+
if (stmt.getOperands().size() == 3) {
35+
sb.append('(');
36+
appendExpression(stmt.getChild(0), sb);
37+
sb.append(")[");
38+
appendExpression(stmt.getChild(1), sb);
39+
sb.append(", ");
40+
appendExpression(stmt.getChild(2), sb);
41+
sb.append(']');
42+
return true;
43+
} else if (stmt.getOperands().size() == 5) {
44+
sb.append('(');
45+
appendExpression(stmt.getChild(0), sb);
46+
sb.append(")[");
47+
appendExpression(stmt.getChild(1), sb);
48+
sb.append(" : ");
49+
appendExpression(stmt.getChild(2), sb);
50+
sb.append(", ");
51+
appendExpression(stmt.getChild(3), sb);
52+
sb.append(" : ");
53+
appendExpression(stmt.getChild(4), sb);
54+
sb.append(']');
55+
return true;
56+
}
57+
58+
return false;
59+
});
60+
}
61+
62+
public static String generateRuleValidationDML(RewriterRule rule, double eps, String sessionId) {
63+
RewriterStatement stmtFrom = rule.getStmt1();
64+
RewriterStatement stmtTo = rule.getStmt2();
65+
66+
Set<RewriterStatement> vars = new HashSet<>();
67+
68+
stmtFrom.forEachPostOrder((stmt, pred) -> {
69+
if (!stmt.isInstruction() && !stmt.isLiteral())
70+
vars.add(stmt);
71+
}, false);
72+
73+
stmtTo.forEachPostOrder((stmt, pred) -> {
74+
if (!stmt.isInstruction() && !stmt.isLiteral())
75+
vars.add(stmt);
76+
}, false);
77+
78+
StringBuilder sb = new StringBuilder();
79+
80+
for (RewriterStatement var : vars) {
81+
switch (var.getResultingDataType(ctx)) {
82+
case "MATRIX":
83+
sb.append(var.getId() + " = rand(rows=1000, cols=1000, min=0.0, max=1.0)\n");
84+
break;
85+
case "FLOAT":
86+
sb.append(var.getId() + " = as.scalar(rand())\n");
87+
break;
88+
case "INT":
89+
sb.append(var.getId() + " = as.integer(as.scalar(rand(min=0.0, max=10000.0)))\n");
90+
break;
91+
case "BOOL":
92+
sb.append(var.getId() + " = as.scalar(rand()) < 0.5\n");
93+
break;
94+
default:
95+
throw new NotImplementedException(var.getResultingDataType(ctx));
96+
}
97+
}
98+
99+
sb.append('\n');
100+
sb.append("R1 = ");
101+
sb.append(generateDML(stmtFrom));
102+
sb.append('\n');
103+
sb.append("R2 = ");
104+
sb.append(generateDML(stmtTo));
105+
sb.append('\n');
106+
sb.append("print(\"");
107+
sb.append(sessionId);
108+
sb.append(" valid: \" + (");
109+
sb.append(generateEqualityCheck("R1", "R2", stmtFrom.getResultingDataType(ctx), eps));
110+
sb.append("))");
111+
112+
return sb.toString();
113+
}
114+
115+
public static String generateEqualityCheck(String stmt1Var, String stmt2Var, String dataType, double eps) {
116+
switch (dataType) {
117+
case "MATRIX":
118+
return "sum(abs(" + stmt1Var + " - " + stmt2Var + ") < " + eps + ") == length(" + stmt1Var + ")";
119+
case "INT":
120+
case "BOOL":
121+
return stmt1Var + " == " + stmt2Var;
122+
case "FLOAT":
123+
return "abs(" + stmt1Var + " - " + stmt2Var + ") < " + eps;
124+
}
125+
126+
throw new NotImplementedException();
127+
}
128+
129+
public static String generateDMLDefs(Map<String, RewriterStatement> defs) {
130+
StringBuilder sb = new StringBuilder();
131+
132+
defs.forEach((k, v) -> {
133+
sb.append(k);
134+
sb.append(" = ");
135+
sb.append(generateDML(v));
136+
sb.append('\n');
137+
});
138+
139+
return sb.toString();
140+
}
141+
142+
public static String generateDML(RewriterStatement root) {
143+
StringBuilder sb = new StringBuilder();
144+
appendExpression(root, sb);
145+
146+
return sb.toString();
147+
}
148+
149+
private static void appendExpression(RewriterStatement cur, StringBuilder sb) {
150+
if (cur.isInstruction()) {
151+
resolveExpression((RewriterInstruction) cur, sb);
152+
} else {
153+
if (cur.isLiteral())
154+
sb.append(cur.getLiteral());
155+
else
156+
sb.append(cur.getId());
157+
}
158+
}
159+
160+
private static void resolveExpression(RewriterInstruction expr, StringBuilder sb) {
161+
String typedInstr = expr.trueTypedInstruction(ctx);
162+
String unTypedInstr = expr.trueInstruction();
163+
164+
if (expr.getOperands().size() == 2 && (printAsBinary.contains(typedInstr) || printAsBinary.contains(unTypedInstr))) {
165+
sb.append('(');
166+
appendExpression(expr.getChild(0), sb);
167+
sb.append(") ");
168+
sb.append(unTypedInstr);
169+
sb.append(" (");
170+
appendExpression(expr.getChild(1), sb);
171+
sb.append(')');
172+
return;
173+
}
174+
175+
BiFunction<RewriterStatement, StringBuilder, Boolean> customEncoder = customEncoders.get(typedInstr);
176+
177+
if (customEncoder == null)
178+
customEncoder = customEncoders.get(unTypedInstr);
179+
180+
if (customEncoder == null) {
181+
sb.append(unTypedInstr);
182+
sb.append('(');
183+
184+
for (int i = 0; i < expr.getOperands().size(); i++) {
185+
if (i != 0)
186+
sb.append(", ");
187+
188+
appendExpression(expr.getChild(i), sb);
189+
}
190+
191+
sb.append(')');
192+
} else {
193+
customEncoder.apply(expr, sb);
194+
}
195+
}
196+
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package org.apache.sysds.hops.rewriter;
2+
3+
import org.apache.sysds.api.DMLScript;
4+
5+
import java.io.OutputStream;
6+
import java.io.PrintStream;
7+
import java.util.function.Consumer;
8+
9+
public class DMLExecutor {
10+
private static PrintStream origPrintStream = System.out;
11+
12+
// This cannot run in parallel
13+
public static synchronized void executeCode(String code, Consumer<String> consoleInterceptor) {
14+
try {
15+
if (consoleInterceptor != null)
16+
System.setOut(new PrintStream(new CustomOutputStream(System.out, consoleInterceptor)));
17+
18+
DMLScript.executeScript(new String[]{"-s", code});
19+
20+
} catch (Exception e) {
21+
e.printStackTrace();
22+
}
23+
24+
if (consoleInterceptor != null)
25+
System.setOut(origPrintStream);
26+
}
27+
28+
// Bypasses the interceptor
29+
public static void println(Object o) {
30+
origPrintStream.println(o);
31+
}
32+
33+
private static class CustomOutputStream extends OutputStream {
34+
private PrintStream ps;
35+
private StringBuilder buffer = new StringBuilder();
36+
private Consumer<String> lineHandler;
37+
38+
public CustomOutputStream(PrintStream actualPrintStream, Consumer<String> lineHandler) {
39+
this.ps = actualPrintStream;
40+
this.lineHandler = lineHandler;
41+
}
42+
43+
@Override
44+
public void write(int b) {
45+
char c = (char) b;
46+
if (c == '\n') {
47+
lineHandler.accept(buffer.toString());
48+
buffer.setLength(0); // Clear the buffer after handling the line
49+
} else {
50+
buffer.append(c); // Accumulate characters until newline
51+
}
52+
// Handle the byte 'b', or you can write to any custom destination
53+
//ps.print((char) b); // Example: redirect to System.err
54+
}
55+
56+
@Override
57+
public void write(byte[] b, int off, int len) {
58+
for (int i = off; i < off + len; i++) {
59+
write(b[i]);
60+
}
61+
}
62+
}
63+
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
package org.apache.sysds.test.component.codegen.rewrite.functions;
2+
3+
import org.apache.commons.lang3.mutable.MutableBoolean;
4+
import org.apache.sysds.hops.rewriter.DMLCodeGenerator;
5+
import org.apache.sysds.hops.rewriter.DMLExecutor;
6+
import org.apache.sysds.hops.rewriter.RewriterRule;
7+
import org.apache.sysds.hops.rewriter.RewriterRuleSet;
8+
import org.apache.sysds.hops.rewriter.RewriterStatement;
9+
import org.apache.sysds.hops.rewriter.RewriterUtils;
10+
import org.apache.sysds.hops.rewriter.RuleContext;
11+
import org.junit.BeforeClass;
12+
import org.junit.Test;
13+
14+
import java.util.List;
15+
import java.util.UUID;
16+
import java.util.function.Function;
17+
18+
public class DMLCodeGenTest {
19+
20+
private static RuleContext ctx;
21+
private static Function<RewriterStatement, RewriterStatement> canonicalConverter;
22+
23+
@BeforeClass
24+
public static void setup() {
25+
ctx = RewriterUtils.buildDefaultContext();
26+
canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, false);
27+
}
28+
29+
@Test
30+
public void test1() {
31+
RewriterStatement stmt = RewriterUtils.parse("trace(+(A, t(B)))", ctx, "MATRIX:A,B");
32+
System.out.println(DMLCodeGenerator.generateDML(stmt));
33+
}
34+
35+
@Test
36+
public void test2() {
37+
String ruleStr1 = "MATRIX:A\nt(t(A))\n=>\nA";
38+
String ruleStr2 = "MATRIX:A\nrowSums(t(A))\n=>\nt(colSums(A))";
39+
RewriterRule rule1 = RewriterUtils.parseRule(ruleStr1, ctx);
40+
RewriterRule rule2 = RewriterUtils.parseRule(ruleStr2, ctx);
41+
42+
//RewriterRuleSet ruleSet = new RewriterRuleSet(ctx, List.of(rule1, rule2));
43+
String sessionId = UUID.randomUUID().toString();
44+
String validationScript = DMLCodeGenerator.generateRuleValidationDML(rule2, DMLCodeGenerator.EPS, sessionId);
45+
System.out.println("Validation script:");
46+
System.out.println(validationScript);
47+
MutableBoolean valid = new MutableBoolean(true);
48+
DMLExecutor.executeCode(validationScript, line -> {
49+
if (!line.startsWith(sessionId))
50+
return;
51+
52+
if (!line.endsWith("valid: TRUE")) {
53+
DMLExecutor.println("An invalid rule was found!");
54+
valid.setValue(false);
55+
}
56+
});
57+
58+
System.out.println("Exiting...");
59+
assert valid.booleanValue();
60+
}
61+
}

0 commit comments

Comments
 (0)