@@ -2226,21 +2226,21 @@ def make_node(self, x, axis, splits):
22262226
22272227 return Apply (self , inputs , outputs )
22282228
2229- def perform (self , node , inputs , outputs ):
2229+ def perform (self , node , inputs , outputs_storage ):
22302230 x , axis , splits = inputs
22312231
22322232 if len (splits ) != self .len_splits :
22332233 raise ValueError ("Length of splits is not equal to n_splits" )
2234- if np .sum (splits ) != x .shape [axis ]:
2234+ if splits .sum () != x .shape [axis ]:
22352235 raise ValueError (
2236- f"Split sizes sum to { np .sum (splits )} ; expected { x .shape [axis ]} "
2236+ f"Split sizes sum to { splits .sum ()} ; expected { x .shape [axis ]} "
22372237 )
2238- if np . any (splits < 0 ):
2238+ if (splits < 0 ). any ( ):
22392239 raise ValueError ("Split sizes cannot be negative" )
22402240
22412241 split_outs = np .split (x , np .cumsum (splits [:- 1 ]), axis = axis )
2242- for i , out in enumerate ( split_outs ):
2243- outputs [ i ] [0 ] = out
2242+ for out_storage , out in zip ( outputs_storage , split_outs , strict = False ):
2243+ out_storage [0 ] = out
22442244
22452245 def infer_shape (self , fgraph , node , in_shapes ):
22462246 axis = node .inputs [1 ]
@@ -2254,10 +2254,10 @@ def infer_shape(self, fgraph, node, in_shapes):
22542254 out_shapes .append (temp )
22552255 return out_shapes
22562256
2257- def grad (self , inputs , g_outputs ):
2257+ def L_op (self , inputs , outputs , g_outputs ):
22582258 """Join the gradients along the axis that was used to split x."""
22592259 x , axis , n = inputs
2260- outputs = self ( * inputs , return_list = True )
2260+
22612261 # If all the output gradients are disconnected, then so are the inputs
22622262 if builtins .all (isinstance (g .type , DisconnectedType ) for g in g_outputs ):
22632263 return [
0 commit comments