@@ -151,6 +151,13 @@ private RewriterStatement propagateDims(RewriterStatement root, RewriterStatemen
151151					root .unsafePutMeta ("nrow" , firstMatrixStatement .get ().getMeta ("nrow" ));
152152					root .unsafePutMeta ("ncol" , firstMatrixStatement .get ().getMeta ("ncol" ));
153153					return  null ;
154+ 				case  "cast.MATRIX" :
155+ 					String  mDT  = root .getChild (0 ).getResultingDataType (ctx );
156+ 					if  (mDT .equals ("BOOL" ) || mDT .equals ("INT" ) || mDT .equals ("FLOAT" )) {
157+ 						root .unsafePutMeta ("ncol" , new  RewriterDataType ().ofType ("INT" ).as ("1" ).asLiteral (1L ));
158+ 						root .unsafePutMeta ("nrow" , new  RewriterDataType ().ofType ("INT" ).as ("1" ).asLiteral (1L ));
159+ 						return  null ;
160+ 					}
154161			}
155162
156163			switch (root .trueTypedInstruction (ctx )) {
@@ -186,23 +193,43 @@ private RewriterStatement propagateDims(RewriterStatement root, RewriterStatemen
186193					root .unsafePutMeta ("ncol" , new  RewriterDataType ().ofType ("INT" ).as ("1" ).asLiteral (1L ));
187194					return  null ;
188195				case  "[](MATRIX,INT,INT,INT,INT)" :
189- 					Integer [] ints  = new  Integer [4 ];
196+ 					Long [] ints  = new  Long [4 ];
190197
191- 					for  (int  i  = 0 ; i  < 4 ; i ++)
192- 						if  (root .getOperands ().get (1 ).isLiteral ())
193- 							ints [i ] = (Integer )root .getOperands ().get (1 ).getLiteral ();
198+ 					for  (int  i  = 1 ; i  < 5 ; i ++)
199+ 						if  (root .getChild (i ).isLiteral ())
200+ 							if  (root .getChild (i ).getLiteral () instanceof  Integer )
201+ 								ints [i -1 ] = (Long )root .getChild (i ).getLiteral ();
194202
195203					if  (ints [0 ] != null  && ints [1 ] != null ) {
196- 						root .unsafePutMeta ("nrow" , ints [1 ] - ints [0 ] + 1 );
204+ 						String  literalString  = Long .toString (ints [1 ] - ints [0 ] + 1 );
205+ 						root .unsafePutMeta ("nrow" , RewriterUtils .parse (literalString , ctx , "LITERAL_INT:"  + literalString ));
197206					} else  {
198- 						throw  new  NotImplementedException ();
199- 						// TODO: 
207+ 						HashMap <String , RewriterStatement > subStmts  = new  HashMap <>();
208+ 						subStmts .put ("i1" , root .getOperands ().get (2 ));
209+ 						subStmts .put ("i0" , root .getOperands ().get (1 ));
210+ 
211+ 						if  (ints [0 ] != null ) {
212+ 							root .unsafePutMeta ("nrow" , RewriterUtils .parse ("+(i1, "  + (1  - ints [0 ]) + ")" , ctx , subStmts , "LITERAL_INT:"  + (1  - ints [0 ])));
213+ 						} else  if  (ints [1 ] != null ) {
214+ 							root .unsafePutMeta ("nrow" , RewriterUtils .parse ("-("  + (ints [1 ] + 1 ) + ", i0)" , ctx , subStmts , "LITERAL_INT:"  + (ints [1 ] + 1 )));
215+ 						} else  {
216+ 							root .unsafePutMeta ("nrow" , RewriterUtils .parse ("+(-(i1, i0), 1)" , ctx , subStmts , "LITERAL_INT:1" ));
217+ 						}
200218					}
201219
202220					if  (ints [2 ] != null  && ints [3 ] != null ) {
203221						root .unsafePutMeta ("ncol" , ints [3 ] - ints [2 ] + 1 );
204222					} else  {
205- 						throw  new  NotImplementedException ();
223+ 						HashMap <String , RewriterStatement > subStmts  = new  HashMap <>();
224+ 						subStmts .put ("i3" , root .getOperands ().get (4 ));
225+ 						subStmts .put ("i2" , root .getOperands ().get (3 ));
226+ 						if  (ints [2 ] != null ) {
227+ 							root .unsafePutMeta ("ncol" , RewriterUtils .parse ("+(i3, "  + (1  - ints [2 ]) + ")" , ctx , subStmts , "LITERAL_INT:"  + (1  - ints [2 ])));
228+ 						} else  if  (ints [3 ] != null ) {
229+ 							root .unsafePutMeta ("ncol" , RewriterUtils .parse ("-("  + (ints [3 ] + 1 ) + ", i2)" , ctx , subStmts , "LITERAL_INT:"  + (ints [3 ] + 1 )));
230+ 						} else  {
231+ 							root .unsafePutMeta ("ncol" , RewriterUtils .parse ("+(-(i3, i2), 1)" , ctx , subStmts , "LITERAL_INT:1" ));
232+ 						}
206233					}
207234
208235					return  null ;
@@ -214,6 +241,10 @@ private RewriterStatement propagateDims(RewriterStatement root, RewriterStatemen
214241					root .unsafePutMeta ("ncol" , root .getOperands ().get (0 ).getMeta ("ncol" ));
215242					root .unsafePutMeta ("nrow" , new  RewriterDataType ().ofType ("INT" ).as ("1" ).asLiteral (1L ));
216243					return  null ;
244+ 				case  "cast.MATRIX(MATRIX)" :
245+ 					root .unsafePutMeta ("nrow" , root .getOperands ().get (0 ).getMeta ("nrow" ));
246+ 					root .unsafePutMeta ("ncol" , root .getOperands ().get (0 ).getMeta ("ncol" ));
247+ 					return  null ;
217248			}
218249
219250			RewriterInstruction  instr  = (RewriterInstruction ) root ;
@@ -230,6 +261,18 @@ private RewriterStatement propagateDims(RewriterStatement root, RewriterStatemen
230261				return  null ;
231262			}
232263
264+ 			if  (instr .getProperties (ctx ).contains ("ElementWiseUnary.FLOAT" )) {
265+ 				if  (root .getOperands ().get (0 ).getResultingDataType (ctx ).startsWith ("MATRIX" )) {
266+ 					root .unsafePutMeta ("nrow" , root .getOperands ().get (0 ).getMeta ("nrow" ));
267+ 					root .unsafePutMeta ("ncol" , root .getOperands ().get (0 ).getMeta ("ncol" ));
268+ 				} else  {
269+ 					root .unsafePutMeta ("nrow" , root .getOperands ().get (1 ).getMeta ("nrow" ));
270+ 					root .unsafePutMeta ("ncol" , root .getOperands ().get (1 ).getMeta ("ncol" ));
271+ 				}
272+ 
273+ 				return  null ;
274+ 			}
275+ 
233276			throw  new  NotImplementedException ("Unknown instruction: "  + instr .trueTypedInstruction (ctx ) + "\n "  + instr .toString (ctx ));
234277		}
235278
0 commit comments