Skip to content

Commit 04867f3

Browse files
committed
Some bugfixes (origin still unknown)
1 parent 3bdc6c6 commit 04867f3

File tree

3 files changed

+51
-7
lines changed

3 files changed

+51
-7
lines changed

src/main/java/org/apache/sysds/hops/rewriter/RewriterAlphabetEncoder.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
import java.util.stream.Stream;
1919

2020
public class RewriterAlphabetEncoder {
21-
private static final List<String> ALL_TYPES = List.of("MATRIX", "FLOAT");
22-
private static final List<String> MATRIX = List.of("MATRIX");
21+
public static final List<String> ALL_TYPES = List.of("MATRIX", "FLOAT");
22+
public static final List<String> MATRIX = List.of("MATRIX");
2323

2424
private static Operand[] instructionAlphabet = new Operand[] {
2525
null,
@@ -146,13 +146,18 @@ public static List<RewriterStatement> buildAssertionVariations(RewriterStatement
146146

147147
private static RewriterStatement createVector(RewriterStatement of, boolean rowVector, Map<RewriterStatement, RewriterStatement> createdObjects) {
148148
// TODO: Why is it necessary to discard the old DataType?
149-
//RewriterStatement mCpy = new RewriterDataType().as(of.getId()).ofType(of.getResultingDataType(ctx)).consolidate(ctx);
149+
RewriterStatement mCpy = createdObjects.get(of);
150+
151+
if (mCpy == null) {
152+
mCpy = new RewriterDataType().as(of.getId()).ofType(of.getResultingDataType(ctx)).consolidate(ctx);
153+
createdObjects.put(of, mCpy);
154+
}
150155
//RewriterStatement nRowCol = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction(rowVector ? "nrow" : "ncol").withOps(mCpy).consolidate(ctx);
151156
//createdObjects.put(of, mCpy);
152157
return new RewriterInstruction()
153158
.as(of.getId())
154159
.withInstruction(rowVector ? "rowVec" : "colVec")
155-
.withOps(of)
160+
.withOps(mCpy)
156161
.consolidate(ctx);
157162
}
158163

src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterClusteringTest.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ public static void testExpressionClustering() {
189189
actualCtr += expanded.size();
190190
for (RewriterStatement stmt : expanded) {
191191
try {
192+
ctx.metaPropagator.apply(stmt);
192193
RewriterStatement canonicalForm = converter.apply(stmt);
193194
stmt.getCost(ctx);
194195

@@ -220,8 +221,8 @@ public static void testExpressionClustering() {
220221

221222
//stmt.getCost(ctx); // Fetch cost already
222223
// TODO: Not quite working yet
223-
//canonicalForm.compress();
224-
//stmt.compress();
224+
canonicalForm.compress();
225+
stmt.compress();
225226
synchronized (lock) {
226227
RewriterEquivalenceDatabase.DBEntry entry = canonicalExprDB.insert(ctx, canonicalForm, stmt);
227228

src/test/java/org/apache/sysds/test/component/codegen/rewrite/functions/RewriterAlphabetTest.java

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77
import org.junit.BeforeClass;
88
import org.junit.Test;
99

10+
import java.util.ArrayList;
1011
import java.util.List;
1112
import java.util.function.Function;
1213

14+
import static org.apache.sysds.hops.rewriter.RewriterAlphabetEncoder.MATRIX;
15+
1316
public class RewriterAlphabetTest {
1417

1518
private static RuleContext ctx;
@@ -51,7 +54,7 @@ public void testEncode1() {
5154
assert l == 27;
5255
}
5356

54-
@Test
57+
//@Test
5558
public void testRandomStatementGeneration() {
5659
int ctr = 0;
5760
for (int i = 0; i < 20; i++) {
@@ -80,4 +83,39 @@ public void test() {
8083
System.out.println(stmt.toParsableString(ctx));
8184
}
8285

86+
@Test
87+
public void testRandomStatementGeneration2() {
88+
int ctr = 0;
89+
List<List<RewriterAlphabetEncoder.Operand>> opList = new ArrayList<>();
90+
opList.add(List.of(new RewriterAlphabetEncoder.Operand("trace", 1, MATRIX), new RewriterAlphabetEncoder.Operand("%*%", 2, MATRIX)));
91+
opList.add(List.of(new RewriterAlphabetEncoder.Operand("sum", 1, MATRIX), new RewriterAlphabetEncoder.Operand("*", 2, MATRIX), new RewriterAlphabetEncoder.Operand("t", 1, MATRIX)));
92+
List<RewriterStatement> all = new ArrayList<>();
93+
for (List<RewriterAlphabetEncoder.Operand> ops : opList) {
94+
//List<RewriterAlphabetEncoder.Operand> ops = List.of(new RewriterAlphabetEncoder.Operand("sum", 1, MATRIX), new RewriterAlphabetEncoder.Operand("%*%", 1, MATRIX));
95+
//System.out.println("Idx: " + i);
96+
//System.out.println(ops);
97+
//System.out.println(RewriterAlphabetEncoder.buildAllPossibleDAGs(ops, ctx, false).size());
98+
for (RewriterStatement stmt : RewriterAlphabetEncoder.buildAllPossibleDAGs(ops, ctx, true)) {
99+
System.out.println("Base: " + stmt.toParsableString(ctx));
100+
List<RewriterStatement> expand = new ArrayList<>();
101+
expand.addAll(RewriterAlphabetEncoder.buildVariations(stmt, ctx));
102+
expand.addAll(RewriterAlphabetEncoder.buildAssertionVariations(stmt, ctx, false));
103+
104+
for (RewriterStatement sstmt : expand) {
105+
canonicalConverter.apply(sstmt);
106+
System.out.println(sstmt);
107+
sstmt.compress();
108+
all.add(sstmt);
109+
//System.out.println("Raw: " + sstmt);
110+
ctr++;
111+
}
112+
}
113+
}
114+
115+
System.out.println("Total DAGs: " + ctr);
116+
for (RewriterStatement sstmt : all) {
117+
System.out.println(sstmt);
118+
}
119+
}
120+
83121
}

0 commit comments

Comments
 (0)