We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 95d2b66 commit 7045b91Copy full SHA for 7045b91
pytensor/link/numba/dispatch/basic.py
@@ -761,7 +761,9 @@ def specify_shape(x, {create_arg_string(shape_input_names)}):
761
def int_to_float_fn(inputs, out_dtype):
762
"""Create a Numba function that converts integer and boolean ``ndarray``s to floats."""
763
764
- if all(input.type.numpy_dtype == np.dtype(out_dtype) for input in inputs):
+ if all(
765
+ input.type.numpy_dtype == np.dtype(out_dtype) for input in inputs
766
+ ) and isinstance(np.dtype(out_dtype), np.floating):
767
768
@numba_njit
769
def inputs_cast(x):
0 commit comments