@@ -1675,6 +1675,7 @@ def verify_grad(
16751675 mode : Optional [Union ["Mode" , str ]] = None ,
16761676 cast_to_output_type : bool = False ,
16771677 no_debug_ref : bool = True ,
1678+ sum_outputs = False ,
16781679):
16791680 """Test a gradient by Finite Difference Method. Raise error on failure.
16801681
@@ -1722,7 +1723,9 @@ def verify_grad(
17221723 float16 is not handled here.
17231724 no_debug_ref
17241725 Don't use `DebugMode` for the numerical gradient function.
1725-
1726+ sum_outputs: bool, default False
1727+ If True, the gradient of the sum of all outputs is verified. If False, an error is raised if the function has
1728+ multiple outputs.
17261729 Notes
17271730 -----
17281731 This function does not support multiple outputs. In `tests.scan.test_basic`
@@ -1782,7 +1785,7 @@ def verify_grad(
17821785 # fun can be either a function or an actual Op instance
17831786 o_output = fun (* tensor_pt )
17841787
1785- if isinstance (o_output , list ):
1788+ if isinstance (o_output , list ) and not sum_outputs :
17861789 raise NotImplementedError (
17871790 "Can't (yet) auto-test the gradient of a function with multiple outputs"
17881791 )
@@ -1793,7 +1796,7 @@ def verify_grad(
17931796 o_fn = fn_maker (tensor_pt , o_output , name = "gradient.py fwd" )
17941797 o_fn_out = o_fn (* [p .copy () for p in pt ])
17951798
1796- if isinstance (o_fn_out , tuple ) or isinstance (o_fn_out , list ):
1799+ if isinstance (o_fn_out , tuple ) or isinstance (o_fn_out , list ) and not sum_outputs :
17971800 raise TypeError (
17981801 "It seems like you are trying to use verify_grad "
17991802 "on an Op or a function which outputs a list: there should"
@@ -1802,33 +1805,45 @@ def verify_grad(
18021805
18031806 # random_projection should not have elements too small,
18041807 # otherwise too much precision is lost in numerical gradient
1805- def random_projection ():
1806- plain = rng .random (o_fn_out . shape ) + 0.5
1807- if cast_to_output_type and o_output . dtype == "float32" :
1808- return np .array (plain , o_output . dtype )
1808+ def random_projection (shape , dtype ):
1809+ plain = rng .random (shape ) + 0.5
1810+ if cast_to_output_type and dtype == "float32" :
1811+ return np .array (plain , dtype )
18091812 return plain
18101813
1811- t_r = shared (random_projection (), borrow = True )
1812- t_r .name = "random_projection"
1813-
18141814 # random projection of o onto t_r
18151815 # This sum() is defined above, it's not the builtin sum.
1816- cost = pytensor .tensor .sum (t_r * o_output )
1816+ if sum_outputs :
1817+ t_rs = [
1818+ shared (
1819+ value = random_projection (o .shape , o .dtype ),
1820+ borrow = True ,
1821+ name = f"random_projection_{ i } " ,
1822+ )
1823+ for i , o in enumerate (o_fn_out )
1824+ ]
1825+ cost = pytensor .tensor .sum (
1826+ [pytensor .tensor .sum (x * y ) for x , y in zip (t_rs , o_output )]
1827+ )
1828+ else :
1829+ t_r = shared (
1830+ value = random_projection (o_fn_out .shape , o_fn_out .dtype ),
1831+ borrow = True ,
1832+ name = "random_projection" ,
1833+ )
1834+ cost = pytensor .tensor .sum (t_r * o_output )
18171835
18181836 if no_debug_ref :
18191837 mode_for_cost = mode_not_slow (mode )
18201838 else :
18211839 mode_for_cost = mode
18221840
18231841 cost_fn = fn_maker (tensor_pt , cost , name = "gradient.py cost" , mode = mode_for_cost )
1824-
18251842 symbolic_grad = grad (cost , tensor_pt , disconnected_inputs = "ignore" )
1826-
18271843 grad_fn = fn_maker (tensor_pt , symbolic_grad , name = "gradient.py symbolic grad" )
18281844
18291845 for test_num in range (n_tests ):
18301846 num_grad = numeric_grad (cost_fn , [p .copy () for p in pt ], eps , out_type )
1831-
18321847 analytic_grad = grad_fn (* [p .copy () for p in pt ])
18331848
18341849 # Since `tensor_pt` is a list, `analytic_grad` should be one too.
@@ -1853,7 +1868,16 @@ def random_projection():
18531868
18541869 # get new random projection for next test
18551870 if test_num < n_tests - 1 :
1856- t_r .set_value (random_projection (), borrow = True )
1871+ if sum_outputs :
1872+ for r in t_rs :
1873+ r .set_value (
1874+ random_projection (r .get_value ().shape , r .get_value ().dtype )
1875+ )
1876+ else :
1877+ t_r .set_value (
1878+ random_projection (t_r .get_value ().shape , t_r .get_value ().dtype ),
1879+ borrow = True ,
1880+ )
18571881
18581882
18591883class GradientError (Exception ):
0 commit comments