Skip to content

Commit 8d3bda6

Browse files
add support for dtype kwarg
1 parent 17707fb commit 8d3bda6

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

pytensor/link/jax/dispatch/tensor_basic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,10 @@ def scalar_from_tensor(x):
221221

222222
@jax_funcify.register(Tri)
223223
def jax_funcify_Tri(op, node, **kwargs):
224+
dtype = op.dtype
224225
tri_args = node.inputs
225226
constant_args = []
227+
226228
for arg in tri_args:
227229
if not isinstance(arg, Constant):
228230
raise NotImplementedError(
@@ -234,6 +236,6 @@ def jax_funcify_Tri(op, node, **kwargs):
234236
M, N, k = constant_args
235237

236238
def tri(*_):
237-
return jnp.tri(N, M, k)
239+
return jnp.tri(N, M, k, dtype=dtype)
238240

239241
return tri

0 commit comments

Comments
 (0)