@@ -1200,18 +1200,18 @@ def __call__(self, *args):
12001200
12011201def pydotprint (
12021202 fct ,
1203- outfile = None ,
1204- compact = True ,
1205- format = "png" ,
1206- with_ids = False ,
1207- high_contrast = True ,
1203+ outfile : str | None = None ,
1204+ compact : bool = True ,
1205+ format : str = "png" ,
1206+ with_ids : bool = False ,
1207+ high_contrast : bool = True ,
12081208 cond_highlight = None ,
1209- colorCodes = None ,
1210- max_label_size = 70 ,
1211- scan_graphs = False ,
1212- var_with_name_simple = False ,
1213- print_output_file = True ,
1214- return_image = False ,
1209+ colorCodes : dict | None = None ,
1210+ max_label_size : int = 70 ,
1211+ scan_graphs : bool = False ,
1212+ var_with_name_simple : bool = False ,
1213+ print_output_file : bool = True ,
1214+ return_image : bool = False ,
12151215):
12161216 """Print to a file the graph of a compiled pytensor function's ops. Supports
12171217 all pydot output formats, including png and svg.
@@ -1676,7 +1676,9 @@ def get_tag(self):
16761676 return rval
16771677
16781678
1679- def min_informative_str (obj , indent_level = 0 , _prev_obs = None , _tag_generator = None ):
1679+ def min_informative_str (
1680+ obj , indent_level : int = 0 , _prev_obs : dict | None = None , _tag_generator = None
1681+ ) -> str :
16801682 """
16811683 Returns a string specifying to the user what obj is
16821684 The string will print out as much of the graph as is needed
@@ -1776,7 +1778,7 @@ def min_informative_str(obj, indent_level=0, _prev_obs=None, _tag_generator=None
17761778 return rval
17771779
17781780
1779- def var_descriptor (obj , _prev_obs = None , _tag_generator = None ):
1781+ def var_descriptor (obj , _prev_obs : dict | None = None , _tag_generator = None ) -> str :
17801782 """
17811783 Returns a string, with no endlines, fully specifying
17821784 how a variable is computed. Does not include any memory
@@ -1832,7 +1834,7 @@ def var_descriptor(obj, _prev_obs=None, _tag_generator=None):
18321834 return rval
18331835
18341836
1835- def position_independent_str (obj ):
1837+ def position_independent_str (obj ) -> str :
18361838 if isinstance (obj , Variable ):
18371839 rval = "pytensor_var"
18381840 rval += "{type=" + str (obj .type ) + "}"
@@ -1842,7 +1844,7 @@ def position_independent_str(obj):
18421844 return rval
18431845
18441846
1845- def hex_digest (x ) :
1847+ def hex_digest (x : np . ndarray ) -> str :
18461848 """
18471849 Returns a short, mostly hexadecimal hash of a numpy ndarray
18481850 """
@@ -1852,8 +1854,8 @@ def hex_digest(x):
18521854 # because the buffer interface only exposes the raw data, not
18531855 # any info about the semantics of how that data should be arranged
18541856 # into a tensor
1855- rval = rval + "|strides=[" + "," .join (str (stride ) for stride in x .strides ) + "]"
1856- rval = rval + "|shape=[" + "," .join (str (s ) for s in x .shape ) + "]"
1857+ rval += "|strides=[" + "," .join (str (stride ) for stride in x .strides ) + "]"
1858+ rval += "|shape=[" + "," .join (str (s ) for s in x .shape ) + "]"
18571859 return rval
18581860
18591861
0 commit comments