@@ -2179,6 +2179,72 @@ def step(s, xtm2, xtm1, z):
21792179 assert gg .eval ({seq : [1 , 1 ], x0 : [1 , 1 ], z : 2 }) == 12
21802180 assert gg .eval ({seq : [1 , 1 ], x0 : [1 , 1 ], z : 1 }) == 3 / 2
21812181
2182+ @pytest .mark .parametrize ("case" , ("inside-explicit" , "inside-implicit" , "outside" ))
2183+ def test_non_shaped_input_disconnected_gradient (self , case ):
2184+ """Test that Scan gradient works when non shaped variables are disconnected from the gradient.
2185+
2186+ Regression test for https://github.com/pymc-devs/pytensor/issues/6
2187+ """
2188+
2189+ # In all cases rng is disconnected from the output gradient
2190+ # Note that when it is an input to the scan (explicit or not) it is still not updated by the scan,
2191+ # so it is equivalent to the `outside` case. A rewrite could have legally hoisted the rng out of the scan.
2192+ rng = shared (np .random .default_rng ())
2193+
2194+ data = pt .zeros (16 )
2195+
2196+ nonlocal_random_index = pt .random .integers (16 , rng = rng )
2197+ nonlocal_random_datum = data [nonlocal_random_index ]
2198+
2199+ if case == "outside" :
2200+
2201+ def step (s , random_datum ):
2202+ return (random_datum + s ) ** 2
2203+
2204+ strict = True
2205+ non_sequences = [nonlocal_random_datum ]
2206+
2207+ elif case == "inside-implicit" :
2208+
2209+ def step (s ):
2210+ return (nonlocal_random_datum + s ) ** 2
2211+
2212+ strict = False
2213+ non_sequences = [] # Scan will introduce the non_sequences for us
2214+
2215+ elif case == "inside-explicit" :
2216+
2217+ def step (s , data , rng ):
2218+ random_index = pt .random .integers (
2219+ 16 , rng = rng
2220+ ) # Not updated by the scan
2221+ random_datum = data [random_index ]
2222+ return (random_datum + s ) ** 2
2223+
2224+ strict = (True ,)
2225+ non_sequences = [data , rng ]
2226+
2227+ else :
2228+ raise ValueError (f"Invalid case: { case } " )
2229+
2230+ seq = vector ("seq" )
2231+ xs , _ = scan (
2232+ step ,
2233+ sequences = [seq ],
2234+ non_sequences = non_sequences ,
2235+ strict = strict ,
2236+ )
2237+ x0 = xs [0 ]
2238+
2239+ np .testing .assert_allclose (
2240+ x0 .eval ({seq : [np .pi , np .nan , np .nan ]}),
2241+ np .pi ** 2 ,
2242+ )
2243+ np .testing .assert_allclose (
2244+ grad (x0 , seq )[0 ].eval ({seq : [np .pi , np .nan , np .nan ]}),
2245+ 2 * np .pi ,
2246+ )
2247+
21822248
21832249@pytest .mark .skipif (
21842250 not config .cxx , reason = "G++ not available, so we need to skip this test."
0 commit comments