Skip to content

Commit 7045b91

Browse files
committed
fix(numba): Cast arguments to dot to float
Numba doesn't support dot with non-floating point arguments.
1 parent 95d2b66 commit 7045b91

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,9 @@ def specify_shape(x, {create_arg_string(shape_input_names)}):
761761
def int_to_float_fn(inputs, out_dtype):
762762
"""Create a Numba function that converts integer and boolean ``ndarray``s to floats."""
763763

764-
if all(input.type.numpy_dtype == np.dtype(out_dtype) for input in inputs):
764+
if all(
765+
input.type.numpy_dtype == np.dtype(out_dtype) for input in inputs
766+
) and isinstance(np.dtype(out_dtype), np.floating):
765767

766768
@numba_njit
767769
def inputs_cast(x):

0 commit comments

Comments
 (0)