Skip to content

Commit 98090c0

Browse files
committed
Bugfix
1 parent 94d46a8 commit 98090c0

File tree

4 files changed

+23
-7
lines changed

4 files changed

+23
-7
lines changed

src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterHeuristics.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,9 @@ public void addRepeated(String name, RewriterHeuristicTransformation heur) {
5757
public RewriterStatement apply(RewriterStatement stmt, @Nullable BiFunction<RewriterStatement, RewriterRule, Boolean> func, MutableBoolean bool, boolean print) {
5858
for (HeuristicEntry entry : heuristics) {
5959
if (print) {
60-
LOG.info("\n");
61-
LOG.info("> " + entry.name + " <");
62-
LOG.info("\n");
60+
System.out.println("\n");
61+
System.out.println("> " + entry.name + " <");
62+
System.out.println("\n");
6363
}
6464

6565
stmt = entry.heuristics.apply(stmt, func, bool, print);

src/main/java/org/apache/sysds/hops/rewriter/rule/RewriterRuleCollection.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1387,7 +1387,8 @@ public static void flattenOperations(final List<RewriterRule> rules, final RuleC
13871387
if (newOwnerId == null)
13881388
throw new IllegalArgumentException();
13891389

1390-
stmt.getOperands().get(0).getOperands().get(1).unsafePutMeta("ownerId", newOwnerId);
1390+
if (!stmt.getChild(0, 1).isLiteral())
1391+
stmt.getOperands().get(0).getOperands().get(1).unsafePutMeta("ownerId", newOwnerId);
13911392
}, true)
13921393
.build());
13931394

src/main/java/org/apache/sysds/hops/rewriter/utils/RewriterUtils.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -920,6 +920,8 @@ private static RewriterStatement tryPullOutSum(RewriterStatement sum, final Rule
920920
List<RewriterStatement> components = new ArrayList<>();
921921

922922
for (RewriterStatement idx : indices) {
923+
if (idx.isLiteral())
924+
continue;
923925
RewriterStatement idxFrom = idx.getChild(0);
924926
RewriterStatement idxTo = idx.getChild(1);
925927
RewriterStatement negation = new RewriterInstruction().as(UUID.randomUUID().toString()).withInstruction("-").withOps(/*RewriterStatement.ensureFloat(ctx, idxFrom)*/idxFrom).consolidate(ctx);
@@ -1104,9 +1106,12 @@ private static RewriterStatement foldNaryReducible(RewriterStatement stmt, final
11041106
int[] literals = IntStream.range(0, argList.size()).filter(i -> argList.get(i).isLiteral()).toArray();
11051107

11061108
if (literals.length == 1) {
1107-
RewriterStatement overwrite = ConstantFoldingUtils.overwritesLiteral((Number)argList.get(literals[0]).getLiteral(), stmt.trueInstruction(), ctx);
1108-
if (overwrite != null)
1109-
return overwrite;
1109+
Object literal = argList.get(literals[0]).getLiteral();
1110+
if (literal instanceof Number) {
1111+
RewriterStatement overwrite = ConstantFoldingUtils.overwritesLiteral((Number) literal, stmt.trueInstruction(), ctx);
1112+
if (overwrite != null)
1113+
return overwrite;
1114+
}
11101115

11111116
// Check if is neutral element
11121117
if (ConstantFoldingUtils.isNeutralElement(argList.get(literals[0]).getLiteral(), stmt.trueInstruction())) {

src/test/java/org/apache/sysds/test/component/codegen/rewrite/RewriterStreamTests.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1758,4 +1758,14 @@ public void testSparsityComparison() {
17581758

17591759
assert can2.match(RewriterStatement.MatcherContext.exactMatch(ctx, can1, can2));
17601760
}
1761+
1762+
@Test
1763+
public void testTEST() {
1764+
RewriterStatement stmt1 = RewriterUtils.parse("t(/(<=(A,B),rowSums(<=(C,B))))", ctx, "MATRIX:A,B,C,D,E", "LITERAL_FLOAT:1.0", "LITERAL_INT:1");
1765+
1766+
stmt1 = canonicalConverter.apply(stmt1);
1767+
1768+
LOG.info("==========");
1769+
LOG.info(stmt1.toParsableString(ctx, true));
1770+
}
17611771
}

0 commit comments

Comments
 (0)