diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 5a3cf0036f..aba6787826 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -2471,6 +2471,18 @@ def make_node(self, axis, *tensors): if axis.type.ndim > 0: raise TypeError(f"Axis {axis} must be 0-d.") + # Convert negative constant axis to positive during canonicalization + if isinstance(axis, Constant) and tensors: + # Get the axis value directly from the constant's data + axis_val = axis.data.item() + # Check if it's negative and needs normalization + if axis_val < 0: + ndim = tensors[0].ndim + # Convert negative axis to positive + axis_val = normalize_axis_index(axis_val, ndim) + # Replace the original axis with the normalized one + axis = constant(axis_val, dtype=axis.type.dtype) + tensors = [as_tensor_variable(x) for x in tensors] if not builtins.all(targs.type.ndim > 0 for targs in tensors): diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index b01a50e2fa..f3b68f0e14 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -2179,6 +2179,15 @@ def test_join_performance(self, ndim, axis, memory_layout, gc, benchmark): assert fn(*test_values).shape == (n * 6, n)[:ndim] if axis == 0 else (n, n * 6) benchmark(fn, *test_values) + def test_join_negative_axis_rewrite(self): + """Test that constant negative axis is rewritten to positive axis in make_node.""" + v = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=self.floatX) + a = self.shared(v) + b = as_tensor_variable(v) + + assert equal_computations([join(-1, a, b)], [join(1, a, b)]) + assert equal_computations([join(-2, a, b)], [join(0, a, b)]) + def test_TensorFromScalar(): s = ps.constant(56)