Skip to content

Commit 5bbda9a

Browse files
committed
Sparsity esimtation
1 parent edc3f17 commit 5bbda9a

File tree

5 files changed

+147
-5
lines changed

5 files changed

+147
-5
lines changed

src/main/java/org/apache/sysds/hops/rewriter/RewriterHeuristic.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package org.apache.sysds.hops.rewriter;
22

3+
import org.apache.commons.collections4.bidimap.DualHashBidiMap;
34
import org.apache.commons.lang3.NotImplementedException;
45
import org.apache.commons.lang3.mutable.MutableBoolean;
56

@@ -14,17 +15,15 @@ public class RewriterHeuristic implements RewriterHeuristicTransformation {
1415
private final RewriterRuleSet ruleSet;
1516
private final Function<RewriterStatement, RewriterStatement> f;
1617
private final boolean accelerated;
17-
//private final List<String> desiredProperties;
1818

1919
public RewriterHeuristic(RewriterRuleSet ruleSet) {
2020
this(ruleSet, true);
2121
}
2222

23-
public RewriterHeuristic(RewriterRuleSet ruleSet, boolean accelerated/*, List<String> desiredProperties*/) {
23+
public RewriterHeuristic(RewriterRuleSet ruleSet, boolean accelerated) {
2424
this.ruleSet = ruleSet;
2525
this.accelerated = accelerated;
2626
this.f = null;
27-
//this.desiredProperties = desiredProperties;
2827
}
2928

3029
public RewriterHeuristic(Function<RewriterStatement, RewriterStatement> f) {

src/main/java/org/apache/sysds/hops/rewriter/RewriterInstruction.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ public RewriterInstruction() {
3434

3535
public RewriterInstruction(String instr, final RuleContext ctx, RewriterStatement... ops) {
3636
id = UUID.randomUUID().toString();
37+
this.instr = instr;
3738
withOps(ops);
3839
consolidate(ctx);
3940
}

src/main/java/org/apache/sysds/hops/rewriter/RewriterUtils.java

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1273,6 +1273,25 @@ public static Function<RewriterStatement, RewriterStatement> unfuseOperators(fin
12731273
return lastUnfuse;
12741274
}
12751275

1276+
private static RuleContext lastSparsityCtx;
1277+
private static Function<RewriterStatement, RewriterStatement> lastPrepareForSparsity;
1278+
1279+
public static Function<RewriterStatement, RewriterStatement> prepareForSparsityEstimation(final RuleContext ctx) {
1280+
if (lastSparsityCtx == ctx)
1281+
return lastPrepareForSparsity;
1282+
1283+
ArrayList<RewriterRule> mRules = new ArrayList<>();
1284+
RewriterRuleCollection.substituteFusedOps(mRules, ctx);
1285+
RewriterRuleCollection.substituteEquivalentStatements(mRules, ctx);
1286+
RewriterRuleCollection.eliminateMultipleCasts(mRules, ctx);
1287+
RewriterRuleCollection.canonicalizeBooleanStatements(mRules, ctx);
1288+
RewriterRuleCollection.canonicalizeAlgebraicStatements(mRules, ctx);
1289+
RewriterHeuristic heur = new RewriterHeuristic(new RewriterRuleSet(ctx, mRules));
1290+
lastSparsityCtx = ctx;
1291+
lastPrepareForSparsity = heur::apply;
1292+
return lastPrepareForSparsity;
1293+
}
1294+
12761295
public static Function<RewriterStatement, RewriterStatement> buildCanonicalFormConverter(final RuleContext ctx, boolean debug) {
12771296
ArrayList<RewriterRule> algebraicCanonicalizationRules = new ArrayList<>();
12781297
RewriterRuleCollection.substituteEquivalentStatements(algebraicCanonicalizationRules, ctx);

src/main/java/org/apache/sysds/hops/rewriter/estimators/RewriterSparsityEstimator.java

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,69 @@
11
package org.apache.sysds.hops.rewriter.estimators;
22

33
import org.apache.sysds.hops.rewriter.ConstantFoldingFunctions;
4+
import org.apache.sysds.hops.rewriter.RewriterDataType;
45
import org.apache.sysds.hops.rewriter.RewriterInstruction;
56
import org.apache.sysds.hops.rewriter.RewriterStatement;
7+
import org.apache.sysds.hops.rewriter.RewriterUtils;
68
import org.apache.sysds.hops.rewriter.RuleContext;
79
import org.apache.sysds.hops.rewriter.utils.StatementUtils;
810

11+
import java.util.HashMap;
912
import java.util.Map;
1013
import java.util.UUID;
1114

1215
public class RewriterSparsityEstimator {
13-
public static RewriterStatement estimateNNZ(RewriterStatement stmt, Map<RewriterStatement, Long> matrixNNZs, final RuleContext ctx) {
14-
long[] nnzs = stmt.getOperands().stream().mapToLong(matrixNNZs::get).toArray();
1516

17+
/*public static RewriterStatement getCanonicalized(RewriterStatement instr, final RuleContext ctx) {
18+
RewriterStatement cpy = instr.copyNode();
19+
Map<RewriterStatement, RewriterStatement> mmap = new HashMap<>();
20+
21+
for (int i = 0; i < cpy.getOperands().size(); i++) {
22+
RewriterStatement existing = mmap.get(cpy.getOperands().get(i));
23+
24+
if (existing != null) {
25+
cpy.getOperands().set(i, existing);
26+
} else {
27+
RewriterStatement mDat = new RewriterDataType().as(UUID.randomUUID().toString()).ofType(cpy.getOperands().get(i).getResultingDataType(ctx)).consolidate(ctx);
28+
}
29+
}
30+
31+
RewriterUtils.prepareForSparsityEstimation();
32+
}*/
33+
34+
public static RewriterStatement rollupSparsities(RewriterStatement sparsityEstimate, Map<RewriterStatement, RewriterStatement> sparsityMap, final RuleContext ctx) {
35+
sparsityEstimate.forEachPreOrder(cur -> {
36+
for (int i = 0; i < cur.getOperands().size(); i++) {
37+
RewriterStatement child = cur.getChild(i);
38+
39+
if (child.isInstruction() && child.trueInstruction().equals("_nnz")) {
40+
RewriterStatement subEstimate = sparsityMap.get(child.getChild(0));
41+
42+
if (subEstimate != null) {
43+
cur.getOperands().set(i, subEstimate);
44+
}
45+
}
46+
}
47+
return true;
48+
}, false);
49+
50+
return sparsityEstimate;
51+
}
52+
53+
public static Map<RewriterStatement, RewriterStatement> estimateAllNNZ(RewriterStatement stmt, final RuleContext ctx) {
54+
Map<RewriterStatement, RewriterStatement> map = new HashMap<>();
55+
stmt.forEachPostOrder((cur, pred) -> {
56+
RewriterStatement estimation = estimateNNZ(cur, ctx);
57+
if (estimation != null)
58+
map.put(cur, estimation);
59+
}, false);
60+
61+
return map;
62+
}
63+
64+
public static RewriterStatement estimateNNZ(RewriterStatement stmt, final RuleContext ctx) {
65+
if (!stmt.isInstruction())
66+
return null;
1667
switch (stmt.trueInstruction()) {
1768
case "%*%":
1869
return new RewriterInstruction("*", ctx, StatementUtils.min(ctx, new RewriterInstruction("*", ctx, stmt.getNRow(), stmt.getNCol()), RewriterStatement.nnz(stmt.getChild(1), ctx)), new RewriterInstruction("*", ctx, stmt.getNRow(), stmt.getNCol()), RewriterStatement.nnz(stmt.getChild(0), ctx));
@@ -49,6 +100,31 @@ public static RewriterStatement estimateNNZ(RewriterStatement stmt, Map<Rewriter
49100

50101
case "sqrt(MATRIX)":
51102
return RewriterStatement.nnz(stmt.getChild(0), ctx);
103+
104+
case "diag(MATRIX)":
105+
return StatementUtils.min(ctx, stmt.getNRow(), RewriterStatement.nnz(stmt.getChild(0), ctx));
106+
107+
case "/(MATRIX,FLOAT)":
108+
case "/(MATRIX,MATRIX)":
109+
return RewriterStatement.nnz(stmt.getChild(0), ctx);
110+
case "/(FLOAT,MATRIX)":
111+
if (stmt.getChild(0).isLiteral() && ConstantFoldingFunctions.isNeutralElement(stmt.getChild(0).getLiteral(), "+"))
112+
return RewriterStatement.literal(ctx, 0L);
113+
return StatementUtils.length(ctx, stmt);
114+
115+
116+
// Fused operators
117+
case "log_nz(MATRIX)":
118+
case "*2(MATRIX)":
119+
case "sq(MATRIX)":
120+
return RewriterStatement.nnz(stmt.getChild(0), ctx);
121+
case "1-*(MATRIX,MATRIX)":
122+
return StatementUtils.length(ctx, stmt);
123+
case "+*(MATRIX,FLOAT,MATRIX)":
124+
case "-*(MATRIX,FLOAT,MATRIX)":
125+
if (stmt.getChild(1).isLiteral() && ConstantFoldingFunctions.isNeutralElement(stmt.getChild(1).getLiteral(), "+"))
126+
return RewriterStatement.nnz(stmt.getChild(0), ctx);
127+
return StatementUtils.min(ctx, new RewriterInstruction("+", ctx, RewriterStatement.nnz(stmt.getChild(0), ctx), RewriterStatement.nnz(stmt.getChild(2), ctx)), StatementUtils.length(ctx, stmt));
52128
}
53129

54130
return StatementUtils.length(ctx, stmt);
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package org.apache.sysds.test.component.codegen.rewrite.functions;
2+
3+
import org.apache.sysds.hops.rewriter.RewriterStatement;
4+
import org.apache.sysds.hops.rewriter.RewriterUtils;
5+
import org.apache.sysds.hops.rewriter.RuleContext;
6+
import org.apache.sysds.hops.rewriter.estimators.RewriterSparsityEstimator;
7+
import org.junit.BeforeClass;
8+
import org.junit.Test;
9+
10+
import java.util.Map;
11+
import java.util.function.Function;
12+
13+
public class SparsityEstimationTest {
14+
private static RuleContext ctx;
15+
private static Function<RewriterStatement, RewriterStatement> canonicalConverter;
16+
17+
@BeforeClass
18+
public static void setup() {
19+
ctx = RewriterUtils.buildDefaultContext();
20+
canonicalConverter = RewriterUtils.buildCanonicalFormConverter(ctx, false);
21+
}
22+
23+
@Test
24+
public void test1() {
25+
RewriterStatement stmt = RewriterUtils.parse("+*(A, 0.0, B)", ctx, "MATRIX:A,B", "LITERAL_FLOAT:0.0");
26+
System.out.println(RewriterSparsityEstimator.estimateNNZ(stmt, ctx).toParsableString(ctx));
27+
}
28+
29+
@Test
30+
public void test2() {
31+
RewriterStatement stmt = RewriterUtils.parse("+*(A, a, B)", ctx, "MATRIX:A,B", "FLOAT:a");
32+
System.out.println(RewriterSparsityEstimator.estimateNNZ(stmt, ctx).toParsableString(ctx));
33+
}
34+
35+
@Test
36+
public void test3() {
37+
RewriterStatement stmt = RewriterUtils.parse("+(A, -(B, A))", ctx, "MATRIX:A,B", "FLOAT:a");
38+
Map<RewriterStatement, RewriterStatement> estimates = RewriterSparsityEstimator.estimateAllNNZ(stmt, ctx);
39+
40+
estimates.forEach((k, v) -> {
41+
System.out.println("K: " + k.toParsableString(ctx));
42+
System.out.println("Sparsity: " + v.toParsableString(ctx));
43+
});
44+
45+
System.out.println("Rollup: " + RewriterSparsityEstimator.rollupSparsities(estimates.get(stmt), estimates, ctx).toParsableString(ctx));
46+
}
47+
}

0 commit comments

Comments
 (0)