@@ -2171,8 +2171,6 @@ class Split(COp):
21712171 array([3, 4])
21722172 >>> c
21732173 array([5])
2174-
2175- TODO: Don't make a copy in C impl
21762174 """
21772175
21782176 len_splits = None
@@ -2285,142 +2283,107 @@ def R_op(self, inputs, eval_points):
22852283 return self .make_node (eval_points [0 ], * inputs [1 :]).outputs
22862284
22872285 def c_code_cache_version (self ):
2288- return (2 ,)
2289-
2290- def c_support_code (self , ** kwargs ):
2291- return """
2292- /* Return 1 if output has the correct shape. */
2293- int split_output_shape_is_correct (
2294- PyArrayObject* output, PyArrayObject* array_to_split, int axis_to_split, npy_intp split_size
2295- ) {
2296- return
2297- PyArray_NDIM(output) == PyArray_NDIM(array_to_split)
2298- && memcmp(
2299- PyArray_DIMS(output),
2300- PyArray_DIMS(array_to_split),
2301- axis_to_split * sizeof(npy_intp)
2302- ) == 0
2303- && memcmp(
2304- PyArray_DIMS(output) + axis_to_split + 1,
2305- PyArray_DIMS(array_to_split) + axis_to_split + 1,
2306- (PyArray_NDIM(array_to_split) - axis_to_split - 1) * sizeof(npy_intp)
2307- ) == 0
2308- && split_size == PyArray_DIM(output, axis_to_split);
2309- }
2310- """
2286+ return (3 ,)
23112287
23122288 def c_code (self , node , name , inputs , outputs , sub ):
23132289 if self .len_splits == 0 :
2314- # There are no outputs, then nothing to do.
2315- return ""
2290+ # This would be a view Op, anyway shouldn't be triggered
2291+ raise NotImplementedError ()
23162292
23172293 # outputs_pointers lists the addresses of the pointers to the outputs.
23182294 outputs_pointers = "&" + (", &" .join (outputs ))
23192295 x , axis , splits = inputs
23202296 fail = sub ["fail" ]
2321- x_typenum = np .dtype (node .inputs [0 ].dtype ).num
2322- x_itemsize = np .dtype (node .inputs [0 ].dtype ).itemsize
2323- axis_dtype = node .inputs [1 ].type .dtype_specs ()[1 ]
23242297 splits_dtype = node .inputs [2 ].type .dtype_specs ()[1 ]
2325- expected_splits_count = self .len_splits
2298+ len_splits = self .len_splits
2299+ ndim = node .inputs [0 ].type .ndim
2300+
2301+ # Most times axis is constant, inline it
2302+ # This is safe to do because the hash of the c_code includes the constant signature
2303+ if isinstance (node .inputs [1 ], Constant ):
2304+ static_axis = int (node .inputs [1 ].data )
2305+ static_axis = normalize_axis_index (static_axis , ndim )
2306+ axis_def = f"{ static_axis } ;"
2307+ axis_check = ""
2308+ else :
2309+ axis_dtype = node .inputs [1 ].type .dtype_specs ()[1 ]
2310+ axis_def = f"(({ axis_dtype } *)PyArray_DATA({ axis } ))[0];"
2311+ axis_check = f"""
2312+ if (axis < 0){{
2313+ axis = ndim + axis;
2314+ }}
2315+ if (axis >= ndim || axis < 0) {{
2316+ PyErr_SetString(PyExc_ValueError, "Split axis is out of bounds");
2317+ { fail }
2318+ }}
2319+ """
23262320
23272321 return f"""
2328- int ndim = PyArray_NDIM( { x } ) ;
2329- int axis = (int)(*( { axis_dtype } *)PyArray_GETPTR1( { axis } , 0));
2322+ int ndim = { ndim } ;
2323+ int axis = { axis_def }
23302324 int splits_count = PyArray_DIM({ splits } , 0);
2331- npy_intp len_along_axis, sum_of_splits = 0, current_split_length = 0, current_split_start = 0;
2332- npy_intp* split_dims = NULL;
2333- PyObject* split_view = NULL;
2334- npy_intp data_offset;
2335- int i;
2325+ npy_intp sum_of_splits = 0, current_split_start = 0;
23362326 PyArrayObject** outputs[] = {{{ outputs_pointers } }};
2327+ npy_intp split_dims[ndim];
23372328
23382329 /* Check inputs. */
2339-
2340- if (splits_count != { expected_splits_count } ) {{
2341- PyErr_Format(PyExc_ValueError,
2342- "Split: splits count (%d) != expected count (%d).", splits_count, { expected_splits_count } );
2330+ if (PyArray_NDIM({ x } ) != ndim) {{
2331+ PyErr_Format(PyExc_ValueError, "Input to Split does not have expected ndim");
23432332 { fail }
23442333 }}
2345-
2346- if (axis < 0) {{
2347- axis += ndim;
2348- }}
2349- if (axis < 0 || axis >= ndim) {{
2350- PyErr_Format(PyExc_IndexError, "Split: invalid axis %d for a %d-D array.", axis, ndim);
2334+ if (splits_count != { len_splits } ) {{
2335+ PyErr_Format(PyExc_ValueError, "Split: splits count (%d) != expected count (%d).", splits_count, { len_splits } );
23512336 { fail }
23522337 }}
2353- len_along_axis = PyArray_DIM({ x } , axis);
23542338
2355- for (i = 0; i < splits_count; ++i) {{
2356- current_split_length = (npy_intp)(*({ splits_dtype } *)PyArray_GETPTR1({ splits } , i));
2339+ { axis_check } ;
2340+
2341+ for (int i = 0; i < splits_count; ++i) {{
2342+ int current_split_length = (npy_intp)(*({ splits_dtype } *)PyArray_GETPTR1({ splits } , i));
23572343 if (current_split_length < 0) {{
23582344 PyErr_Format(PyExc_ValueError,
23592345 "Split: you try to take a negative number (%ld) of elements.", current_split_length);
23602346 { fail }
23612347 }}
23622348 sum_of_splits += current_split_length;
23632349 }}
2364- if (sum_of_splits != len_along_axis) {{
2365- PyErr_Format(PyExc_ValueError, "Split: the splits sums to %ld, expected %ld.", sum_of_splits, len_along_axis);
2366- { fail }
2367- }}
2368-
2369- /* Check outputs. */
2370-
2371- split_dims = (npy_intp*) malloc(ndim * sizeof(npy_intp));
2372- if (split_dims == NULL) {{
2373- PyErr_NoMemory();
2350+ if (sum_of_splits != PyArray_DIM({ x } , axis)) {{
2351+ PyErr_Format(PyExc_ValueError, "Split: the splits sums to %ld, expected %ld.", sum_of_splits, PyArray_DIM({ x } , axis));
23742352 { fail }
23752353 }}
23762354
2355+ /* Compute split. */
23772356 memcpy(split_dims, PyArray_DIMS({ x } ), ndim * sizeof(npy_intp));
23782357
2379- for (i = 0; i < splits_count; ++i) {{
2380- PyArrayObject** output = outputs[i];
2381- current_split_length = (npy_intp) (* ({ splits_dtype } *) PyArray_GETPTR1({ splits } , i));
2382- if (*output == NULL || !split_output_shape_is_correct(*output, { x } , axis, current_split_length)) {{
2383- Py_XDECREF(*output);
2384- split_dims[axis] = current_split_length;
2385- *output = (PyArrayObject*)PyArray_EMPTY(ndim, split_dims, { x_typenum } , PyArray_IS_F_CONTIGUOUS({ x } ));
2386- if (outputs == NULL) {{
2387- PyErr_SetString(PyExc_RuntimeError, "Split: unable to allocate an output.");
2388- free(split_dims);
2389- { fail }
2390- }}
2391- }}
2392- }}
2358+ for (int i = 0; i < splits_count; ++i) {{
2359+ Py_XDECREF(*outputs[i]);
23932360
2394- /* Compute split. */
2395-
2396- for (i = 0; i < splits_count; ++i) {{
2397- current_split_length = (npy_intp) (* ({ splits_dtype } *) PyArray_GETPTR1({ splits } , i));
2398- data_offset = PyArray_STRIDE({ x } , axis) * current_split_start;
2361+ // Create view of input
2362+ npy_intp data_offset = PyArray_STRIDE({ x } , axis) * current_split_start;
2363+ int current_split_length = (npy_intp)(*({ splits_dtype } *)PyArray_GETPTR1({ splits } , i));
23992364 split_dims[axis] = current_split_length;
2400- split_view = PyArray_New(&PyArray_Type,
2401- ndim, split_dims,
2402- { x_typenum } ,
2403- PyArray_STRIDES({ x } ),
2404- PyArray_BYTES({ x } ) + data_offset,
2405- { x_itemsize } ,
2406- PyArray_FLAGS({ x } ),
2407- NULL);
2408- if (split_view == NULL) {{
2365+ PyArray_Descr *descr = PyArray_DESCR({ x } );
2366+ Py_INCREF(descr);
2367+ *outputs[i] = (PyArrayObject*)PyArray_NewFromDescr(&PyArray_Type,
2368+ descr, // PyArray_NewFromDescr steals this reference
2369+ ndim, split_dims,
2370+ PyArray_STRIDES({ x } ),
2371+ PyArray_BYTES({ x } ) + data_offset,
2372+ PyArray_FLAGS({ x } ) & ~NPY_ARRAY_OWNDATA,
2373+ NULL);
2374+
2375+ if (*outputs[i] == NULL) {{
24092376 PyErr_SetString(PyExc_RuntimeError, "Split: unable to create a view for a split.");
2410- free(split_dims);
2411- { fail }
2412- }}
2413- if (PyArray_CopyInto(*outputs[i], (PyArrayObject*)split_view) != 0) {{
2414- PyErr_SetString(PyExc_RuntimeError, "Split: unable to copy a split view into the output.");
2415- Py_XDECREF(split_view);
2416- free(split_dims);
24172377 { fail }
24182378 }}
2419- Py_XDECREF(split_view);
2379+
2380+ // Set as a view of input
2381+ Py_INCREF((PyObject*){ x } );
2382+ PyArray_SetBaseObject(*outputs[i], (PyObject*){ x } );
2383+
2384+ // Update split slice pointer
24202385 current_split_start += current_split_length;
24212386 }}
2422-
2423- free(split_dims);
24242387 """
24252388
24262389
0 commit comments