|
9 | 9 | from pytensor.graph.op import Op |
10 | 10 | from pytensor.link.numba.dispatch import basic as numba_basic |
11 | 11 | from pytensor.link.numba.dispatch.basic import ( |
12 | | - create_numba_signature, |
13 | 12 | numba_funcify, |
14 | 13 | numba_njit, |
15 | | - use_optimized_cheap_pass, |
16 | 14 | ) |
17 | 15 | from pytensor.link.numba.dispatch.vectorize_codegen import ( |
18 | 16 | _jit_options, |
@@ -245,47 +243,6 @@ def {careduce_fn_name}(x): |
245 | 243 | return careduce_fn |
246 | 244 |
|
247 | 245 |
|
248 | | -def jit_compile_reducer( |
249 | | - node, fn, *, reduce_to_scalar=False, infer_signature=True, **kwds |
250 | | -): |
251 | | - """Compile Python source for reduction loops using additional optimizations. |
252 | | -
|
253 | | - Parameters |
254 | | - ========== |
255 | | - node |
256 | | - An node from which the signature can be derived. |
257 | | - fn |
258 | | - The Python function object to compile. |
259 | | - reduce_to_scalar: bool, default False |
260 | | - Whether to reduce output to a scalar (instead of 0d array) |
261 | | - infer_signature: bool: default True |
262 | | - Whether to try and infer the function signature from the Apply node. |
263 | | - kwds |
264 | | - Extra keywords to be added to the :func:`numba.njit` function. |
265 | | -
|
266 | | - Returns |
267 | | - ======= |
268 | | - A :func:`numba.njit`-compiled function. |
269 | | -
|
270 | | - """ |
271 | | - if infer_signature: |
272 | | - signature = create_numba_signature(node, reduce_to_scalar=reduce_to_scalar) |
273 | | - args = (signature,) |
274 | | - else: |
275 | | - args = () |
276 | | - |
277 | | - # Eagerly compile the function using increased optimizations. This should |
278 | | - # help improve nested loop reductions. |
279 | | - with use_optimized_cheap_pass(): |
280 | | - res = numba_basic.numba_njit( |
281 | | - *args, |
282 | | - boundscheck=False, |
283 | | - **kwds, |
284 | | - )(fn) |
285 | | - |
286 | | - return res |
287 | | - |
288 | | - |
289 | 246 | def create_axis_apply_fn(fn, axis, ndim, dtype): |
290 | 247 | axis = normalize_axis_index(axis, ndim) |
291 | 248 |
|
@@ -448,7 +405,7 @@ def numba_funcify_CAReduce(op, node, **kwargs): |
448 | 405 | np.dtype(node.outputs[0].type.dtype), |
449 | 406 | ) |
450 | 407 |
|
451 | | - careduce_fn = jit_compile_reducer(node, careduce_py_fn, reduce_to_scalar=False) |
| 408 | + careduce_fn = numba_njit(careduce_py_fn, boundscheck=False) |
452 | 409 | return careduce_fn |
453 | 410 |
|
454 | 411 |
|
@@ -579,7 +536,7 @@ def softmax_py_fn(x): |
579 | 536 | sm = e_x / w |
580 | 537 | return sm |
581 | 538 |
|
582 | | - softmax = jit_compile_reducer(node, softmax_py_fn) |
| 539 | + softmax = numba_njit(softmax_py_fn, boundscheck=False) |
583 | 540 |
|
584 | 541 | return softmax |
585 | 542 |
|
@@ -608,8 +565,7 @@ def softmax_grad_py_fn(dy, sm): |
608 | 565 | dx = dy_times_sm - sum_dy_times_sm * sm |
609 | 566 | return dx |
610 | 567 |
|
611 | | - # The signature inferred by jit_compile_reducer is wrong when dy is a constant (readonly=True) |
612 | | - softmax_grad = jit_compile_reducer(node, softmax_grad_py_fn, infer_signature=False) |
| 568 | + softmax_grad = numba_njit(softmax_grad_py_fn, boundscheck=False) |
613 | 569 |
|
614 | 570 | return softmax_grad |
615 | 571 |
|
@@ -647,7 +603,7 @@ def log_softmax_py_fn(x): |
647 | 603 | lsm = xdev - np.log(reduce_sum(np.exp(xdev))) |
648 | 604 | return lsm |
649 | 605 |
|
650 | | - log_softmax = jit_compile_reducer(node, log_softmax_py_fn) |
| 606 | + log_softmax = numba_njit(log_softmax_py_fn, boundscheck=False) |
651 | 607 | return log_softmax |
652 | 608 |
|
653 | 609 |
|
|
0 commit comments