Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions pytensor/compile/compilelock.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from contextlib import contextmanager
from pathlib import Path

import filelock

from pytensor.configdefaults import config


Expand All @@ -35,8 +33,9 @@ def force_unlock(lock_dir: os.PathLike):
lock_dir : os.PathLike
Path to a directory that was locked with `lock_ctx`.
"""
from filelock import FileLock

fl = filelock.FileLock(Path(lock_dir) / ".lock")
fl = FileLock(Path(lock_dir) / ".lock")
fl.release(force=True)

dir_key = f"{lock_dir}-{os.getpid()}"
Expand All @@ -62,6 +61,8 @@ def lock_ctx(
Timeout in seconds for waiting in lock acquisition.
Defaults to `pytensor.config.compile__timeout`.
"""
from filelock import FileLock

if lock_dir is None:
lock_dir = config.compiledir

Expand All @@ -73,7 +74,7 @@ def lock_ctx(

if dir_key not in local_mem._locks:
local_mem._locks[dir_key] = True
fl = filelock.FileLock(Path(lock_dir) / ".lock")
fl = FileLock(Path(lock_dir) / ".lock")
fl.acquire(timeout=timeout)
try:
yield
Expand Down
7 changes: 3 additions & 4 deletions pytensor/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
import os
import platform
import re
import shutil
import socket
import sys
import textwrap
from pathlib import Path
from shutil import which

import numpy as np

Expand Down Expand Up @@ -349,7 +348,7 @@ def add_compile_configvars():

# Try to find the full compiler path from the name
if param != "":
newp = shutil.which(param)
newp = which(param)
if newp is not None:
param = newp
del newp
Expand Down Expand Up @@ -1190,7 +1189,7 @@ def _get_home_dir() -> Path:
"pytensor_version": pytensor.__version__,
"numpy_version": np.__version__,
"gxx_version": "xxx",
"hostname": socket.gethostname(),
"hostname": platform.node(),
}


Expand Down
4 changes: 2 additions & 2 deletions pytensor/configparser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import os
import shlex
import sys
import warnings
from collections.abc import Callable, Sequence
Expand All @@ -14,6 +13,7 @@
from functools import wraps
from io import StringIO
from pathlib import Path
from shlex import shlex

from pytensor.utils import hash_from_code

