22from collections .abc import Callable
33from typing import cast
44
5+ import numpy as np
6+
57from pytensor import Variable
68from pytensor import tensor as pt
79from pytensor .graph import Apply , FunctionGraph
@@ -967,23 +969,24 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
967969 return [eye_input * (non_eye_input ** 0.5 )]
968970
969971
970- # SLogDet Rewrites
971- def check_log_abs_det (fgraph , client ):
972+ def _check_log_abs_det (fgraph , client ):
972973 # First, we find abs
973974 if not (isinstance (client .op , Elemwise ) and isinstance (client .op .scalar_op , Abs )):
974975 return False
975976
976977 # Check whether log is a client of abs
977978 for client_2 in fgraph .clients [client .outputs [0 ]]:
978979 if not (
979- isinstance (client_2 .op , Elemwise ) and isinstance (client_2 .op .scalar_op , Log )
980+ isinstance (client_2 [0 ].op , Elemwise )
981+ and isinstance (client_2 [0 ].op .scalar_op , Log )
980982 ):
981983 return False
982984
983985 return True
984986
985987
986- @node_rewriter (tracks = [det ])
988+ @register_specialize
989+ @node_rewriter ([det ])
987990def slogdet_specialization (fgraph , node ):
988991 replacements = {}
989992 for client in fgraph .clients [node .outputs [0 ]]:
@@ -996,19 +999,22 @@ def slogdet_specialization(fgraph, node):
996999 replacements [client [0 ].outputs [0 ]] = sign_det_x
9971000
9981001 # Check for log(abs(det))
999- elif check_log_abs_det (fgraph , client [0 ]):
1002+ elif _check_log_abs_det (fgraph , client [0 ]):
10001003 x = node .inputs [0 ]
10011004 sign_det_x , slog_det_x = SLogDet ()(x )
10021005 replacements [fgraph .clients [client [0 ].outputs [0 ]][0 ][0 ].outputs [0 ]] = (
10031006 slog_det_x
10041007 )
10051008
10061009 # Check for log(det)
1007- # elif isinstance(client[0].op, Elemwise) and isinstance(
1008- # client[0].op.scalar_op, Log
1009- # ):
1010- # pass
1011- # replacements[client[0].owner.outputs[0]] = pt.where(pt.eq(sign_det_x, -1), np.nan, slog_det_x)
1010+ elif isinstance (client [0 ].op , Elemwise ) and isinstance (
1011+ client [0 ].op .scalar_op , Log
1012+ ):
1013+ x = node .inputs [0 ]
1014+ sign_det_x , slog_det_x = SLogDet ()(x )
1015+ replacements [client [0 ].outputs [0 ]] = pt .where (
1016+ pt .eq (sign_det_x , - 1 ), np .nan , slog_det_x
1017+ )
10121018
10131019 # Det is used directly for something else, don't rewrite to avoid computing two dets
10141020 else :
0 commit comments