@@ -709,29 +709,66 @@ def test_shape_tuple():
709709
710710
711711class TestVectorize :
712+ @pytensor .config .change_flags (cxx = "" ) # For faster eval
712713 def test_shape (self ):
713- vec = tensor (shape = (None ,))
714- mat = tensor (shape = (None , None ))
715-
714+ vec = tensor (shape = (None ,), dtype = "float64" )
715+ mat = tensor (shape = (None , None ), dtype = "float64" )
716716 node = shape (vec ).owner
717- vect_node = vectorize_node (node , mat )
718- assert equal_computations (vect_node .outputs , [shape (mat )])
719717
718+ [vect_out ] = vectorize_node (node , mat ).outputs
719+ assert equal_computations (
720+ [vect_out ], [broadcast_to (mat .shape [1 :], (* mat .shape [:1 ], 1 ))]
721+ )
722+
723+ mat_test_value = np .ones ((5 , 3 ))
724+ ref_fn = np .vectorize (lambda vec : np .asarray (vec .shape ), signature = "(vec)->(1)" )
725+ np .testing .assert_array_equal (
726+ vect_out .eval ({mat : mat_test_value }),
727+ ref_fn (mat_test_value ),
728+ )
729+
730+ mat = tensor (shape = (None , None ), dtype = "float64" )
731+ tns = tensor (shape = (None , None , None , None ), dtype = "float64" )
732+ node = shape (mat ).owner
733+ [vect_out ] = vectorize_node (node , tns ).outputs
734+ assert equal_computations (
735+ [vect_out ], [broadcast_to (tns .shape [2 :], (* tns .shape [:2 ], 2 ))]
736+ )
737+
738+ tns_test_value = np .ones ((4 , 6 , 5 , 3 ))
739+ ref_fn = np .vectorize (
740+ lambda vec : np .asarray (vec .shape ), signature = "(m1,m2)->(2)"
741+ )
742+ np .testing .assert_array_equal (
743+ vect_out .eval ({tns : tns_test_value }),
744+ ref_fn (tns_test_value ),
745+ )
746+
747+ @pytensor .config .change_flags (cxx = "" ) # For faster eval
720748 def test_reshape (self ):
721749 x = scalar ("x" , dtype = int )
722- vec = tensor (shape = (None ,))
723- mat = tensor (shape = (None , None ))
750+ vec = tensor (shape = (None ,), dtype = "float64" )
751+ mat = tensor (shape = (None , None ), dtype = "float64" )
724752
725- shape = (2 , x )
753+ shape = (- 1 , x )
726754 node = reshape (vec , shape ).owner
727- vect_node = vectorize_node (node , mat , shape )
728- assert equal_computations (
729- vect_node .outputs , [reshape (mat , (* mat .shape [:1 ], 2 , x ))]
755+
756+ [vect_out ] = vectorize_node (node , mat , shape ).outputs
757+ assert equal_computations ([vect_out ], [reshape (mat , (* mat .shape [:1 ], - 1 , x ))])
758+
759+ x_test_value = 2
760+ mat_test_value = np .ones ((5 , 6 ))
761+ ref_fn = np .vectorize (
762+ lambda x , vec : vec .reshape (- 1 , x ), signature = "(),(vec1)->(mat1,mat2)"
763+ )
764+ np .testing .assert_array_equal (
765+ vect_out .eval ({x : x_test_value , mat : mat_test_value }),
766+ ref_fn (x_test_value , mat_test_value ),
730767 )
731768
732- new_shape = (5 , 2 , x )
733- vect_node = vectorize_node (node , mat , new_shape )
734- assert equal_computations (vect_node . outputs , [reshape (mat , new_shape )])
769+ new_shape = (5 , - 1 , x )
770+ [ vect_out ] = vectorize_node (node , mat , new_shape ). outputs
771+ assert equal_computations ([ vect_out ] , [reshape (mat , new_shape )])
735772
736773 with pytest .raises (NotImplementedError ):
737774 vectorize_node (node , vec , broadcast_to (as_tensor ([5 , 2 , x ]), (2 , 3 )))
0 commit comments