@@ -865,19 +865,19 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
865865 if not isinstance (node .op .core_op , Cholesky ):
866866 return None
867867
868- inputs = node .inputs [ 0 ]
868+ [ input ] = node .inputs
869869 # Check for use of pt.diag first
870870 if (
871- inputs .owner
872- and isinstance (inputs .owner .op , AllocDiag )
873- and AllocDiag .is_offset_zero (inputs .owner )
871+ input .owner
872+ and isinstance (input .owner .op , AllocDiag )
873+ and AllocDiag .is_offset_zero (input .owner )
874874 ):
875- diag_input = inputs .owner .inputs [0 ]
875+ diag_input = input .owner .inputs [0 ]
876876 cholesky_val = pt .diag (diag_input ** 0.5 )
877877 return [cholesky_val ]
878878
879879 # Check if the input is an elemwise multiply with identity matrix -- this also results in a diagonal matrix
880- inputs_or_none = _find_diag_from_eye_mul (inputs )
880+ inputs_or_none = _find_diag_from_eye_mul (input )
881881 if inputs_or_none is None :
882882 return None
883883
@@ -887,7 +887,7 @@ def rewrite_cholesky_diag_to_sqrt_diag(fgraph, node):
887887 if len (non_eye_inputs ) != 1 :
888888 return None
889889
890- non_eye_input = non_eye_inputs [ 0 ]
890+ [ non_eye_input ] = non_eye_inputs
891891
892892 # Now, we can simply return the matrix consisting of sqrt values of the original diagonal elements
893893 # For a matrix, we have to first extract the diagonal (non-zero values) and then only use those
0 commit comments