@@ -86,8 +86,8 @@ public void testSimplifyDistributiveBinaryOperation() {
8686
8787 @ Test
8888 public void testSimplifyBushyBinaryOperation () {
89- RewriterStatement stmt1 = RewriterUtils .parse ("*(A,*(B, %*%(C, rowVec (D))))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:1.0" );
90- RewriterStatement stmt2 = RewriterUtils .parse ("*(*(A,B), %*%(C, rowVec (D)))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:1.0" );
89+ RewriterStatement stmt1 = RewriterUtils .parse ("*(A,*(B, %*%(C, colVec (D))))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:1.0" );
90+ RewriterStatement stmt2 = RewriterUtils .parse ("*(*(A,B), %*%(C, colVec (D)))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:1.0" );
9191
9292 stmt1 = canonicalConverter .apply (stmt1 );
9393 stmt2 = canonicalConverter .apply (stmt2 );
@@ -159,7 +159,7 @@ public void testSimplifyTraceMatrixMult() {
159159 @ Test
160160 public void testSimplifySlicedMatrixMult () {
161161 RewriterStatement stmt1 = RewriterUtils .parse ("[](%*%(A,B), 1, 1)" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:1.0,2.0" , "LITERAL_INT:1" );
162- RewriterStatement stmt2 = RewriterUtils .parse ("as.scalar(%*%(colVec (A), rowVec (B)))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:1.0,2.0" , "LITERAL_INT:1" );
162+ RewriterStatement stmt2 = RewriterUtils .parse ("as.scalar(%*%(rowVec (A), colVec (B)))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:1.0,2.0" , "LITERAL_INT:1" );
163163
164164 assert match (stmt1 , stmt2 );
165165 }
@@ -226,23 +226,23 @@ public void testSimplifyNotOverComparisons() {
226226 public void testRemoveEmptyRightIndexing () {
227227 // We do not directly support the specification of nnz, but we can emulate such a matrix by multiplying with 0
228228 RewriterStatement stmt1 = RewriterUtils .parse ("[](*(A, 0.0), 1, nrow(A), 1, 1)" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" );
229- RewriterStatement stmt2 = RewriterUtils .parse ("const(rowVec (A), 0.0)" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" );
229+ RewriterStatement stmt2 = RewriterUtils .parse ("const(colVec (A), 0.0)" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" );
230230
231231 assert match (stmt1 , stmt2 );
232232 }
233233
234234 @ Test
235235 public void testRemoveUnnecessaryRightIndexing () {
236- RewriterStatement stmt1 = RewriterUtils .parse ("[](rowVec (A), 1, nrow(A), 1, 1)" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" );
237- RewriterStatement stmt2 = RewriterUtils .parse ("rowVec (A)" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" );
236+ RewriterStatement stmt1 = RewriterUtils .parse ("[](colVec (A), 1, nrow(A), 1, 1)" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" );
237+ RewriterStatement stmt2 = RewriterUtils .parse ("colVec (A)" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" );
238238
239239 assert match (stmt1 , stmt2 );
240240 }
241241
242242 @ Test
243243 public void testRemoveUnnecessaryReorgOperation3 () {
244- RewriterStatement stmt1 = RewriterUtils .parse ("t(rowVec(colVec (A)))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" );
245- RewriterStatement stmt2 = RewriterUtils .parse ("rowVec(colVec (A))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" );
244+ RewriterStatement stmt1 = RewriterUtils .parse ("t(cellMat (A)))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" );
245+ RewriterStatement stmt2 = RewriterUtils .parse ("cellMat (A))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" );
246246
247247 assert match (stmt1 , stmt2 );
248248 }
@@ -275,41 +275,41 @@ public void testFuseDatagenAndReorgOperation() {
275275
276276 @ Test
277277 public void testSimplifyColwiseAggregate () {
278- RewriterStatement stmt1 = RewriterUtils .parse ("colSums(colVec (A))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
279- RewriterStatement stmt2 = RewriterUtils .parse ("colVec (A)" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
278+ RewriterStatement stmt1 = RewriterUtils .parse ("colSums(rowVec (A))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
279+ RewriterStatement stmt2 = RewriterUtils .parse ("rowVec (A)" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
280280
281281 assert match (stmt1 , stmt2 );
282282 }
283283
284284 @ Test
285285 public void testSimplifyRowwiseAggregate () {
286- RewriterStatement stmt1 = RewriterUtils .parse ("rowSums(rowVec (A))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
287- RewriterStatement stmt2 = RewriterUtils .parse ("rowVec (A)" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
286+ RewriterStatement stmt1 = RewriterUtils .parse ("rowSums(colVec (A))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
287+ RewriterStatement stmt2 = RewriterUtils .parse ("colVec (A)" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
288288
289289 assert match (stmt1 , stmt2 );
290290 }
291291
292292 // We don't have broadcasting semantics
293293 @ Test
294294 public void testSimplifyColSumsMVMult () {
295- RewriterStatement stmt1 = RewriterUtils .parse ("colSums(*(rowVec (A), rowVec (B)))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
296- RewriterStatement stmt2 = RewriterUtils .parse ("%*%(t(rowVec (B)), rowVec (A))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
295+ RewriterStatement stmt1 = RewriterUtils .parse ("colSums(*(colVec (A), colVec (B)))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
296+ RewriterStatement stmt2 = RewriterUtils .parse ("%*%(t(colVec (B)), colVec (A))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
297297
298298 assert match (stmt1 , stmt2 );
299299 }
300300
301301 // We don't have broadcasting semantics
302302 @ Test
303303 public void testSimplifyRowSumsMVMult () {
304- RewriterStatement stmt1 = RewriterUtils .parse ("rowSums(*(colVec (A), colVec (B)))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
305- RewriterStatement stmt2 = RewriterUtils .parse ("%*%(colVec (A), t(colVec (B)))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
304+ RewriterStatement stmt1 = RewriterUtils .parse ("rowSums(*(rowVec (A), rowVec (B)))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
305+ RewriterStatement stmt2 = RewriterUtils .parse ("%*%(rowVec (A), t(rowVec (B)))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
306306
307307 assert match (stmt1 , stmt2 );
308308 }
309309
310310 @ Test
311311 public void testSimplifyUnnecessaryAggregate () {
312- RewriterStatement stmt1 = RewriterUtils .parse ("sum(rowVec(colVec (A)))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
312+ RewriterStatement stmt1 = RewriterUtils .parse ("sum(cellMat (A)))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
313313 RewriterStatement stmt2 = RewriterUtils .parse ("as.scalar(A)" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
314314
315315 assert match (stmt1 , stmt2 );
@@ -350,16 +350,16 @@ public void testSimplifyEmptyMatrixMult() {
350350
351351 @ Test
352352 public void testSimplifyEmptyMatrixMult2 () {
353- RewriterStatement stmt1 = RewriterUtils .parse ("%*%(rowVec (A), cast.MATRIX(1.0))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
354- RewriterStatement stmt2 = RewriterUtils .parse ("rowVec (A)" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
353+ RewriterStatement stmt1 = RewriterUtils .parse ("%*%(colVec (A), cast.MATRIX(1.0))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
354+ RewriterStatement stmt2 = RewriterUtils .parse ("colVec (A)" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
355355
356356 assert match (stmt1 , stmt2 );
357357 }
358358
359359 @ Test
360360 public void testSimplifyScalarMatrixMult () {
361- RewriterStatement stmt1 = RewriterUtils .parse ("%*%(rowVec (A), cast.MATRIX(a))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
362- RewriterStatement stmt2 = RewriterUtils .parse ("*(rowVec (A), as.scalar(cast.MATRIX(a)))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
361+ RewriterStatement stmt1 = RewriterUtils .parse ("%*%(colVec (A), cast.MATRIX(a))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
362+ RewriterStatement stmt2 = RewriterUtils .parse ("*(colVec (A), as.scalar(cast.MATRIX(a)))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
363363
364364 assert match (stmt1 , stmt2 );
365365 }
@@ -405,8 +405,8 @@ public void testPushdownSumOnAdditiveBinary() {
405405
406406 @ Test
407407 public void testSimplifyDotProductSum () {
408- RewriterStatement stmt1 = RewriterUtils .parse ("cast.MATRIX(sum(sq(rowVec (A))))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
409- RewriterStatement stmt2 = RewriterUtils .parse ("%*%(t(rowVec (A)), rowVec (A))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
408+ RewriterStatement stmt1 = RewriterUtils .parse ("cast.MATRIX(sum(sq(colVec (A))))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
409+ RewriterStatement stmt2 = RewriterUtils .parse ("%*%(t(colVec (A)), colVec (A))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
410410
411411 assert match (stmt1 , stmt2 );
412412 }
@@ -477,7 +477,7 @@ public void testSimplifyEmptyBinaryOperation3() {
477477
478478 //@Test
479479 public void testSimplifyScalarMVBinaryOperation () {
480- RewriterStatement stmt1 = RewriterUtils .parse ("*(A, rowVec (colVec(B)))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
480+ RewriterStatement stmt1 = RewriterUtils .parse ("*(A, colVec (colVec(B)))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
481481 RewriterStatement stmt2 = RewriterUtils .parse ("*(A, as.scalar(B))" , ctx , "MATRIX:A,B,C,D" , "FLOAT:a,b,c" , "LITERAL_FLOAT:0.0,1.0,2.0" , "LITERAL_INT:1" , "LITERAL_BOOL:TRUE,FALSE" , "INT:i" );
482482
483483 assert match (stmt1 , stmt2 );
0 commit comments