1212from numba .core .base import BaseContext
1313from numba .core .types .misc import NoneType
1414from numba .np import arrayobj
15- from numba .np .ufunc .wrappers import _ArrayArgLoader
1615
1716
1817def compute_itershape (
@@ -158,7 +157,7 @@ def make_loop_call(
158157 input_types : tuple [Any , ...],
159158 output_types : tuple [Any , ...],
160159):
161- # safe = (False, False)
160+ safe = (False , False )
162161
163162 n_outputs = len (outputs )
164163
@@ -183,14 +182,6 @@ def extract_array(aryty, obj):
183182 # input_scope_set = mod.add_metadata([input_scope, output_scope])
184183 # output_scope_set = mod.add_metadata([input_scope, output_scope])
185184
186- typ = input_types [0 ]
187- inp = inputs [0 ]
188- shape = cgutils .unpack_tuple (builder , inp .shape )
189- strides = cgutils .unpack_tuple (builder , inp .strides )
190- loader = _ArrayArgLoader (typ .dtype , typ .ndim , shape [- 1 ], False , shape , strides )
191-
192- inputs = tuple (extract_array (aryty , ary ) for aryty , ary in zip (input_types , inputs ))
193-
194185 outputs = tuple (
195186 extract_array (aryty , ary ) for aryty , ary in zip (output_types , outputs )
196187 )
@@ -221,13 +212,50 @@ def extract_array(aryty, obj):
221212
222213 # Load values from input arrays
223214 input_vals = []
224- for array_info , bc in zip (inputs , input_bc ):
225- idxs_bc = [zero if bc else idx for idx , bc in zip (idxs , bc )]
226- # ptr = cgutils.get_item_pointer2(context, builder, *array_info, idxs_bc, *safe)
227- val = loader .load (context , builder , inp .data , idxs [0 ] or zero )
228- # val = builder.load(ptr)
229- # val.set_metadata("alias.scope", input_scope_set)
230- # val.set_metadata("noalias", output_scope_set)
215+ for input , input_type , bc in zip (inputs , input_types , input_bc ):
216+ core_ndim = input_type .ndim - len (bc )
217+
218+ idxs_bc = [zero if bc else idx for idx , bc in zip (idxs , bc )] + [
219+ zero
220+ ] * core_ndim
221+ ptr = cgutils .get_item_pointer2 (
222+ context ,
223+ builder ,
224+ input .data ,
225+ cgutils .unpack_tuple (builder , input .shape ),
226+ cgutils .unpack_tuple (builder , input .strides ),
227+ input_type .layout ,
228+ idxs_bc ,
229+ * safe ,
230+ )
231+ if core_ndim == 0 :
232+ # Retrive scalar item at index
233+ val = builder .load (ptr )
234+ # val.set_metadata("alias.scope", input_scope_set)
235+ # val.set_metadata("noalias", output_scope_set)
236+ else :
237+ # Retrieve array item at index
238+ # This is a streamlined version of Numba's `GUArrayArg.load`
239+ # TODO check layout arg!
240+ core_arry_type = types .Array (
241+ dtype = input_type .dtype , ndim = core_ndim , layout = input_type .layout
242+ )
243+ core_array = context .make_array (core_arry_type )(context , builder )
244+ core_shape = cgutils .unpack_tuple (builder , input .shape )[- core_ndim :]
245+ core_strides = cgutils .unpack_tuple (builder , input .strides )[- core_ndim :]
246+ itemsize = context .get_abi_sizeof (context .get_data_type (input_type .dtype ))
247+ context .populate_array (
248+ core_array ,
249+ # TODO whey do we need to bitcast?
250+ data = builder .bitcast (ptr , core_array .data .type ),
251+ shape = cgutils .pack_array (builder , core_shape ),
252+ strides = cgutils .pack_array (builder , core_strides ),
253+ itemsize = context .get_constant (types .intp , itemsize ),
254+ # TODO what is meminfo about?
255+ meminfo = None ,
256+ )
257+ val = core_array ._getvalue ()
258+
231259 input_vals .append (val )
232260
233261 inner_codegen = context .get_function (scalar_func , scalar_signature )
@@ -350,17 +378,30 @@ def _vectorized(
350378
351379 batch_ndim = len (input_bc_patterns [0 ])
352380
353- if not all (input .ndim >= batch_ndim for input in inputs ):
354- raise TypingError ("Vectorized inputs must have the same rank." )
381+ if not all (
382+ len (pattern ) == batch_ndim for pattern in input_bc_patterns + output_bc_patterns
383+ ):
384+ raise TypingError (
385+ "Vectorized broadcastable patterns must have the same length."
386+ )
355387
356- if not all (len (pattern ) >= batch_ndim for pattern in output_bc_patterns ):
357- raise TypingError ("Invalid output broadcasting pattern." )
388+ core_input_types = []
389+ for input_type , bc_pattern in zip (inputs , input_bc_patterns ):
390+ core_ndim = input_type .ndim - len (bc_pattern )
391+ # TODO: Reconsider this
392+ if core_ndim == 0 :
393+ core_input_type = input_type .dtype
394+ else :
395+ core_input_type = types .Array (
396+ dtype = input_type .dtype , ndim = core_ndim , layout = input_type .layout
397+ )
398+ core_input_types .append (core_input_type )
358399
359- scalar_signature = typingctx .resolve_function_type (
400+ core_signature = typingctx .resolve_function_type (
360401 scalar_func ,
361402 [
362403 * constant_inputs ,
363- * [ in_type . dtype if in_type . ndim == 0 else in_type for in_type in inputs ] ,
404+ * core_input_types ,
364405 ],
365406 {},
366407 )
@@ -415,7 +456,7 @@ def codegen(
415456 ctx ,
416457 builder ,
417458 scalar_func ,
418- scalar_signature ,
459+ core_signature ,
419460 iter_shape ,
420461 constant_inputs ,
421462 inputs ,
0 commit comments