@@ -3903,5 +3903,139 @@ def main(inp_0: R.Tensor((2, 3), dtype="float32")) -> R.Tensor((), dtype="bool")
39033903 verify_model (IsFloatingPoint (), [([2 , 3 ], "float32" )], {}, Expected )
39043904
39053905
3906+ def test_gather ():
3907+ class Gather0 (Module ):
3908+ def forward (self , data , indices ):
3909+ return torch .gather (data , 0 , indices )
3910+
3911+ class Gather1 (Module ):
3912+ def forward (self , data , indices ):
3913+ return torch .gather (data , 1 , indices )
3914+
3915+ class Gather2 (Module ):
3916+ def forward (self , data , indices ):
3917+ return torch .gather (data , - 1 , indices )
3918+
3919+ class Gather3 (Module ):
3920+ def forward (self , data , indices ):
3921+ return torch .gather (data , - 2 , indices )
3922+
3923+ @tvm .script .ir_module
3924+ class Expected0 :
3925+ @R .function
3926+ def main (
3927+ inp_0 : R .Tensor ((2 , 3 ), dtype = "float32" ),
3928+ inp_1 : R .Tensor ((2 , 3 ), dtype = "int32" ),
3929+ ) -> R .Tensor ((2 , 3 ), dtype = "float32" ):
3930+ with R .dataflow ():
3931+ lv : R .Tensor ((2 , 3 ), dtype = "float32" ) = R .gather_elements (inp_0 , inp_1 , axis = 0 )
3932+ gv : R .Tensor ((2 , 3 ), dtype = "float32" ) = lv
3933+ R .output (gv )
3934+ return gv
3935+
3936+ @tvm .script .ir_module
3937+ class Expected1 :
3938+ @R .function
3939+ def main (
3940+ inp_0 : R .Tensor ((2 , 3 ), dtype = "float32" ),
3941+ inp_1 : R .Tensor ((2 , 3 ), dtype = "int32" ),
3942+ ) -> R .Tensor ((2 , 3 ), dtype = "float32" ):
3943+ with R .dataflow ():
3944+ lv : R .Tensor ((2 , 3 ), dtype = "float32" ) = R .gather_elements (inp_0 , inp_1 , axis = 1 )
3945+ gv : R .Tensor ((2 , 3 ), dtype = "float32" ) = lv
3946+ R .output (gv )
3947+ return gv
3948+
3949+ @tvm .script .ir_module
3950+ class Expected2 :
3951+ @R .function
3952+ def main (
3953+ inp_0 : R .Tensor ((2 , 3 ), dtype = "float32" ),
3954+ inp_1 : R .Tensor ((2 , 3 ), dtype = "int32" ),
3955+ ) -> R .Tensor ((2 , 3 ), dtype = "float32" ):
3956+ with R .dataflow ():
3957+ lv : R .Tensor ((2 , 3 ), dtype = "float32" ) = R .gather_elements (inp_0 , inp_1 , axis = - 1 )
3958+ gv : R .Tensor ((2 , 3 ), dtype = "float32" ) = lv
3959+ R .output (gv )
3960+ return gv
3961+
3962+ @tvm .script .ir_module
3963+ class Expected3 :
3964+ @R .function
3965+ def main (
3966+ inp_0 : R .Tensor ((2 , 3 ), dtype = "float32" ),
3967+ inp_1 : R .Tensor ((2 , 3 ), dtype = "int32" ),
3968+ ) -> R .Tensor ((2 , 3 ), dtype = "float32" ):
3969+ with R .dataflow ():
3970+ lv : R .Tensor ((2 , 3 ), dtype = "float32" ) = R .gather_elements (inp_0 , inp_1 , axis = - 2 )
3971+ gv : R .Tensor ((2 , 3 ), dtype = "float32" ) = lv
3972+ R .output (gv )
3973+ return gv
3974+
3975+ verify_model (Gather0 (), [([2 , 3 ], "float32" ), ([2 , 3 ], "int32" )], {}, Expected0 )
3976+ verify_model (Gather1 (), [([2 , 3 ], "float32" ), ([2 , 3 ], "int32" )], {}, Expected1 )
3977+ verify_model (Gather2 (), [([2 , 3 ], "float32" ), ([2 , 3 ], "int32" )], {}, Expected2 )
3978+ verify_model (Gather3 (), [([2 , 3 ], "float32" ), ([2 , 3 ], "int32" )], {}, Expected3 )
3979+
3980+
3981+ def test_flip ():
3982+ class Flip0 (Module ):
3983+ def forward (self , data ):
3984+ return torch .flip (data , [0 ])
3985+
3986+ class Flip1 (Module ):
3987+ def forward (self , data ):
3988+ return torch .flip (data , [1 ])
3989+
3990+ @tvm .script .ir_module
3991+ class Expected0 :
3992+ @R .function
3993+ def main (
3994+ inp_0 : R .Tensor ((2 , 2 ), dtype = "float32" ),
3995+ ) -> R .Tensor ((2 , 2 ), dtype = "float32" ):
3996+ with R .dataflow ():
3997+ lv : R .Tensor ((2 , 2 ), dtype = "float32" ) = R .flip (inp_0 , axis = 0 )
3998+ gv : R .Tensor ((2 , 2 ), dtype = "float32" ) = lv
3999+ R .output (gv )
4000+ return gv
4001+
4002+ @tvm .script .ir_module
4003+ class Expected1 :
4004+ @R .function
4005+ def main (
4006+ inp_0 : R .Tensor ((2 , 2 ), dtype = "float32" ),
4007+ ) -> R .Tensor ((2 , 2 ), dtype = "float32" ):
4008+ with R .dataflow ():
4009+ lv : R .Tensor ((2 , 2 ), dtype = "float32" ) = R .flip (inp_0 , axis = 1 )
4010+ gv : R .Tensor ((2 , 2 ), dtype = "float32" ) = lv
4011+ R .output (gv )
4012+ return gv
4013+
4014+ verify_model (Flip0 (), [([2 , 2 ], "float32" )], {}, Expected0 )
4015+ verify_model (Flip1 (), [([2 , 2 ], "float32" )], {}, Expected1 )
4016+
4017+
4018+ def test_take ():
4019+ class Take (Module ):
4020+ def forward (self , data , indices ):
4021+ return torch .take (data , indices )
4022+
4023+ @tvm .script .ir_module
4024+ class Expected :
4025+ @R .function
4026+ def main (
4027+ inp_0 : R .Tensor ((5 ,), dtype = "float32" ),
4028+ inp_1 : R .Tensor ((3 ,), dtype = "int32" ),
4029+ ) -> R .Tensor ((3 ,), dtype = "float32" ):
4030+ with R .dataflow ():
4031+ lv : R .Tensor ((3 ,), dtype = "int32" ) = R .astype (inp_1 , "int32" )
4032+ lv1 : R .Tensor ((3 ,), dtype = "float32" ) = R .take (inp_0 , lv )
4033+ gv : R .Tensor ((3 ,), dtype = "float32" ) = lv1
4034+ R .output (gv )
4035+ return gv
4036+
4037+ verify_model (Take (), [([5 ], "float32" ), ([3 ], "int32" )], {}, Expected )
4038+
4039+
39064040if __name__ == "__main__" :
39074041 tvm .testing .main ()
0 commit comments