Skip to content

Revert patch for index underflow after location 127 when mode=JAX #395

@jessegrabowski

Description

@jessegrabowski

Describe the issue:

It seems the JAX linker downcasts index constants to uint8? mode=None and mode="NUMBA" work as expected. Declaring an index variable (i = pt.lscalar('i'); z = x[i]) also works as expected.

Reproducable code example:

import pytensor
import pytensor.tensor as pt
import numpy as np

x = pt.dvector('x')
z1 = x[127]
z2 = x[128]
f = pytensor.function([x], [z1, z2], mode='JAX')

f(np.arange(200))
# out: [Array(127., dtype=float64), Array(0., dtype=float64)]

Error message:

No response

PyTensor version information:

Pytensor 2.13.1

Context for the issue:

No response

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions