| 
2 | 2 | import pytest  | 
3 | 3 | 
 
  | 
4 | 4 | from pytensor.compile.function import function  | 
5 |  | -from pytensor.compile.mode import Mode  | 
6 | 5 | from pytensor.configdefaults import config  | 
7 | 6 | from pytensor.graph.fg import FunctionGraph  | 
8 |  | -from pytensor.graph.op import get_test_value  | 
9 |  | -from pytensor.graph.rewriting.db import RewriteDatabaseQuery  | 
10 |  | -from pytensor.link.jax import JAXLinker  | 
11 |  | -from pytensor.tensor import blas as pt_blas  | 
12 | 7 | from pytensor.tensor import nlinalg as pt_nlinalg  | 
13 |  | -from pytensor.tensor.math import Argmax, Max, maximum  | 
14 |  | -from pytensor.tensor.math import max as pt_max  | 
15 |  | -from pytensor.tensor.type import dvector, matrix, scalar, tensor3, vector  | 
 | 8 | +from pytensor.tensor.type import matrix  | 
16 | 9 | from tests.link.jax.test_basic import compare_jax_and_py  | 
17 | 10 | 
 
  | 
18 | 11 | 
 
  | 
19 | 12 | jax = pytest.importorskip("jax")  | 
20 | 13 | 
 
  | 
21 | 14 | 
 
  | 
22 |  | -def test_jax_BatchedDot():  | 
23 |  | -    # tensor3 . tensor3  | 
24 |  | -    a = tensor3("a")  | 
25 |  | -    a.tag.test_value = (  | 
26 |  | -        np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3))  | 
27 |  | -    )  | 
28 |  | -    b = tensor3("b")  | 
29 |  | -    b.tag.test_value = (  | 
30 |  | -        np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2))  | 
31 |  | -    )  | 
32 |  | -    out = pt_blas.BatchedDot()(a, b)  | 
33 |  | -    fgraph = FunctionGraph([a, b], [out])  | 
34 |  | -    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])  | 
35 |  | - | 
36 |  | -    # A dimension mismatch should raise a TypeError for compatibility  | 
37 |  | -    inputs = [get_test_value(a)[:-1], get_test_value(b)]  | 
38 |  | -    opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])  | 
39 |  | -    jax_mode = Mode(JAXLinker(), opts)  | 
40 |  | -    pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode)  | 
41 |  | -    with pytest.raises(TypeError):  | 
42 |  | -        pytensor_jax_fn(*inputs)  | 
43 |  | - | 
44 |  | - | 
45 | 15 | def test_jax_basic_multiout():  | 
46 | 16 |     rng = np.random.default_rng(213234)  | 
47 | 17 | 
 
  | 
@@ -79,45 +49,6 @@ def assert_fn(x, y):  | 
79 | 49 |     compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)  | 
80 | 50 | 
 
  | 
81 | 51 | 
 
  | 
82 |  | -def test_jax_max_and_argmax():  | 
83 |  | -    # Test that a single output of a multi-output `Op` can be used as input to  | 
84 |  | -    # another `Op`  | 
85 |  | -    x = dvector()  | 
86 |  | -    mx = Max([0])(x)  | 
87 |  | -    amx = Argmax([0])(x)  | 
88 |  | -    out = mx * amx  | 
89 |  | -    out_fg = FunctionGraph([x], [out])  | 
90 |  | -    compare_jax_and_py(out_fg, [np.r_[1, 2]])  | 
91 |  | - | 
92 |  | - | 
93 |  | -def test_tensor_basics():  | 
94 |  | -    y = vector("y")  | 
95 |  | -    y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)  | 
96 |  | -    x = vector("x")  | 
97 |  | -    x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX)  | 
98 |  | -    A = matrix("A")  | 
99 |  | -    A.tag.test_value = np.empty((2, 2), dtype=config.floatX)  | 
100 |  | -    alpha = scalar("alpha")  | 
101 |  | -    alpha.tag.test_value = np.array(3.0, dtype=config.floatX)  | 
102 |  | -    beta = scalar("beta")  | 
103 |  | -    beta.tag.test_value = np.array(5.0, dtype=config.floatX)  | 
104 |  | - | 
105 |  | -    # This should be converted into a `Gemv` `Op` when the non-JAX compatible  | 
106 |  | -    # optimizations are turned on; however, when using JAX mode, it should  | 
107 |  | -    # leave the expression alone.  | 
108 |  | -    out = y.dot(alpha * A).dot(x) + beta * y  | 
109 |  | -    fgraph = FunctionGraph([y, x, A, alpha, beta], [out])  | 
110 |  | -    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])  | 
111 |  | - | 
112 |  | -    out = maximum(y, x)  | 
113 |  | -    fgraph = FunctionGraph([y, x], [out])  | 
114 |  | -    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])  | 
115 |  | - | 
116 |  | -    out = pt_max(y)  | 
117 |  | -    fgraph = FunctionGraph([y], [out])  | 
118 |  | -    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])  | 
119 |  | - | 
120 |  | - | 
121 | 52 | def test_pinv():  | 
122 | 53 |     x = matrix("x")  | 
123 | 54 |     x_inv = pt_nlinalg.pinv(x)  | 
 | 
0 commit comments