Skip to content

Commit 0286e48

Browse files
committed
Checkpoint
1 parent ba8d1b7 commit 0286e48

File tree

3 files changed

+25
-8
lines changed

3 files changed

+25
-8
lines changed

src/main/java/org/apache/sysds/hops/rewriter/RewriterCostEstimator.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@ public static long estimateCost(RewriterStatement stmt, Function<RewriterStateme
4949

5050
costFn = RewriterUtils.foldConstants(costFn, ctx);
5151

52-
if (!costFn.isLiteral())
53-
throw new IllegalArgumentException("Cost function must be a literal: " + costFn.toParsableString(ctx));
52+
if (!costFn.isLiteral()) {
53+
//throw new IllegalArgumentException("Cost function must be a literal: " + costFn.toParsableString(ctx) + "\nCorresponding statement:\n" + stmt.toParsableString(ctx));
54+
}
5455

5556
return (long)costFn.getLiteral();
5657
}

src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterClusteringTest.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,14 @@ public void testExpressionClustering() {
7979
db.parForEach(expr -> {
8080
if (ctr.incrementAndGet() % 10 == 0)
8181
System.out.println("Done: " + ctr.intValue() + " / " + size);
82-
if (ctr.intValue() > 1000)
83-
return; // Skip
82+
//if (ctr.intValue() > 1000)
83+
//return; // Skip
8484
// First, build all possible subtrees
8585
//System.out.println("Eval:\n" + expr.toParsableString(ctx, true));
86-
List<RewriterStatement> subExprs = RewriterUtils.generateSubtrees(expr, ctx, 500);
86+
List<RewriterStatement> subExprs = RewriterUtils.generateSubtrees(expr, ctx, 300);
8787
if (subExprs.size() > 100)
8888
System.out.println("Critical number of subtrees: " + subExprs.size());
89-
if (subExprs.size() > 2000) {
89+
if (subExprs.size() > 500) {
9090
System.out.println("Skipping subtrees...");
9191
subExprs = List.of(expr);
9292
}
@@ -261,8 +261,8 @@ private List<Tuple5<Double, Long, Long, RewriterStatement, RewriterStatement>> f
261261

262262
if (cost != null) {
263263
double score = (((double)cost.longValue()) / minCost - 1) * 1000; // Relative cost reduction
264-
score += cost.longValue() - minCost; // Absolute cost reduction
265-
if (score > 0.000001)
264+
score *= cost.longValue() - minCost; // Absolute cost reduction
265+
if (score > 1e-10)
266266
suggestedRewrites.add(new Tuple5<>(score, cost, minCost, eq, optimalStatement));
267267
}
268268
}

src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,4 +738,20 @@ public void testAdvancedEquivalence1() {
738738

739739
assert stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2));
740740
}
741+
742+
@Test
743+
public void testInequality() {
744+
RewriterStatement stmt1 = RewriterUtils.parse("/(*(A, A), B)", ctx, "MATRIX:A,B");
745+
RewriterStatement stmt2 = RewriterUtils.parse("/(*(A, A), sum(B))", ctx, "MATRIX:A,B");
746+
747+
stmt1 = canonicalConverter.apply(stmt1);
748+
stmt2 = canonicalConverter.apply(stmt2);
749+
750+
System.out.println("==========");
751+
System.out.println(stmt1.toParsableString(ctx, true));
752+
System.out.println("==========");
753+
System.out.println(stmt2.toParsableString(ctx, true));
754+
755+
assert !stmt1.match(RewriterStatement.MatcherContext.exactMatch(ctx, stmt2));
756+
}
741757
}

0 commit comments

Comments
 (0)