@@ -193,14 +193,19 @@ public void testSumEquality6() {
193193 @ Test
194194 public void testSumEquality () {
195195 RewriterStatement stmt = RewriterUtils .parse ("sum(+(B, sum(*(a, A))))" , ctx , "MATRIX:A,B" , "FLOAT:a" );
196- RewriterStatement stmt2 = RewriterUtils .parse ("*(a, sum(+(B, sum(A))))" , ctx , "MATRIX:A,B" , "FLOAT:a" );
196+ RewriterStatement stmt2 = RewriterUtils .parse ("+(*(a, length(A)), sum(+(B, sum(A))))" , ctx , "MATRIX:A,B" , "FLOAT:a" );
197+ RewriterStatement stmt3 = RewriterUtils .parse ("sum(+(B, *(a, sum(A))))" , ctx , "MATRIX:A,B" , "FLOAT:a" );
197198 stmt = canonicalConverter .apply (stmt );
198199 stmt2 = canonicalConverter .apply (stmt2 );
200+ stmt3 = canonicalConverter .apply (stmt3 );
199201
200202 System .out .println ("==========" );
201203 System .out .println (stmt .toParsableString (ctx , true ));
202204 System .out .println ("==========" );
205+ System .out .println (stmt3 .toParsableString (ctx , true ));
206+ System .out .println ("==========" );
203207 System .out .println (stmt2 .toParsableString (ctx , true ));
208+ assert stmt .match (RewriterStatement .MatcherContext .exactMatch (ctx , stmt3 , stmt ));
204209 assert stmt .match (RewriterStatement .MatcherContext .exactMatch (ctx , stmt2 , stmt ));
205210 }
206211
@@ -1245,4 +1250,42 @@ public void testFused6() {
12451250
12461251 assert !stmt1 .match (RewriterStatement .MatcherContext .exactMatch (ctx , stmt2 , stmt1 ));
12471252 }
1253+
1254+ @ Test
1255+ public void testSum () {
1256+ RewriterStatement stmt1 = RewriterUtils .parse ("sum(+(a,A))" , ctx , "MATRIX:A,B" , "FLOAT:a" );
1257+ RewriterStatement stmt2 = RewriterUtils .parse ("+(*(a, length(A)), sum(A))" , ctx , "MATRIX:A,B" , "FLOAT:a" , "LITERAL_FLOAT:0.0" );
1258+
1259+ System .out .println ("Cost1: " + RewriterCostEstimator .estimateCost (stmt1 , ctx ));
1260+ System .out .println ("Cost2: " + RewriterCostEstimator .estimateCost (stmt2 , ctx ));
1261+
1262+ stmt1 = canonicalConverter .apply (stmt1 );
1263+ stmt2 = canonicalConverter .apply (stmt2 );
1264+
1265+ System .out .println ("==========" );
1266+ System .out .println (stmt1 .toParsableString (ctx , true ));
1267+ System .out .println ("==========" );
1268+ System .out .println (stmt2 .toParsableString (ctx , true ));
1269+
1270+ assert stmt1 .match (RewriterStatement .MatcherContext .exactMatch (ctx , stmt2 , stmt1 ));
1271+ }
1272+
1273+ @ Test
1274+ public void testSumInequality () {
1275+ RewriterStatement stmt1 = RewriterUtils .parse ("sum(+(a,*(B,c)))" , ctx , "MATRIX:B" , "FLOAT:a,c" );
1276+ RewriterStatement stmt2 = RewriterUtils .parse ("*(a, sum(+(B,c)))" , ctx , "MATRIX:B" , "FLOAT:a,c" , "LITERAL_FLOAT:0.0" );
1277+
1278+ System .out .println ("Cost1: " + RewriterCostEstimator .estimateCost (stmt1 , ctx ));
1279+ System .out .println ("Cost2: " + RewriterCostEstimator .estimateCost (stmt2 , ctx ));
1280+
1281+ stmt1 = canonicalConverter .apply (stmt1 );
1282+ stmt2 = canonicalConverter .apply (stmt2 );
1283+
1284+ System .out .println ("==========" );
1285+ System .out .println (stmt1 .toParsableString (ctx , true ));
1286+ System .out .println ("==========" );
1287+ System .out .println (stmt2 .toParsableString (ctx , true ));
1288+
1289+ assert !stmt1 .match (RewriterStatement .MatcherContext .exactMatch (ctx , stmt2 , stmt1 ));
1290+ }
12481291}
0 commit comments