44"""
55
66import json
7- import os
87import shutil
8+ from pathlib import Path
99
1010from pytensor .d3viz .formatting import PyDotFormatter
1111
1212
13- __path__ = os . path . dirname ( os . path . realpath ( __file__ ))
13+ __path__ = Path ( __file__ ). parent
1414
1515
1616def replace_patterns (x , replace ):
@@ -40,7 +40,7 @@ def safe_json(obj):
4040 return json .dumps (obj ).replace ("<" , "\\ u003c" )
4141
4242
43- def d3viz (fct , outfile , copy_deps = True , * args , ** kwargs ):
43+ def d3viz (fct , outfile : Path | str , copy_deps : bool = True , * args , ** kwargs ):
4444 """Create HTML file with dynamic visualizing of an PyTensor function graph.
4545
4646 In the HTML file, the whole graph or single nodes can be moved by drag and
@@ -59,7 +59,7 @@ def d3viz(fct, outfile, copy_deps=True, *args, **kwargs):
5959 ----------
6060 fct : pytensor.compile.function.types.Function
6161 A compiled PyTensor function, variable, apply or a list of variables.
62- outfile : str
62+ outfile : Path | str
6363 Path to output HTML file.
6464 copy_deps : bool, optional
6565 Copy javascript and CSS dependencies to output directory.
@@ -78,30 +78,28 @@ def d3viz(fct, outfile, copy_deps=True, *args, **kwargs):
7878 dot_graph = dot_graph .decode ("utf8" )
7979
8080 # Create output directory if not existing
81- outdir = os .path .dirname (outfile )
82- if outdir != "" and not os .path .exists (outdir ):
83- os .makedirs (outdir )
81+ outdir = Path (outfile ).parent
82+ outdir .mkdir (exist_ok = True )
8483
8584 # Read template HTML file
86- template_file = os .path .join (__path__ , "html" , "template.html" )
87- with open (template_file ) as f :
88- template = f .read ()
85+ template_file = __path__ / "html/template.html"
86+ template = template_file .read_text (encoding = "utf-8" )
8987
9088 # Copy dependencies to output directory
9189 src_deps = __path__
9290 if copy_deps :
93- dst_deps = "d3viz"
91+ dst_deps = outdir / "d3viz"
9492 for d in ("js" , "css" ):
95- dep = os . path . join ( outdir , dst_deps , d )
96- if not os . path . exists (dep ):
97- shutil .copytree (os . path . join ( src_deps , d ) , dep )
93+ dep = dst_deps / d
94+ if not dep . exists ():
95+ shutil .copytree (src_deps / d , dep )
9896 else :
9997 dst_deps = src_deps
10098
10199 # Replace patterns in template
102100 replace = {
103- "%% JS_DIR %%" : os . path . join ( dst_deps , "js" ) ,
104- "%% CSS_DIR %%" : os . path . join ( dst_deps , "css" ) ,
101+ "%% JS_DIR %%" : dst_deps / "js" ,
102+ "%% CSS_DIR %%" : dst_deps / "css" ,
105103 "%% DOT_GRAPH %%" : safe_json (dot_graph ),
106104 }
107105 html = replace_patterns (template , replace )
0 commit comments