Expand Down Expand Up @@ -541,7 +541,7 @@ def parse_config_string(
Parses a config string (comma-separated key=value components) into a dict.
"""
config_dict = {}
my_splitter = shlex.shlex(config_string, posix=True)
my_splitter = shlex(config_string, posix=True)
my_splitter.whitespace = ","
my_splitter.whitespace_split = True
for kv_pair in my_splitter:
Expand Down
15 changes: 6 additions & 9 deletions pytensor/d3viz/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,7 @@
from pytensor.compile import Function, builders
from pytensor.graph.basic import Apply, Constant, Variable, graph_inputs
from pytensor.graph.fg import FunctionGraph
from pytensor.printing import pydot_imported, pydot_imported_msg


try:
from pytensor.printing import pd
except ImportError:
pass
from pytensor.printing import _try_pydot_import


class PyDotFormatter:
Expand All @@ -41,8 +35,7 @@

def __init__(self, compact=True):
"""Construct PyDotFormatter object."""
if not pydot_imported:
raise ImportError("Failed to import pydot. " + pydot_imported_msg)
_try_pydot_import()

Check warning on line 38 in pytensor/d3viz/formatting.py

View check run for this annotation

Codecov / codecov/patch

pytensor/d3viz/formatting.py#L38

Added line #L38 was not covered by tests

self.compact = compact
self.node_colors = {
Expand Down Expand Up @@ -115,6 +108,8 @@
pydot.Dot
Pydot graph of `fct`
"""
pd = _try_pydot_import()

Check warning on line 111 in pytensor/d3viz/formatting.py

View check run for this annotation

Codecov / codecov/patch

pytensor/d3viz/formatting.py#L111

Added line #L111 was not covered by tests

if graph is None:
graph = pd.Dot()

Expand Down Expand Up @@ -356,6 +351,8 @@

def dict_to_pdnode(d):
"""Create pydot node from dict."""
pd = _try_pydot_import()

Check warning on line 354 in pytensor/d3viz/formatting.py

View check run for this annotation

Codecov / codecov/patch

pytensor/d3viz/formatting.py#L354

Added line #L354 was not covered by tests

e = dict()
for k, v in d.items():
if v is not None:
Expand Down
5 changes: 4 additions & 1 deletion pytensor/graph/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import functools
import inspect
import logging
import pdb
import sys
import time
import traceback
Expand Down Expand Up @@ -237,6 +236,8 @@
if config.on_opt_error == "raise":
raise exc
elif config.on_opt_error == "pdb":
import pdb

Check warning on line 239 in pytensor/graph/rewriting/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/graph/rewriting/basic.py#L239

Added line #L239 was not covered by tests

pdb.post_mortem(sys.exc_info()[2])

def __init__(self, *rewrites, failure_callback=None):
Expand Down Expand Up @@ -1752,6 +1753,8 @@
_logger.error("TRACEBACK:")
_logger.error(traceback.format_exc())
if config.on_opt_error == "pdb":
import pdb

Check warning on line 1756 in pytensor/graph/rewriting/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/graph/rewriting/basic.py#L1756

Added line #L1756 was not covered by tests

pdb.post_mortem(sys.exc_info()[2])
elif isinstance(exc, AssertionError) or config.on_opt_error == "raise":
# We always crash on AssertionError because something may be
Expand Down
41 changes: 16 additions & 25 deletions pytensor/link/c/cmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,12 @@
from typing import TYPE_CHECKING, Protocol, cast

import numpy as np
from setuptools._distutils.sysconfig import (
get_config_h_filename,
get_config_var,
get_python_inc,
get_python_lib,
)

# we will abuse the lockfile mechanism when reading and writing the registry
from pytensor.compile.compilelock import lock_ctx
from pytensor.configdefaults import config, gcc_version_str
from pytensor.configparser import BoolParam, StrParam
from pytensor.graph.op import Op
from pytensor.link.c.exceptions import CompileError, MissingGXX
from pytensor.utils import (
LOCAL_BITWIDTH,
flatten,
Expand Down Expand Up @@ -266,6 +259,8 @@

def _get_ext_suffix():
"""Get the suffix for compiled extensions"""
from setuptools._distutils.sysconfig import get_config_var

dist_suffix = get_config_var("EXT_SUFFIX")
if dist_suffix is None:
dist_suffix = get_config_var("SO")
Expand Down Expand Up @@ -1697,6 +1692,8 @@


def std_include_dirs():
from setuptools._distutils.sysconfig import get_python_inc

numpy_inc_dirs = [np.get_include()]
py_inc = get_python_inc()
py_plat_spec_inc = get_python_inc(plat_specific=True)
Expand All @@ -1709,6 +1706,12 @@

@is_StdLibDirsAndLibsType
def std_lib_dirs_and_libs() -> tuple[list[str], ...] | None:
from setuptools._distutils.sysconfig import (
get_config_var,
get_python_inc,
get_python_lib,
)

# We cache the results as on Windows, this trigger file access and
# this method is called many times.
if std_lib_dirs_and_libs.data is not None:
Expand Down Expand Up @@ -2388,23 +2391,6 @@
# xcode's version.
cxxflags.append("-ld64")

if sys.platform == "win32":
# Workaround for https://github.com/Theano/Theano/issues/4926.
# https://github.com/python/cpython/pull/11283/ removed the "hypot"
# redefinition for recent CPython versions (>=2.7.16 and >=3.7.3).
# The following nullifies that redefinition, if it is found.
python_version = sys.version_info[:3]
if (3,) <= python_version < (3, 7, 3):
config_h_filename = get_config_h_filename()
try:
with open(config_h_filename) as config_h:
if any(
line.startswith("#define hypot _hypot") for line in config_h
):
cxxflags.append("-D_hypot=hypot")
except OSError:
pass

return cxxflags

@classmethod
Expand Down Expand Up @@ -2555,8 +2541,9 @@

"""
# TODO: Do not do the dlimport in this function

if not config.cxx:
from pytensor.link.c.exceptions import MissingGXX

Check warning on line 2545 in pytensor/link/c/cmodule.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/c/cmodule.py#L2545

Added line #L2545 was not covered by tests

raise MissingGXX("g++ not available! We can't compile c code.")

if include_dirs is None:
Expand Down Expand Up @@ -2586,6 +2573,8 @@
cppfile.write("\n")

if platform.python_implementation() == "PyPy":
from setuptools._distutils.sysconfig import get_config_var

Check warning on line 2576 in pytensor/link/c/cmodule.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/c/cmodule.py#L2576

Added line #L2576 was not covered by tests

suffix = "." + get_lib_extension()

dist_suffix = get_config_var("SO")
Expand Down Expand Up @@ -2642,6 +2631,8 @@
status = p_out[2]

if status:
from pytensor.link.c.exceptions import CompileError

tf = tempfile.NamedTemporaryFile(
mode="w", prefix="pytensor_compilation_error_", delete=False
)
Expand Down
3 changes: 2 additions & 1 deletion pytensor/link/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.link.basic import Container, LocalLinker
from pytensor.link.c.exceptions import MissingGXX
from pytensor.link.utils import (
gc_helper,
get_destroy_dependencies,
Expand Down Expand Up @@ -1006,6 +1005,8 @@ def make_vm(
compute_map,
updated_vars,
):
from pytensor.link.c.exceptions import MissingGXX

pre_call_clear = [storage_map[v] for v in self.no_recycling]

try:
Expand Down
83 changes: 44 additions & 39 deletions pytensor/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,39 +26,6 @@

IDTypesType = Literal["id", "int", "CHAR", "auto", ""]

pydot_imported = False
pydot_imported_msg = ""
try:
# pydot-ng is a fork of pydot that is better maintained
import pydot_ng as pd

if pd.find_graphviz():
pydot_imported = True
else:
pydot_imported_msg = "pydot-ng can't find graphviz. Install graphviz."
except ImportError:
try:
# fall back on pydot if necessary
import pydot as pd

if hasattr(pd, "find_graphviz"):
if pd.find_graphviz():
pydot_imported = True
else:
pydot_imported_msg = "pydot can't find graphviz"
else:
pd.Dot.create(pd.Dot())
pydot_imported = True
except ImportError:
# tests should not fail on optional dependency
pydot_imported_msg = (
"Install the python package pydot or pydot-ng. Install graphviz."
)
except Exception as e:
pydot_imported_msg = "An error happened while importing/trying pydot: "
pydot_imported_msg += str(e.args)


_logger = logging.getLogger("pytensor.printing")
VALID_ASSOC = {"left", "right", "either"}

Expand Down Expand Up @@ -1196,6 +1163,48 @@
}


def _try_pydot_import():
pydot_imported = False
pydot_imported_msg = ""
try:
# pydot-ng is a fork of pydot that is better maintained
import pydot_ng as pd

if pd.find_graphviz():
pydot_imported = True

Check warning on line 1174 in pytensor/printing.py

View check run for this annotation

Codecov / codecov/patch

pytensor/printing.py#L1174

Added line #L1174 was not covered by tests
else:
pydot_imported_msg = "pydot-ng can't find graphviz. Install graphviz."

Check warning on line 1176 in pytensor/printing.py

View check run for this annotation

Codecov / codecov/patch

pytensor/printing.py#L1176

Added line #L1176 was not covered by tests
except ImportError:
try:
# fall back on pydot if necessary
import pydot as pd

if hasattr(pd, "find_graphviz"):
if pd.find_graphviz():
pydot_imported = True

Check warning on line 1184 in pytensor/printing.py

View check run for this annotation

Codecov / codecov/patch

pytensor/printing.py#L1184

Added line #L1184 was not covered by tests
else:
pydot_imported_msg = "pydot can't find graphviz"

Check warning on line 1186 in pytensor/printing.py

View check run for this annotation

Codecov / codecov/patch

pytensor/printing.py#L1186

Added line #L1186 was not covered by tests
else:
pd.Dot.create(pd.Dot())
pydot_imported = True

Check warning on line 1189 in pytensor/printing.py

View check run for this annotation

Codecov / codecov/patch

pytensor/printing.py#L1188-L1189

Added lines #L1188 - L1189 were not covered by tests
except ImportError:
# tests should not fail on optional dependency
pydot_imported_msg = (
"Install the python package pydot or pydot-ng. Install graphviz."
)
except Exception as e:
pydot_imported_msg = "An error happened while importing/trying pydot: "
pydot_imported_msg += str(e.args)

Check warning on line 1197 in pytensor/printing.py

View check run for this annotation

Codecov / codecov/patch

pytensor/printing.py#L1195-L1197

Added lines #L1195 - L1197 were not covered by tests

if not pydot_imported:
raise ImportError(
"Failed to import pydot. You must install graphviz "
"and either pydot or pydot-ng for "
f"`pydotprint` to work:\n {pydot_imported_msg}",
)
return pd

Check warning on line 1205 in pytensor/printing.py

View check run for this annotation

Codecov / codecov/patch

pytensor/printing.py#L1205

Added line #L1205 was not covered by tests


def pydotprint(
fct,
outfile: Path | str | None = None,
Expand Down Expand Up @@ -1288,6 +1297,8 @@
scan separately after the top level debugprint output.

"""
pd = _try_pydot_import()

Check warning on line 1300 in pytensor/printing.py

View check run for this annotation

Codecov / codecov/patch

pytensor/printing.py#L1300

Added line #L1300 was not covered by tests

from pytensor.scan.op import Scan

if colorCodes is None:
Expand Down Expand Up @@ -1320,12 +1331,6 @@
outputs = fct.outputs
topo = fct.toposort()
fgraph = fct
if not pydot_imported:
raise RuntimeError(
"Failed to import pydot. You must install graphviz "
"and either pydot or pydot-ng for "
f"`pydotprint` to work:\n {pydot_imported_msg}",
)

g = pd.Dot()

Expand Down
Loading
Loading