@@ -643,35 +643,37 @@ def no_shared_fn(n, x_tm1, M):
643643 # (i.e. from `Scan._fn`)
644644 out = pytensor .function ([M ], out , updates = updates , mode = "FAST_RUN" )
645645
646- expected_output = """Scan{scan_fn, while_loop=False, inplace=all} [id A] 2 (outer_out_sit_sot-0)
647- ├─ 20000 [id B] (n_steps)
648- ├─ [ 0 ... 998 19999] [id C] (outer_in_seqs-0)
649- ├─ SetSubtensor{:stop} [id D] 1 (outer_in_sit_sot-0)
650- │ ├─ AllocEmpty{dtype='int64'} [id E] 0
651- │ │ └─ 20000 [id B]
652- │ ├─ [0] [id F]
653- │ └─ 1 [id G]
654- └─ <Tensor3(float64, shape=(20000, 2, 2))> [id H] (outer_in_non_seqs-0)
655-
656- Inner graphs:
657-
658- Scan{scan_fn, while_loop=False, inplace=all} [id A]
659- ← Composite{switch(lt(0, i0), 1, 0)} [id I] (inner_out_sit_sot-0)
660- └─ Subtensor{i, j, k} [id J]
661- ├─ *2-<Tensor3(float64, shape=(20000, 2, 2))> [id K] -> [id H] (inner_in_non_seqs-0)
662- ├─ ScalarFromTensor [id L]
663- │ └─ *0-<Scalar(int64, shape=())> [id M] -> [id C] (inner_in_seqs-0)
664- ├─ ScalarFromTensor [id N]
665- │ └─ *1-<Scalar(int64, shape=())> [id O] -> [id D] (inner_in_sit_sot-0)
666- └─ 0 [id P]
667-
668- Composite{switch(lt(0, i0), 1, 0)} [id I]
669- ← Switch [id Q] 'o0'
670- ├─ LT [id R]
671- │ ├─ 0 [id S]
672- │ └─ i0 [id T]
673- ├─ 1 [id U]
674- └─ 0 [id S]
646+ expected_output = """Subtensor{start:} [id A] 3
647+ ├─ Scan{scan_fn, while_loop=False, inplace=all} [id B] 2 (outer_out_sit_sot-0)
648+ │ ├─ 20000 [id C] (n_steps)
649+ │ ├─ [ 0 ... 998 19999] [id D] (outer_in_seqs-0)
650+ │ ├─ SetSubtensor{:stop} [id E] 1 (outer_in_sit_sot-0)
651+ │ │ ├─ AllocEmpty{dtype='int64'} [id F] 0
652+ │ │ │ └─ 20001 [id G]
653+ │ │ ├─ [0] [id H]
654+ │ │ └─ 1 [id I]
655+ │ └─ <Tensor3(float64, shape=(20000, 2, 2))> [id J] (outer_in_non_seqs-0)
656+ └─ 1 [id I]
657+
658+ Inner graphs:
659+
660+ Scan{scan_fn, while_loop=False, inplace=all} [id B]
661+ ← Composite{switch(lt(0, i0), 1, 0)} [id K] (inner_out_sit_sot-0)
662+ └─ Subtensor{i, j, k} [id L]
663+ ├─ *2-<Tensor3(float64, shape=(20000, 2, 2))> [id M] -> [id J] (inner_in_non_seqs-0)
664+ ├─ ScalarFromTensor [id N]
665+ │ └─ *0-<Scalar(int64, shape=())> [id O] -> [id D] (inner_in_seqs-0)
666+ ├─ ScalarFromTensor [id P]
667+ │ └─ *1-<Scalar(int64, shape=())> [id Q] -> [id E] (inner_in_sit_sot-0)
668+ └─ 0 [id R]
669+
670+ Composite{switch(lt(0, i0), 1, 0)} [id K]
671+ ← Switch [id S] 'o0'
672+ ├─ LT [id T]
673+ │ ├─ 0 [id U]
674+ │ └─ i0 [id V]
675+ ├─ 1 [id W]
676+ └─ 0 [id U]
675677 """
676678
677679 output_str = debugprint (out , file = "str" , print_op_info = True )
0 commit comments