From e056fcf27491d0cff0ef494bce61ac322a31c974 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 9 Jan 2023 10:17:27 +0100 Subject: [PATCH] Seed logsumexp benchmark tests Also adds missing numba benchmark test Co-authored-by: Brandon T. Willard --- tests/link/jax/test_elemwise.py | 3 ++- tests/link/numba/test_elemwise.py | 23 +++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/tests/link/jax/test_elemwise.py b/tests/link/jax/test_elemwise.py index 947bea6a5b..0f903a33b2 100644 --- a/tests/link/jax/test_elemwise.py +++ b/tests/link/jax/test_elemwise.py @@ -111,7 +111,8 @@ def test_logsumexp_benchmark(size, axis, benchmark): X_max = at.switch(at.isinf(X_max), 0, X_max) X_lse = at.log(at.sum(at.exp(X - X_max), axis=axis, keepdims=True)) + X_max - X_val = np.random.normal(size=size) + rng = np.random.default_rng(23920) + X_val = rng.normal(size=size) X_lse_fn = pytensor.function([X], X_lse, mode="JAX") diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index 0958e90034..d05b5ab95d 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -2,7 +2,9 @@ import numpy as np import pytest +import scipy.special +import pytensor import pytensor.tensor as at import pytensor.tensor.inplace as ati import pytensor.tensor.math as aem @@ -532,3 +534,24 @@ def test_MaxAndArgmax(x, axes, exc): if not isinstance(i, (SharedVariable, Constant)) ], ) + + +@pytest.mark.parametrize("size", [(10, 10), (1000, 1000), (10000, 10000)]) +@pytest.mark.parametrize("axis", [0, 1]) +def test_logsumexp_benchmark(size, axis, benchmark): + + X = at.matrix("X") + X_max = at.max(X, axis=axis, keepdims=True) + X_max = at.switch(at.isinf(X_max), 0, X_max) + X_lse = at.log(at.sum(at.exp(X - X_max), axis=axis, keepdims=True)) + X_max + + rng = np.random.default_rng(23920) + X_val = rng.normal(size=size) + + X_lse_fn = pytensor.function([X], X_lse, mode="JAX") + + # JIT compile first + _ = X_lse_fn(X_val) + res = benchmark(X_lse_fn, X_val) + exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True) + np.testing.assert_array_almost_equal(res, exp_res)