@@ -732,46 +732,27 @@ def make_node(self, x, *inputs):
732
732
f"Incompatible types for Subtensor template. Expected { input .type } , got { expected_type } ."
733
733
)
734
734
735
- padded = list (self .idx_list ) + [slice (None , None , None )] * (x .type .ndim - len (idx_list ))
735
+ padded = list (self .idx_list ) + [slice (None , None , None )] * (
736
+ x .type .ndim - len (idx_list )
737
+ )
736
738
737
739
out_shape = []
738
740
739
- def extract_const (value ):
740
- if value is None :
741
- return value , True
742
- try :
743
- value = get_underlying_scalar_constant_value (value )
744
- return value , True
745
- except NotScalarConstantError :
746
- return value , False
747
-
748
741
for the_slice , length in zip (padded , x .type .shape ):
749
742
if isinstance (the_slice , slice ):
750
743
if length is None :
751
744
out_shape .append (None )
752
745
continue
753
746
754
- start = the_slice .start
755
- stop = the_slice .stop
756
- step = the_slice .step
757
-
758
- is_slice_const = True
759
-
760
- start , is_const = extract_const (start )
761
- is_slice_const = is_slice_const and is_const
762
-
763
- stop , is_const = extract_const (stop )
764
- is_slice_const = is_slice_const and is_const
765
-
766
- step , is_const = extract_const (step )
767
- is_slice_const = is_slice_const and is_const
768
-
769
- if not is_slice_const :
747
+ try :
748
+ start = get_underlying_scalar_constant_value (the_slice .start )
749
+ stop = get_underlying_scalar_constant_value (the_slice .stop )
750
+ step = get_underlying_scalar_constant_value (the_slice .step )
751
+ except NotScalarConstantError :
770
752
out_shape .append (None )
771
- continue
772
-
773
- slice_length = len (range (* slice (start , stop , step ).indices (length )))
774
- out_shape .append (slice_length )
753
+ else :
754
+ slice_length = len (range (* slice (start , stop , step ).indices (length )))
755
+ out_shape .append (slice_length )
775
756
776
757
return Apply (
777
758
self ,
0 commit comments