@@ -604,31 +604,40 @@ def no_shared_fn(n, x_tm1, M):
604604 out = pytensor .function ([M ], out , updates = updates , mode = "FAST_RUN" )
605605
606606 expected_output = """forall_inplace,cpu,scan_fn} [id A] 2 (outer_out_sit_sot-0)
607- |TensorConstant{20000} [id B] (n_steps)
608- |TensorConstant{[ 0 ..998 19999]} [id C] (outer_in_seqs-0)
609- |IncSubtensor{InplaceSet;:int64:} [id D] 1 (outer_in_sit_sot-0)
610- | |AllocEmpty{dtype='int64'} [id E] 0
611- | | |TensorConstant{20000} [id B]
612- | |TensorConstant{(1,) of 0} [id F]
613- | |ScalarConstant{1} [id G]
614- |<TensorType(float64, (20000, 2, 2))> [id H] (outer_in_non_seqs-0)
607+ |TensorConstant{20000} [id B] (n_steps)
608+ |TensorConstant{[ 0 ..998 19999]} [id C] (outer_in_seqs-0)
609+ |IncSubtensor{InplaceSet;:int64:} [id D] 1 (outer_in_sit_sot-0)
610+ | |AllocEmpty{dtype='int64'} [id E] 0
611+ | | |TensorConstant{20000} [id B]
612+ | |TensorConstant{(1,) of 0} [id F]
613+ | |ScalarConstant{1} [id G]
614+ |<TensorType(float64, (20000, 2, 2))> [id H] (outer_in_non_seqs-0)
615615
616616 Inner graphs:
617617
618618 forall_inplace,cpu,scan_fn} [id A] (outer_out_sit_sot-0)
619- >Elemwise{Composite{Switch(LT(i0, i1), i2, i0)}} [id I] (inner_out_sit_sot-0)
620- > |TensorConstant{0} [id J]
621- > |Subtensor{int64, int64, uint8} [id K]
622- > | |*2-<TensorType(float64, (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
623- > | |ScalarFromTensor [id M]
624- > | | |*0-<TensorType(int64, ())> [id N] -> [id C] (inner_in_seqs-0)
625- > | |ScalarFromTensor [id O]
626- > | | |*1-<TensorType(int64, ())> [id P] -> [id D] (inner_in_sit_sot-0)
627- > | |ScalarConstant{0} [id Q]
628- > |TensorConstant{1} [id R]
619+ >Elemwise{Composite} [id I] (inner_out_sit_sot-0)
620+ > |TensorConstant{0} [id J]
621+ > |Subtensor{int64, int64, uint8} [id K]
622+ > | |*2-<TensorType(float64, (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
623+ > | |ScalarFromTensor [id M]
624+ > | | |*0-<TensorType(int64, ())> [id N] -> [id C] (inner_in_seqs-0)
625+ > | |ScalarFromTensor [id O]
626+ > | | |*1-<TensorType(int64, ())> [id P] -> [id D] (inner_in_sit_sot-0)
627+ > | |ScalarConstant{0} [id Q]
628+ > |TensorConstant{1} [id R]
629+
630+ Elemwise{Composite} [id I]
631+ >Switch [id S]
632+ > |LT [id T]
633+ > | |<int64> [id U]
634+ > | |<float64> [id V]
635+ > |<int64> [id W]
636+ > |<int64> [id U]
629637 """
630638
631639 output_str = debugprint (out , file = "str" , print_op_info = True )
640+ print (output_str )
632641 lines = output_str .split ("\n " )
633642
634643 for truth , out in zip (expected_output .split ("\n " ), lines ):
0 commit comments