@@ -968,30 +968,17 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
968968
969969
970970# SLogDet Rewrites
971- def check_sign_det (node ):
972- if not (isinstance (node .op , Elemwise ) and isinstance (node .op .scalar_op , Sign )):
971+ def check_log_abs_det (fgraph , client ):
972+ # First, we find abs
973+ if not (isinstance (client .op , Elemwise ) and isinstance (client .op .scalar_op , Abs )):
973974 return False
974975
975- return True
976-
977-
978- def check_log_abs_det (node ):
979- if not (isinstance (node .op , Elemwise ) and isinstance (node .op .scalar_op , Log )):
980- return False
981-
982- potential_abs = node .inputs [0 ].owner
983- if not (
984- isinstance (potential_abs .op , Elemwise )
985- and isinstance (potential_abs .op .scalar_op , Abs )
986- ):
987- return False
988-
989- return True
990-
991-
992- def check_log_det (node ):
993- if not (isinstance (node .op , Elemwise ) and isinstance (node .op .scalar_op , Log )):
994- return False
976+ # Check whether log is a client of abs
977+ for client_2 in fgraph .clients [client .outputs [0 ]]:
978+ if not (
979+ isinstance (client_2 .op , Elemwise ) and isinstance (client_2 .op .scalar_op , Log )
980+ ):
981+ return False
995982
996983 return True
997984
@@ -1001,17 +988,21 @@ def slogdet_specialization(fgraph, node):
1001988 x = node .inputs [0 ]
1002989 sign_det_x , slog_det_x = SLogDet ()(x )
1003990 replacements = {}
1004- for client in list ( fgraph .clients . keys ()) :
991+ for client in fgraph .clients [ node . outputs [ 0 ]] :
1005992 # Check for sign(det)
1006- if check_sign_det (client [0 ].owner ):
993+ if isinstance (client [0 ].op , Elemwise ) and isinstance (
994+ client [0 ].op .scalar_op , Sign
995+ ):
1007996 replacements [client [0 ].owner .outputs [0 ]] = sign_det_x
1008997
1009998 # Check for log(abs(det))
1010- elif check_log_abs_det (client [0 ]. owner ):
999+ elif check_log_abs_det (fgraph , client [0 ]):
10111000 replacements [client [0 ].owner .outputs [0 ]] = slog_det_x
10121001
10131002 # Check for log(det)
1014- elif check_log_det (client [0 ].owner ):
1003+ elif isinstance (client [0 ].op , Elemwise ) and isinstance (
1004+ client [0 ].op .scalar_op , Log
1005+ ):
10151006 pass
10161007 # replacements[client[0].owner.outputs[0]] = pt.where(pt.eq(sign_det_x, -1), np.nan, slog_det_x)
10171008
0 commit comments