From ac6f89ea488fa9b6892c181c59d8fd3ac701e566 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sun, 27 Jul 2025 18:01:38 -0400 Subject: [PATCH 01/11] add rewrite for log(sqrt(x)) --- pytensor/tensor/rewriting/math.py | 21 +++++++++++++++++++++ tests/tensor/rewriting/test_math.py | 12 ++++++++++++ 2 files changed, 33 insertions(+) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index acc8becbfb..e330a6a68b 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -552,6 +552,27 @@ def local_sqrt_sqr(fgraph, node): return [new_out] +@register_canonicalize +@node_rewriter([log]) +def local_log_sqrt(fgraph, node): + x = node.inputs[0] + + if not (x.owner and isinstance(x.owner.op, Elemwise)): + return + + prev_op = x.owner.op.scalar_op + node_op = node.op.scalar_op + + if isinstance(prev_op, ps.Sqrt) and isinstance(node_op, ps.Log): + # Case for log(sqrt(x)) -> 0.5 * log(x) + x = x.owner.inputs[0] + old_out = node.outputs[0] + new_out = mul(0.5, log(x)) + if new_out.dtype != old_out.dtype: + new_out = cast(new_out, old_out.dtype) + return [new_out] + + @register_specialize @node_rewriter([exp, expm1]) def local_exp_log_nan_switch(fgraph, node): diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index a6e734ae82..b2447462c1 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -1989,6 +1989,18 @@ def test_exp_log_nested(self, nested_expression, expected_switches): assert len(ops_graph) == expected_switches +def test_log_sqrt() -> None: + x = pt.tensor("x", shape=(None, None)) + out = log(sqrt(x)) + + out = rewrite_graph(out, include=["canonicalize", "specialize", "stabilize"]) + + assert equal_computations( + [out], + [mul(np.array([[0.5]]), log(x))], + ) + + class TestSqrSqrt: def setup_method(self): mode = get_default_mode() From 49cab313c77c251aee94e2d27ef16f34b6dd3246 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sun, 27 Jul 2025 18:21:55 -0400 Subject: [PATCH 02/11] remove duplicate check --- pytensor/tensor/rewriting/math.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index e330a6a68b..f9c5cf48bd 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -561,9 +561,8 @@ def local_log_sqrt(fgraph, node): return prev_op = x.owner.op.scalar_op - node_op = node.op.scalar_op - if isinstance(prev_op, ps.Sqrt) and isinstance(node_op, ps.Log): + if isinstance(prev_op, ps.Sqrt): # Case for log(sqrt(x)) -> 0.5 * log(x) x = x.owner.inputs[0] old_out = node.outputs[0] From 67390dc37f2bf21348180b0855c4285191f1b0b2 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sun, 27 Jul 2025 18:29:51 -0400 Subject: [PATCH 03/11] combine if block --- pytensor/tensor/rewriting/math.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index f9c5cf48bd..77a896f584 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -557,19 +557,18 @@ def local_sqrt_sqr(fgraph, node): def local_log_sqrt(fgraph, node): x = node.inputs[0] - if not (x.owner and isinstance(x.owner.op, Elemwise)): + if not (x.owner and isinstance(x.owner.op, Elemwise)) or not isinstance( + x.owner.op.scalar_op, ps.Sqrt + ): return - prev_op = x.owner.op.scalar_op - - if isinstance(prev_op, ps.Sqrt): - # Case for log(sqrt(x)) -> 0.5 * log(x) - x = x.owner.inputs[0] - old_out = node.outputs[0] - new_out = mul(0.5, log(x)) - if new_out.dtype != old_out.dtype: - new_out = cast(new_out, old_out.dtype) - return [new_out] + # Case for log(sqrt(x)) -> 0.5 * log(x) + x = x.owner.inputs[0] + old_out = node.outputs[0] + new_out = mul(0.5, log(x)) + if new_out.dtype != old_out.dtype: + new_out = cast(new_out, old_out.dtype) + return [new_out] @register_specialize From c24e72538d5d75d6acc50bc4f9fdfe7856e07f7e Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sun, 27 Jul 2025 18:37:18 -0400 Subject: [PATCH 04/11] copy stack trace --- pytensor/tensor/rewriting/math.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 77a896f584..1bf8720a85 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -568,6 +568,7 @@ def local_log_sqrt(fgraph, node): new_out = mul(0.5, log(x)) if new_out.dtype != old_out.dtype: new_out = cast(new_out, old_out.dtype) + copy_stack_trace(node.out, new_out) return [new_out] From 1ab1dc3c2b6ee7420682945e5b312ed039864a46 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Sun, 27 Jul 2025 19:47:37 -0400 Subject: [PATCH 05/11] use as_tensor_variable instead --- tests/tensor/rewriting/test_math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index b2447462c1..7171b683d3 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -1997,7 +1997,7 @@ def test_log_sqrt() -> None: assert equal_computations( [out], - [mul(np.array([[0.5]]), log(x))], + [mul(pt.as_tensor_variable([[0.5]]), log(x))], ) From a69b79ff544e81473ca91a7346017c632be1b142 Mon Sep 17 00:00:00 2001 From: Will Dean <57733339+williambdean@users.noreply.github.com> Date: Mon, 28 Jul 2025 06:37:44 -0400 Subject: [PATCH 06/11] Update math.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pytensor/tensor/rewriting/math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 1bf8720a85..805041c64b 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -557,7 +557,7 @@ def local_sqrt_sqr(fgraph, node): def local_log_sqrt(fgraph, node): x = node.inputs[0] - if not (x.owner and isinstance(x.owner.op, Elemwise)) or not isinstance( + if not (x.owner and isinstance(x.owner.op, Elemwise)) and isinstance( x.owner.op.scalar_op, ps.Sqrt ): return From 38719b496ab34d2def144a60b9d29c13f998c987 Mon Sep 17 00:00:00 2001 From: Will Dean <57733339+williambdean@users.noreply.github.com> Date: Mon, 28 Jul 2025 06:37:54 -0400 Subject: [PATCH 07/11] Update test_math.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- tests/tensor/rewriting/test_math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 7171b683d3..5942b84cf8 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -1993,7 +1993,7 @@ def test_log_sqrt() -> None: x = pt.tensor("x", shape=(None, None)) out = log(sqrt(x)) - out = rewrite_graph(out, include=["canonicalize", "specialize", "stabilize"]) + out = rewrite_graph(out, include=["canonicalize"]) assert equal_computations( [out], From 1bce3d58614ddddf9c758f42deb58c25493f167a Mon Sep 17 00:00:00 2001 From: Will Dean Date: Mon, 28 Jul 2025 10:04:48 -0400 Subject: [PATCH 08/11] always copy stack trace --- pytensor/tensor/rewriting/math.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 805041c64b..9d5dfa9e28 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -568,7 +568,8 @@ def local_log_sqrt(fgraph, node): new_out = mul(0.5, log(x)) if new_out.dtype != old_out.dtype: new_out = cast(new_out, old_out.dtype) - copy_stack_trace(node.out, new_out) + + copy_stack_trace(node.out, new_out) return [new_out] From 3f85e50d3fc1f30ec2945eee03023f0d61ea4208 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Mon, 28 Jul 2025 10:18:26 -0400 Subject: [PATCH 09/11] simplify condition and change to specialize --- pytensor/tensor/rewriting/math.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 9d5dfa9e28..0b05dee6ca 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -552,13 +552,15 @@ def local_sqrt_sqr(fgraph, node): return [new_out] -@register_canonicalize +@register_specialize @node_rewriter([log]) def local_log_sqrt(fgraph, node): x = node.inputs[0] - if not (x.owner and isinstance(x.owner.op, Elemwise)) and isinstance( - x.owner.op.scalar_op, ps.Sqrt + if ( + not x.owner + or not isinstance(x.owner.op, Elemwise) + or not isinstance(x.owner.op.scalar_op, ps.Sqrt) ): return From 5ed6975140518f8528199418a8e325b374ff4434 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Mon, 28 Jul 2025 10:19:11 -0400 Subject: [PATCH 10/11] change to specialize --- tests/tensor/rewriting/test_math.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 5942b84cf8..7bf98d1979 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -1993,7 +1993,7 @@ def test_log_sqrt() -> None: x = pt.tensor("x", shape=(None, None)) out = log(sqrt(x)) - out = rewrite_graph(out, include=["canonicalize"]) + out = rewrite_graph(out, include=["specialize"]) assert equal_computations( [out], From c139f75fd3ce51cca79543956b2e7ce333b5b083 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Mon, 28 Jul 2025 10:53:48 -0400 Subject: [PATCH 11/11] use the type from x --- pytensor/tensor/rewriting/math.py | 2 +- tests/tensor/rewriting/test_math.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 0b05dee6ca..327e14951e 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -567,7 +567,7 @@ def local_log_sqrt(fgraph, node): # Case for log(sqrt(x)) -> 0.5 * log(x) x = x.owner.inputs[0] old_out = node.outputs[0] - new_out = mul(0.5, log(x)) + new_out = mul(as_tensor_variable(0.5, dtype=x.dtype), log(x)) if new_out.dtype != old_out.dtype: new_out = cast(new_out, old_out.dtype) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 7bf98d1979..43f65c2282 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -1995,9 +1995,9 @@ def test_log_sqrt() -> None: out = rewrite_graph(out, include=["specialize"]) - assert equal_computations( + assert utt.assert_equal_computations( [out], - [mul(pt.as_tensor_variable([[0.5]]), log(x))], + [mul(pt.as_tensor_variable([[0.5]], dtype=x.dtype), log(x))], )