@@ -3207,13 +3207,14 @@ def tile(
32073207 return A_replicated .reshape (tiled_shape )
32083208
32093209
3210- class ARange (Op ):
3210+ class ARange (COp ):
32113211 """Create an array containing evenly spaced values within a given interval.
32123212
32133213 Parameters and behaviour are the same as numpy.arange().
32143214
32153215 """
32163216
3217+ # TODO: Arange should work with scalars as inputs, not arrays
32173218 __props__ = ("dtype" ,)
32183219
32193220 def __init__ (self , dtype ):
@@ -3293,13 +3294,30 @@ def upcast(var):
32933294 )
32943295 ]
32953296
3296- def perform (self , node , inp , out_ ):
3297- start , stop , step = inp
3298- (out ,) = out_
3299- start = start .item ()
3300- stop = stop .item ()
3301- step = step .item ()
3302- out [0 ] = np .arange (start , stop , step , dtype = self .dtype )
3297+ def perform (self , node , inputs , output_storage ):
3298+ start , stop , step = inputs
3299+ output_storage [0 ][0 ] = np .arange (
3300+ start .item (), stop .item (), step .item (), dtype = self .dtype
3301+ )
3302+
3303+ def c_code (self , node , nodename , input_names , output_names , sub ):
3304+ [start_name , stop_name , step_name ] = input_names
3305+ [out_name ] = output_names
3306+ typenum = np .dtype (self .dtype ).num
3307+ return f"""
3308+ double start = ((dtype_{ start_name } *)PyArray_DATA({ start_name } ))[0];
3309+ double stop = ((dtype_{ stop_name } *)PyArray_DATA({ stop_name } ))[0];
3310+ double step = ((dtype_{ step_name } *)PyArray_DATA({ step_name } ))[0];
3311+ //printf("start: %f, stop: %f, step: %f\\ n", start, stop, step);
3312+ Py_XDECREF({ out_name } );
3313+ { out_name } = (PyArrayObject*) PyArray_Arange(start, stop, step, { typenum } );
3314+ if (!{ out_name } ) {{
3315+ { sub ["fail" ]}
3316+ }}
3317+ """
3318+
3319+ def c_code_cache_version (self ):
3320+ return (0 ,)
33033321
33043322 def connection_pattern (self , node ):
33053323 return [[True ], [False ], [True ]]
@@ -3685,8 +3703,7 @@ def inverse_permutation(perm):
36853703 )
36863704
36873705
3688- # TODO: optimization to insert ExtractDiag with view=True
3689- class ExtractDiag (Op ):
3706+ class ExtractDiag (COp ):
36903707 """
36913708 Return specified diagonals.
36923709
@@ -3742,7 +3759,7 @@ class ExtractDiag(Op):
37423759
37433760 __props__ = ("offset" , "axis1" , "axis2" , "view" )
37443761
3745- def __init__ (self , offset = 0 , axis1 = 0 , axis2 = 1 , view = False ):
3762+ def __init__ (self , offset = 0 , axis1 = 0 , axis2 = 1 , view = True ):
37463763 self .view = view
37473764 if self .view :
37483765 self .view_map = {0 : [0 ]}
@@ -3765,24 +3782,74 @@ def make_node(self, x):
37653782 if x .ndim < 2 :
37663783 raise ValueError ("ExtractDiag needs an input with 2 or more dimensions" , x )
37673784
3768- out_shape = [
3769- st_dim
3770- for i , st_dim in enumerate (x .type .shape )
3771- if i not in (self .axis1 , self .axis2 )
3772- ] + [None ]
3785+ if (dim1 := x .type .shape [self .axis1 ]) is not None and (
3786+ dim2 := x .type .shape [self .axis2 ]
3787+ ) is not None :
3788+ offset = self .offset
3789+ if offset > 0 :
3790+ diag_size = int (np .clip (dim2 - offset , 0 , dim1 ))
3791+ elif offset < 0 :
3792+ diag_size = int (np .clip (dim1 + offset , 0 , dim2 ))
3793+ else :
3794+ diag_size = int (np .minimum (dim1 , dim2 ))
3795+ else :
3796+ diag_size = None
3797+
3798+ out_shape = (
3799+ * (
3800+ dim
3801+ for i , dim in enumerate (x .type .shape )
3802+ if i not in (self .axis1 , self .axis2 )
3803+ ),
3804+ diag_size ,
3805+ )
37733806
37743807 return Apply (
37753808 self ,
37763809 [x ],
3777- [x .type .clone (dtype = x .dtype , shape = tuple ( out_shape ) )()],
3810+ [x .type .clone (dtype = x .dtype , shape = out_shape )()],
37783811 )
37793812
3780- def perform (self , node , inputs , outputs ):
3813+ def perform (self , node , inputs , output_storage ):
37813814 (x ,) = inputs
3782- (z ,) = outputs
3783- z [0 ] = x .diagonal (self .offset , self .axis1 , self .axis2 )
3784- if not self .view :
3785- z [0 ] = z [0 ].copy ()
3815+ out = x .diagonal (self .offset , self .axis1 , self .axis2 )
3816+ if self .view :
3817+ try :
3818+ out .flags .writeable = True
3819+ except ValueError :
3820+ # We can't make this array writable
3821+ out = out .copy ()
3822+ else :
3823+ out = out .copy ()
3824+ output_storage [0 ][0 ] = out
3825+
3826+ def c_code (self , node , nodename , input_names , output_names , sub ):
3827+ [x_name ] = input_names
3828+ [out_name ] = output_names
3829+ return f"""
3830+ Py_XDECREF({ out_name } );
3831+
3832+ { out_name } = (PyArrayObject*) PyArray_Diagonal({ x_name } , { self .offset } , { self .axis1 } , { self .axis2 } );
3833+ if (!{ out_name } ) {{
3834+ { sub ["fail" ]} // Error already set by Numpy
3835+ }}
3836+
3837+ if ({ int (self .view )} && PyArray_ISWRITEABLE({ x_name } )) {{
3838+ // Make output writeable if input was writeable
3839+ PyArray_ENABLEFLAGS({ out_name } , NPY_ARRAY_WRITEABLE);
3840+ }} else {{
3841+ // Make a copy
3842+ PyArrayObject *{ out_name } _copy = (PyArrayObject*) PyArray_Copy({ out_name } );
3843+ Py_DECREF({ out_name } );
3844+ if (!{ out_name } _copy) {{
3845+ { sub ['fail' ]} ; // Error already set by Numpy
3846+ }}
3847+ { out_name } = { out_name } _copy;
3848+ }}
3849+ """
3850+
3851+ def c_code_cache_version (self ):
3852+ return (0 ,)
37863853
37873854 def grad (self , inputs , gout ):
37883855 # Avoid circular import
@@ -3829,19 +3896,6 @@ def infer_shape(self, fgraph, node, shapes):
38293896 out_shape .append (diag_size )
38303897 return [tuple (out_shape )]
38313898
3832- def __setstate__ (self , state ):
3833- self .__dict__ .update (state )
3834-
3835- if self .view :
3836- self .view_map = {0 : [0 ]}
3837-
3838- if "offset" not in state :
3839- self .offset = 0
3840- if "axis1" not in state :
3841- self .axis1 = 0
3842- if "axis2" not in state :
3843- self .axis2 = 1
3844-
38453899
38463900def extract_diag (x ):
38473901 warnings .warn (
0 commit comments