@@ -127,33 +127,6 @@ def arange_shape_func(attrs, inputs, _):
127127 """
128128 return [_arange_shape_func (* inputs )]
129129
130- @script
131- def _strided_slice_shape_func_input_data (data , begin , end , strides ,
132- slice_mode ):
133- ndim = len (data .shape )
134- out = output_tensor ((ndim ,), "int64" )
135- for i in const_range (ndim ):
136- cbegin = 0
137- cend = data .shape [i ]
138- cstride = 1
139- if strides .shape [0 ] > i :
140- cstride = strides [i ]
141- if begin .shape [0 ] > i :
142- cbegin = begin [i ]
143- if end .shape [0 ] <= i :
144- cend = data .shape [i ]
145- elif slice_mode != 0 :
146- cstride = 1
147- if end [i ] < 0 :
148- cend = data .shape [i ]
149- else :
150- cend = cbegin + end [i ]
151- else :
152- cend = end [i ]
153- assert cstride != 0 , "Strides can't be zero."
154- out [i ] = int64 (ceil_div ((int64 (cend ) - int64 (cbegin )), int64 (cstride )))
155- return out
156-
157130@script
158131def _strided_slice_shape_func_input_shape (data_shape , begin , end , strides , slice_mode ):
159132 ndim = data_shape .shape [0 ]
@@ -166,6 +139,8 @@ def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice
166139 cstride = int64 (strides [i ])
167140 if len (begin ) > i :
168141 cbegin = int64 (begin [i ])
142+ if cbegin < 0 :
143+ cbegin += int64 (data_shape [i ])
169144 if len (end ) <= i :
170145 cend = int64 (data_shape [i ])
171146 elif slice_mode != 0 :
@@ -175,23 +150,32 @@ def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice
175150 else :
176151 cend = cbegin + int64 (end [i ])
177152 else :
178- cend = int64 (end [i ])
153+ if end [i ] > data_shape [i ]:
154+ cend = int64 (data_shape [i ])
155+ else :
156+ cend = int64 (end [i ])
157+ if cend < 0 :
158+ cend += int64 (data_shape [i ])
179159 assert cstride != 0 , "Strides can't be zero."
180- out [i ] = int64 (ceil_div ((int64 (cend ) - int64 (cbegin )), int64 (cstride )))
160+ if cstride < 0 :
161+ slice_range = cbegin - cend
162+ step = - cstride
163+ else :
164+ slice_range = cend - cbegin
165+ step = cstride
166+
167+ out [i ] = int64 (ceil_div (slice_range , step ))
181168 return out
182169
183170
184- @_reg .register_shape_func ("strided_slice" , True )
171+ @_reg .register_shape_func ("strided_slice" , False )
185172def strided_slice_shape_func (attrs , inputs , _ ):
186173 """
187174 Shape func for strided_slice
188175 """
189176 slice_mode = convert (0 if attrs .slice_mode == "end" else 1 )
190- # data independent if begin, end and strides exist
191- if attrs .begin and attrs .end and attrs .strides :
192- return [_strided_slice_shape_func_input_shape (inputs [0 ], attrs .begin , attrs .end ,
193- attrs .strides , slice_mode )]
194- return [_strided_slice_shape_func_input_data (* inputs , slice_mode )]
177+ return [_strided_slice_shape_func_input_shape (inputs [0 ], attrs .begin , attrs .end ,
178+ attrs .strides , slice_mode )]
195179
196180@script
197181def _concatenate_shape_func (inputs , axis ):
0 commit comments