22
33import time
44import warnings
5+ from collections .abc import Mapping , MutableSequence , Sequence
56from functools import partial , reduce
6- from typing import (
7- TYPE_CHECKING ,
8- Callable ,
9- Dict ,
10- List ,
11- Literal ,
12- Mapping ,
13- MutableSequence ,
14- Optional ,
15- Sequence ,
16- Tuple ,
17- TypeVar ,
18- Union ,
19- )
7+ from typing import TYPE_CHECKING , Callable , Literal , Optional , TypeVar , Union
208
219import numpy as np
2210
4432# TODO: Add `overload` variants
4533def as_list_or_tuple (
4634 use_list : bool , use_tuple : bool , outputs : Union [V , Sequence [V ]]
47- ) -> Union [V , List [V ], Tuple [V , ...]]:
35+ ) -> Union [V , list [V ], tuple [V , ...]]:
4836 """Return either a single object or a list/tuple of objects.
4937
5038 If `use_list` is True, `outputs` is returned as a list (if `outputs`
@@ -206,17 +194,17 @@ def Rop(
206194 """
207195
208196 if not isinstance (wrt , (list , tuple )):
209- _wrt : List [Variable ] = [pytensor .tensor .as_tensor_variable (wrt )]
197+ _wrt : list [Variable ] = [pytensor .tensor .as_tensor_variable (wrt )]
210198 else :
211199 _wrt = [pytensor .tensor .as_tensor_variable (x ) for x in wrt ]
212200
213201 if not isinstance (eval_points , (list , tuple )):
214- _eval_points : List [Variable ] = [pytensor .tensor .as_tensor_variable (eval_points )]
202+ _eval_points : list [Variable ] = [pytensor .tensor .as_tensor_variable (eval_points )]
215203 else :
216204 _eval_points = [pytensor .tensor .as_tensor_variable (x ) for x in eval_points ]
217205
218206 if not isinstance (f , (list , tuple )):
219- _f : List [Variable ] = [pytensor .tensor .as_tensor_variable (f )]
207+ _f : list [Variable ] = [pytensor .tensor .as_tensor_variable (f )]
220208 else :
221209 _f = [pytensor .tensor .as_tensor_variable (x ) for x in f ]
222210
@@ -237,7 +225,7 @@ def Rop(
237225 # Tensor, Sparse have the ndim attribute
238226 pass
239227
240- seen_nodes : Dict [Apply , Sequence [Variable ]] = {}
228+ seen_nodes : dict [Apply , Sequence [Variable ]] = {}
241229
242230 def _traverse (node ):
243231 """TODO: writeme"""
@@ -310,7 +298,7 @@ def _traverse(node):
310298 for out in _f :
311299 _traverse (out .owner )
312300
313- rval : List [Optional [Variable ]] = []
301+ rval : list [Optional [Variable ]] = []
314302 for out in _f :
315303 if out in _wrt :
316304 rval .append (_eval_points [_wrt .index (out )])
@@ -394,19 +382,19 @@ def Lop(
394382 If `f` is a list/tuple, then return a list/tuple with the results.
395383 """
396384 if not isinstance (eval_points , (list , tuple )):
397- _eval_points : List [Variable ] = [pytensor .tensor .as_tensor_variable (eval_points )]
385+ _eval_points : list [Variable ] = [pytensor .tensor .as_tensor_variable (eval_points )]
398386 else :
399387 _eval_points = [pytensor .tensor .as_tensor_variable (x ) for x in eval_points ]
400388
401389 if not isinstance (f , (list , tuple )):
402- _f : List [Variable ] = [pytensor .tensor .as_tensor_variable (f )]
390+ _f : list [Variable ] = [pytensor .tensor .as_tensor_variable (f )]
403391 else :
404392 _f = [pytensor .tensor .as_tensor_variable (x ) for x in f ]
405393
406394 grads = list (_eval_points )
407395
408396 if not isinstance (wrt , (list , tuple )):
409- _wrt : List [Variable ] = [pytensor .tensor .as_tensor_variable (wrt )]
397+ _wrt : list [Variable ] = [pytensor .tensor .as_tensor_variable (wrt )]
410398 else :
411399 _wrt = [pytensor .tensor .as_tensor_variable (x ) for x in wrt ]
412400
@@ -504,7 +492,7 @@ def grad(
504492 raise TypeError ("Cost must be a scalar." )
505493
506494 if not isinstance (wrt , Sequence ):
507- _wrt : List [Variable ] = [wrt ]
495+ _wrt : list [Variable ] = [wrt ]
508496 else :
509497 _wrt = list (wrt )
510498
@@ -1677,7 +1665,7 @@ def mode_not_slow(mode):
16771665
16781666def verify_grad (
16791667 fun : Callable ,
1680- pt : List [np .ndarray ],
1668+ pt : list [np .ndarray ],
16811669 n_tests : int = 2 ,
16821670 rng : Optional [Union [np .random .Generator , np .random .RandomState ]] = None ,
16831671 eps : Optional [float ] = None ,
0 commit comments