From 15ce5fd572d51f45fc2f812abeb5a30aac61f9cc Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Thu, 31 Oct 2024 09:28:25 +0100 Subject: [PATCH 1/2] Fix bug in `local_useless_slice` rewrite Canonical slice start and stop values depend on the sign of the step. The rewrite wrongly assumed they were always 0:len(dim) --- pytensor/tensor/rewriting/subtensor.py | 45 ++++++---- tests/tensor/rewriting/test_subtensor.py | 100 +++++++++++++++-------- 2 files changed, 94 insertions(+), 51 deletions(-) diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 7699169143..cb453a44e4 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -342,14 +342,18 @@ def local_subtensor_of_dot(fgraph, node): @node_rewriter([Subtensor]) def local_useless_slice(fgraph, node): """ - Remove Subtensor of the form: + Remove useless slice(None) of the form: 1. X[0, :] -> X[0] 2. X[:] -> X - Also, rewrite Subtensor of the form: + Also, canonicalize slices of the form: X[0:7:1] -> X[None:None:None] where X is a vector of length 7 + And: + X[-1:-8:-1] -> X[::-1] + where x is a vector of length 7 + """ idxs = get_idx_list(node.inputs, node.op.idx_list) x = node.inputs[0] @@ -368,32 +372,40 @@ def local_useless_slice(fgraph, node): if s == slice(None): continue + step = s.step + + if step is None: + positive_step = True + elif isinstance(step, Constant): + step_value = step.data + positive_step = step.data > 0 + if step_value == 1: + change_flag = True + step = None + else: + # We can only canonicalize start and stop if we know the sign of step + last_useful_idx = dim + continue + start = s.start stop = s.stop - step = s.step - if ( - start is not None - and extract_constant(start, only_process_constants=True) == 0 - ): + + if start is not None and extract_constant( + start, only_process_constants=True + ) == (0 if positive_step else -1): change_flag = True start = None if ( stop is not None and x.type.shape[dim] is not None - and extract_constant(stop, only_process_constants=True) == x.type.shape[dim] + and extract_constant(stop, only_process_constants=True) + == (x.type.shape[dim] if positive_step else -x.type.shape[dim] - 1) ): change_flag = True stop = None - if ( - step is not None - and extract_constant(step, only_process_constants=True) == 1 - ): - change_flag = True - step = None - - if not (start is None and stop is None and step is None): + if start is not None or stop is not None or step is not None: last_useful_idx = dim new_idxs[dim] = slice(start, stop, step) @@ -402,7 +414,6 @@ def local_useless_slice(fgraph, node): out = x[tuple(new_idxs[: last_useful_idx + 1])] # Copy over previous output stacktrace copy_stack_trace(node.outputs, out) - return [out] diff --git a/tests/tensor/rewriting/test_subtensor.py b/tests/tensor/rewriting/test_subtensor.py index 91575bc7da..72a7a0f235 100644 --- a/tests/tensor/rewriting/test_subtensor.py +++ b/tests/tensor/rewriting/test_subtensor.py @@ -2404,42 +2404,74 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc): np.testing.assert_allclose(fn(test_x, test_y), expected_out) -def test_slice_canonicalize(): - rng = np.random.default_rng(43) - x = tensor(shape=(3, 5, None, 9)) - test_x = rng.normal(size=(3, 5, 8, 9)) - # Test case 1 - y = x[0:None, 0:5, 0:7, 0:9:1] - f = pytensor.function([x], y, allow_input_downcast=True) - - # Get the DeepCopy input and assert that the Op is a DeepCopy - test_y = f.maker.fgraph.outputs[0].owner.inputs[0] - assert isinstance(f.maker.fgraph.outputs[0].owner.op, DeepCopyOp) - - expected_y = x[None:None:None, None:None:None, None:7:None] - - assert equal_computations([test_y], [expected_y]) - - np.testing.assert_allclose( - f(test_x), - test_x[ - 0:None, 0:5, 0:7, 0:9:1 - ], # Use the unoptimized slice to make sure our rewrite logic is correct - ) +class TestUselessSlice: + def test_positive_step(self): + # When steps are positive, default start and end are `0` and `len(dim)` + x = tensor(shape=(3, 5, None, 9), dtype="float64") + test_x = np.random.normal(size=(3, 5, 8, 9)) + + y = x[0:3:1, 1:5:2, 0:7:1, 0:9:1] + f = pytensor.function([x], y) + + # Get the DeepCopy input and assert that the Op is a DeepCopy + deep_copy_node = f.maker.fgraph.outputs[0].owner + assert isinstance(deep_copy_node.op, DeepCopyOp) + + rewritten_y = deep_copy_node.inputs[0] + expected_y = x[None:None:None, 1:None:2, None:7:None] + assert equal_computations([rewritten_y], [expected_y]) + + np.testing.assert_allclose( + f(test_x), + # Use the unoptimized slice to make sure our rewrite logic is correct + test_x[0:3:1, 1:5:2, 0:7:1, 0:9:1], + ) - # Test case 2 - y1 = x[0:-1, 0:5, 0:7, 0:-1:-1] - f1 = pytensor.function([x], y1, allow_input_downcast=True) + def test_negative_step(self): + # When steps are negative, default start and end are `-1` and `-len(dim) - 1` + x = tensor(shape=(3, 5, None, 9), dtype="float64") + test_x = np.random.normal(size=(3, 5, 8, 9)) - # Get the DeepCopy input and assert that the Op is a DeepCopy - test_y1 = f1.maker.fgraph.outputs[0].owner.inputs[0] - assert isinstance(f1.maker.fgraph.outputs[0].owner.op, DeepCopyOp) + y = x[-1:-4:-1, 0:5:-2, -1:-9:-1, 0:9:None] + f = pytensor.function([x], y) - expected_y1 = x[None:-1:None, None:None:None, None:7:None, None:-1:-1] + # Get the DeepCopy input and assert that the Op is a DeepCopy + deep_copy_node = f.maker.fgraph.outputs[0].owner + assert isinstance(deep_copy_node.op, DeepCopyOp) - assert equal_computations([test_y1], [expected_y1]) + rewritten_y = deep_copy_node.inputs[0] + expected_y = x[None:None:-1, 0:5:-2, None:-9:-1] + assert equal_computations([rewritten_y], [expected_y]) - np.testing.assert_allclose( - f1(test_x), - test_x[0:-1, 0:5, 0:7, 0:-1:-1], - ) + np.testing.assert_allclose( + f(test_x), + test_x[-1:-4:-1, 0:5:-2, -1:-9:-1, 0:9:None], + ) + + def test_unknown_step(self): + # If step isn't known, we can't canonicalize start and stop points + step = pt.scalar("step", dtype=int) + x = tensor(shape=(3, 5, None), dtype="float64") + test_x = np.random.normal(size=(3, 5, 7)) + + y = x[0:3:step, -1:-6:-step, ::] + # Need this rewrite when `FAST_COMPILE` otherwise step = -1 * step instead of neg(step) + mode = get_default_mode().including("local_mul_specialize") + f = pytensor.function([x, step], y, mode=mode) + + # Get the DeepCopy input and assert that the Op is a DeepCopy + deep_copy_node = f.maker.fgraph.outputs[0].owner + assert isinstance(deep_copy_node.op, DeepCopyOp) + + rewritten_y = deep_copy_node.inputs[0] + expected_y = x[0:3:step, -1:-6:-step] + assert equal_computations([rewritten_y], [expected_y]) + + np.testing.assert_allclose( + f(test_x, 1), + test_x[0:3:1, -1:-6:-1, ::], + ) + np.testing.assert_allclose( + f(test_x, -2), + test_x[0:3:-2, -1:-6:2, ::], + ) From 55239849ff8eff3414da5b3ac476f1c0c6f29dc1 Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Thu, 31 Oct 2024 18:41:26 +0100 Subject: [PATCH 2/2] Fix bug in `local_reduce_join` rewrite. The helper `apply_local_dimshuffle_lift` requires a FunctionGraph when elemwise inputs are involved. --- pytensor/tensor/rewriting/math.py | 2 +- tests/tensor/rewriting/test_math.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 68cc0e5e96..b230f035cc 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -1620,7 +1620,7 @@ def local_reduce_join(fgraph, node): if not inp.type.broadcastable[join_axis]: return None # Most times inputs to join have an expand_dims, we eagerly clean up those here - new_input = apply_local_dimshuffle_lift(None, inp.squeeze(join_axis)) + new_input = apply_local_dimshuffle_lift(fgraph, inp.squeeze(join_axis)) new_inputs.append(new_input) ret = Elemwise(node.op.scalar_op)(*new_inputs) diff --git a/tests/tensor/rewriting/test_math.py b/tests/tensor/rewriting/test_math.py index 1212ee4fbd..e4a08cdf81 100644 --- a/tests/tensor/rewriting/test_math.py +++ b/tests/tensor/rewriting/test_math.py @@ -103,6 +103,7 @@ local_mul_canonizer, local_mul_switch_sink, local_reduce_chain, + local_reduce_join, local_sum_prod_of_mul_or_div, mul_canonizer, parse_mul_tree, @@ -3415,6 +3416,24 @@ def test_not_supported_unequal_shapes(self): f(x, y), np.sum(np.concatenate([x, y], axis=0), axis=0) ) + def test_non_ds_inputs(self): + """Make sure rewrite works when inputs to join are not the usual DimShuffle. + + Sum{axis=1} [id A] + └─ Join [id B] + ├─ 1 [id C] + ├─ ExpandDims{axis=1} [id D] + ├─ Sub [id E] + └─ Sub [id F] + """ + x = vector("x") + out = join(0, exp(x[None]), log(x[None])).sum(axis=0) + + fg = FunctionGraph([x], [out], clone=False) + [rewritten_out] = local_reduce_join.transform(fg, out.owner) + expected_out = add(exp(x), log(x)) + assert equal_computations([rewritten_out], [expected_out]) + def test_local_useless_adds(): default_mode = get_default_mode()