Skip to content

Commit 1a14965

Browse files
committed
style: Cleanup subtensor make_node
1 parent 6e495ed commit 1a14965

File tree

1 file changed

+11
-30
lines changed

1 file changed

+11
-30
lines changed

pytensor/tensor/subtensor.py

Lines changed: 11 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -732,46 +732,27 @@ def make_node(self, x, *inputs):
732732
f"Incompatible types for Subtensor template. Expected {input.type}, got {expected_type}."
733733
)
734734

735-
padded = list(self.idx_list) + [slice(None, None, None)] * (x.type.ndim - len(idx_list))
735+
padded = list(self.idx_list) + [slice(None, None, None)] * (
736+
x.type.ndim - len(idx_list)
737+
)
736738

737739
out_shape = []
738740

739-
def extract_const(value):
740-
if value is None:
741-
return value, True
742-
try:
743-
value = get_underlying_scalar_constant_value(value)
744-
return value, True
745-
except NotScalarConstantError:
746-
return value, False
747-
748741
for the_slice, length in zip(padded, x.type.shape):
749742
if isinstance(the_slice, slice):
750743
if length is None:
751744
out_shape.append(None)
752745
continue
753746

754-
start = the_slice.start
755-
stop = the_slice.stop
756-
step = the_slice.step
757-
758-
is_slice_const = True
759-
760-
start, is_const = extract_const(start)
761-
is_slice_const = is_slice_const and is_const
762-
763-
stop, is_const = extract_const(stop)
764-
is_slice_const = is_slice_const and is_const
765-
766-
step, is_const = extract_const(step)
767-
is_slice_const = is_slice_const and is_const
768-
769-
if not is_slice_const:
747+
try:
748+
start = get_underlying_scalar_constant_value(the_slice.start)
749+
stop = get_underlying_scalar_constant_value(the_slice.stop)
750+
step = get_underlying_scalar_constant_value(the_slice.step)
751+
except NotScalarConstantError:
770752
out_shape.append(None)
771-
continue
772-
773-
slice_length = len(range(*slice(start, stop, step).indices(length)))
774-
out_shape.append(slice_length)
753+
else:
754+
slice_length = len(range(*slice(start, stop, step).indices(length)))
755+
out_shape.append(slice_length)
775756

776757
return Apply(
777758
self,

0 commit comments

Comments
 (0)