From d7d20be24248537059732a3a7dca3f33499f4742 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 20 Feb 2023 13:51:45 +0100 Subject: [PATCH] Fix spurious warning from FusionOptimizer --- pytensor/tensor/rewriting/elemwise.py | 5 +++++ tests/tensor/rewriting/test_elemwise.py | 18 ++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 820131b5cc..494b08025d 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -692,6 +692,11 @@ def elemwise_scalar_op_has_c_code(node: Apply) -> bool: fuseable_clients: FUSEABLE_MAPPING = defaultdict(list) unfuseable_clients: UNFUSEABLE_MAPPING = defaultdict(set) for out, clients in fg.clients.items(): + # Old FunctionGraph nodes remain in the clients dictionary + # even after they are removed by rewrites + if not clients: + continue + out_maybe_fuseable = ( out.owner and isinstance(out.owner.op, Elemwise) diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index 712117fc3e..c30ed12f89 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -1,3 +1,5 @@ +import warnings + import numpy as np import pytest @@ -36,6 +38,7 @@ invert, iround, log, + log1mexp, log2, log10, mul, @@ -1370,6 +1373,21 @@ def rewrite_func(): assert benchmark(rewrite_func) == 103 + def test_no_warning_from_old_client(self): + # There used to be a warning issued when creating fuseable mapping + # for nodes that are no longer in the FunctionGraph + with warnings.catch_warnings(): + warnings.simplefilter("error") + # The -2 integer array cannot be passed directly to the C method + # of log1mexp as that can only handle floats. There is a rewrite + # that casts it to a float, but the FunctionGraph client retains + # the original log1mexp of the integer input, which caused + # a misleading warning for non C implementation in the FusionRewrite + assert np.isclose( + log1mexp(np.array(-2, dtype="int64")).eval(), + np.log(1 - np.exp(-2)), + ) + class TimesN(aes.basic.UnaryScalarOp): """