File tree Expand file tree Collapse file tree 4 files changed +38
-0
lines changed Expand file tree Collapse file tree 4 files changed +38
-0
lines changed Original file line number Diff line number Diff line change @@ -1097,6 +1097,18 @@ def sync_shared(self):
10971097 # NOTE: sync was needed on old gpu backend
10981098 pass
10991099
1100+ def dprint (self , ** kwargs ):
1101+ """Debug print itself
1102+
1103+ Parameters
1104+ ----------
1105+ kwargs:
1106+ Optional keyword arguments to pass to debugprint function.
1107+ """
1108+ from pytensor .printing import debugprint
1109+
1110+ return debugprint (self , ** kwargs )
1111+
11001112
11011113# pickling/deepcopy support for Function
11021114def _pickle_Function (f ):
Original file line number Diff line number Diff line change @@ -927,3 +927,15 @@ def __contains__(self, item: Variable | Apply) -> bool:
927927 return item in self .apply_nodes
928928 else :
929929 raise TypeError ()
930+
931+ def dprint (self , ** kwargs ):
932+ """Debug print itself
933+
934+ Parameters
935+ ----------
936+ kwargs:
937+ Optional keyword arguments to pass to debugprint function.
938+ """
939+ from pytensor .printing import debugprint
940+
941+ return debugprint (self , ** kwargs )
Original file line number Diff line number Diff line change 1616from pytensor .graph .rewriting .basic import OpKeyGraphRewriter , PatternNodeRewriter
1717from pytensor .graph .utils import MissingInputError
1818from pytensor .link .vm import VMLinker
19+ from pytensor .printing import debugprint
1920from pytensor .tensor .math import dot , tanh
2021from pytensor .tensor .math import sum as pt_sum
2122from pytensor .tensor .type import (
@@ -862,6 +863,12 @@ def test_key_string_requirement(self):
862863 with pytest .raises (AssertionError ):
863864 function ([x ], outputs = {(1 , "b" ): x , 1.0 : x ** 2 })
864865
866+ def test_dprint (self ):
867+ x = pt .scalar ("x" )
868+ out = x + 1
869+ f = function ([x ], out )
870+ assert f .dprint (file = "str" ) == debugprint (f , file = "str" )
871+
865872
866873class TestPicklefunction :
867874 def test_deepcopy (self ):
Original file line number Diff line number Diff line change 88from pytensor .graph .basic import NominalVariable
99from pytensor .graph .fg import FunctionGraph
1010from pytensor .graph .utils import MissingInputError
11+ from pytensor .printing import debugprint
1112from tests .graph .utils import (
1213 MyConstant ,
1314 MyOp ,
@@ -706,3 +707,9 @@ def test_nominals(self):
706707 assert nm2 not in fg .inputs
707708 assert nm in fg .variables
708709 assert nm2 in fg .variables
710+
711+ def test_dprint (self ):
712+ r1 , r2 = MyVariable ("x" ), MyVariable ("y" )
713+ o1 = op1 (r1 , r2 )
714+ fg = FunctionGraph ([r1 , r2 ], [o1 ], clone = False )
715+ assert fg .dprint (file = "str" ) == debugprint (fg , file = "str" )
You can’t perform that action at this time.
0 commit comments