From 4348da7719766a4a0a9a338e02261c6d0894adb9 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 26 Aug 2023 00:49:39 -0400 Subject: [PATCH 01/15] Add numba overload for --- pytensor/link/numba/dispatch/__init__.py | 1 + pytensor/link/numba/dispatch/slinalg.py | 267 +++++++++++++++++++++++ tests/link/numba/test_slinalg.py | 22 ++ 3 files changed, 290 insertions(+) create mode 100644 pytensor/link/numba/dispatch/slinalg.py create mode 100644 tests/link/numba/test_slinalg.py diff --git a/pytensor/link/numba/dispatch/__init__.py b/pytensor/link/numba/dispatch/__init__.py index c7cb2632a1..9810e14178 100644 --- a/pytensor/link/numba/dispatch/__init__.py +++ b/pytensor/link/numba/dispatch/__init__.py @@ -10,5 +10,6 @@ import pytensor.link.numba.dispatch.elemwise import pytensor.link.numba.dispatch.scan import pytensor.link.numba.dispatch.sparse +import pytensor.link.numba.dispatch.slinalg # isort: on diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py new file mode 100644 index 0000000000..1d5b1f6001 --- /dev/null +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -0,0 +1,267 @@ +import ctypes + +import numba +import numpy as np +import scipy +from numba.core import cgutils, types +from numba.extending import get_cython_function_address, intrinsic, overload +from numba.np.linalg import _blas_kinds, _copy_to_fortran_order, ensure_lapack +from scipy import linalg + +from pytensor.link.numba.dispatch import basic as numba_basic +from pytensor.link.numba.dispatch.basic import numba_funcify +from pytensor.tensor.slinalg import SolveTriangular + + +_PTR = ctypes.POINTER + +_dbl = ctypes.c_double +_float = ctypes.c_float +_char = ctypes.c_char +_int = ctypes.c_int + +_ptr_float = _PTR(_float) +_ptr_dbl = _PTR(_dbl) +_ptr_char = _PTR(_char) +_ptr_int = _PTR(_int) + + +@intrinsic +def val_to_dptr(typingctx, data): + def impl(context, builder, signature, args): + ptr = cgutils.alloca_once_value(builder, args[0]) + return ptr + + sig = types.CPointer(types.float64)(types.float64) + return sig, impl + + +@intrinsic +def val_to_zptr(typingctx, data): + def impl(context, builder, signature, args): + ptr = cgutils.alloca_once_value(builder, args[0]) + return ptr + + sig = types.CPointer(types.complex128)(types.complex128) + return sig, impl + + +@intrinsic +def val_to_sptr(typingctx, data): + def impl(context, builder, signature, args): + ptr = cgutils.alloca_once_value(builder, args[0]) + return ptr + + sig = types.CPointer(types.float32)(types.float32) + return sig, impl + + +@intrinsic +def val_to_int_ptr(typingctx, data): + def impl(context, builder, signature, args): + ptr = cgutils.alloca_once_value(builder, args[0]) + return ptr + + sig = types.CPointer(types.int32)(types.int32) + return sig, impl + + +@intrinsic +def int_ptr_to_val(typingctx, data): + def impl(context, builder, signature, args): + val = builder.load(args[0]) + return val + + sig = types.int32(types.CPointer(types.int32)) + return sig, impl + + +@intrinsic +def dptr_to_val(typingctx, data): + def impl(context, builder, signature, args): + val = builder.load(args[0]) + return val + + sig = types.float64(types.CPointer(types.float64)) + return sig, impl + + +@intrinsic +def sptr_to_val(typingctx, data): + def impl(context, builder, signature, args): + val = builder.load(args[0]) + return val + + sig = types.float32(types.CPointer(types.float32)) + return sig, impl + + +def _get_float_pointer_for_dtype(blas_dtype): + if blas_dtype in ["s", "c"]: + return _ptr_float + elif blas_dtype in ["d", "z"]: + return _ptr_dbl + + +def _get_underlying_float(dtype): + s_dtype = str(dtype) + out_type = s_dtype + if s_dtype == "complex64": + out_type = "float32" + elif s_dtype == "complex128": + out_type = "float64" + + return np.dtype(out_type) + + +def _get_addr_and_float_pointer(dtype, name): + d = _blas_kinds[dtype] + func_name = f"{d}{name}" + float_pointer = _get_float_pointer_for_dtype(d) + addr = get_cython_function_address("scipy.linalg.cython_lapack", func_name) + + return addr, float_pointer + + +def _check_scipy_linalg_matrix(a, func_name): + """ + Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831 + """ + prefix = "scipy.linalg" + interp = (prefix, func_name) + # Unpack optional type + if isinstance(a, types.Optional): + a = a.type + if not isinstance(a, types.Array): + msg = "%s.%s() only supported for array types" % interp + raise numba.TypingError(msg, highlighting=False) + if not a.ndim == 2: + msg = "%s.%s() only supported on 2-D arrays." % interp + raise numba.TypingError(msg, highlighting=False) + if not isinstance(a.dtype, (types.Float, types.Complex)): + msg = "%s.%s() only supported on " "float and complex arrays." % interp + raise numba.TypingError(msg, highlighting=False) + + +class _LAPACK: + """ + Functions to return type signatures for wrapped LAPACK functions. + + Patterned after https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L74 + """ + + def __init__(self): + ensure_lapack() + + @classmethod + def test_blas_kinds(cls, dtype): + return _blas_kinds[dtype] + + @classmethod + def numba_xtrtrs(cls, dtype): + """ + Called by scipy.linalg.solve_triangular + """ + d = _blas_kinds[dtype] + func_name = f"{d}trtrs" + float_pointer = _get_float_pointer_for_dtype(d) + + addr = get_cython_function_address("scipy.linalg.cython_lapack", func_name) + functype = ctypes.CFUNCTYPE( + None, + _ptr_int, # UPLO + _ptr_int, # TRANS + _ptr_int, # DIAG + _ptr_int, # N + _ptr_int, # NRHS + float_pointer, # A + _ptr_int, # LDA + float_pointer, # B + _ptr_int, # LDB + _ptr_int, + ) # INFO + + return functype(addr) + + +@overload(scipy.linalg.solve_triangular) +def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False): + ensure_lapack() + + _check_scipy_linalg_matrix(A, "solve_triangular") + _check_scipy_linalg_matrix(B, "solve_triangular") + + dtype = A.dtype + w_type = _get_underlying_float(dtype) + + numba_trtrs = _LAPACK().numba_xtrtrs(dtype) + + def impl(A, B, trans=0, lower=False, unit_diagonal=False): + _N = np.int32(A.shape[-1]) + if A.shape[-2] != _N: + raise linalg.LinAlgError("Last 2 dimensions of A must be square") + + if A.shape[0] != B.shape[0]: + raise linalg.LinAlgError("Dimensions of A and B do not conform") + + A_copy = _copy_to_fortran_order(A) + B_copy = _copy_to_fortran_order(B) + + # if isinstance(trans, str): + # if trans not in ['N', 'C', 'T']: + # raise ValueError('Parameter "trans" should be one of N, C, T or 0, 1, 2') + # transval = ord(trans) + + # else: + if trans not in [0, 1, 2]: + raise ValueError('Parameter "trans" should be one of N, C, T or 0, 1, 2') + if trans == 0: + transval = ord("N") + elif trans == 1: + transval = ord("T") + else: + transval = ord("C") + + UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) + TRANS = val_to_int_ptr(transval) + DIAG = val_to_int_ptr(ord("U") if unit_diagonal else ord("N")) + N = val_to_int_ptr(_N) + NRHS = val_to_int_ptr(B.shape[1]) + LDA = val_to_int_ptr(_N) + LDB = val_to_int_ptr(_N) + INFO = val_to_int_ptr(0) + + numba_trtrs( + UPLO, + TRANS, + DIAG, + N, + NRHS, + A_copy.view(w_type).ctypes, + LDA, + B_copy.view(w_type).ctypes, + LDB, + INFO, + ) + + return B_copy + + return impl + + +@numba_funcify.register(SolveTriangular) +def numba_funcify_SolveTriangular(op, node, **kwargs): + trans = op.trans + lower = op.lower + unit_diagonal = op.unit_diagonal + check_finite = op.check_finite + + @numba_basic.numba_njit(inline="always") + def solve_triangular(a, b): + res = scipy.linalg.solve_triangular(a, b, trans, lower, unit_diagonal) + if check_finite: + if np.any(np.isinf(res)): + raise ValueError + return res + + return solve_triangular diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py new file mode 100644 index 0000000000..dac7dd23bf --- /dev/null +++ b/tests/link/numba/test_slinalg.py @@ -0,0 +1,22 @@ +import numpy as np + +import pytensor +import pytensor.tensor as pt +from pytensor import config + + +def test_solve_triangular(): + A = pt.matrix("A") + b = pt.matrix("b") + + X = pt.linalg.solve_triangular(A, b, lower=True) + f = pytensor.function([A, b], X, mode="NUMBA") + + A_val = np.random.normal(size=(5, 5)).astype(config.floatX) + A_sym = A_val @ A_val.T + A_tri = np.linalg.cholesky(A_sym) + + b = np.random.normal(size=(5, 1)).astype(config.floatX) + + X_np = f(A_tri, b) + np.testing.assert_allclose(A_tri @ X_np, b) From c9f5f4fb42cfb6037cf5b95df9e0e41197a4d5c6 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Mon, 28 Aug 2023 01:59:16 +0200 Subject: [PATCH 02/15] Overload dummy function instead of scipy.linalg --- pytensor/link/numba/dispatch/slinalg.py | 11 ++++++++--- tests/link/numba/test_slinalg.py | 17 +++++++++++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index 1d5b1f6001..c1afe46f1d 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -2,7 +2,6 @@ import numba import numpy as np -import scipy from numba.core import cgutils, types from numba.extending import get_cython_function_address, intrinsic, overload from numba.np.linalg import _blas_kinds, _copy_to_fortran_order, ensure_lapack @@ -184,7 +183,13 @@ def numba_xtrtrs(cls, dtype): return functype(addr) -@overload(scipy.linalg.solve_triangular) +def _solve_triangular(A, B, trans=0, lower=False, unit_diagonal=False): + return linalg.solve_triangular( + A, B, trans=trans, lower=lower, unit_diagonal=unit_diagonal + ) + + +@overload(_solve_triangular) def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False): ensure_lapack() @@ -258,7 +263,7 @@ def numba_funcify_SolveTriangular(op, node, **kwargs): @numba_basic.numba_njit(inline="always") def solve_triangular(a, b): - res = scipy.linalg.solve_triangular(a, b, trans, lower, unit_diagonal) + res = _solve_triangular(a, b, trans, lower, unit_diagonal) if check_finite: if np.any(np.isinf(res)): raise ValueError diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index dac7dd23bf..86e4d750c0 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -1,4 +1,7 @@ +import numba as nb import numpy as np +import pytest +from scipy import linalg import pytensor import pytensor.tensor as pt @@ -20,3 +23,17 @@ def test_solve_triangular(): X_np = f(A_tri, b) np.testing.assert_allclose(A_tri @ X_np, b) + + +def test_scipy_solve_triangular_not_overloaded(): + A_val = np.random.normal(size=(5, 5)).astype(config.floatX) + A_sym = A_val @ A_val.T + A_tri = np.linalg.cholesky(A_sym) + b = np.random.normal(size=(5, 1)).astype(config.floatX) + + @nb.njit + def test_solve_tri(a, b): + return linalg.solve_triangular(a, b) + + with pytest.raises(nb.TypingError, match="Failed in nopython mode"): + test_solve_tri(A_tri, b) From 08de8faceae31487348031c04cb2e0bad0a925a0 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Mon, 28 Aug 2023 02:09:14 +0200 Subject: [PATCH 03/15] Add tolerance for float32 tests --- tests/link/numba/test_slinalg.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index 86e4d750c0..cd7d2013a8 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -8,6 +8,10 @@ from pytensor import config +ATOL = 0 if config.floatX.endswith('64') else 1e-6 +RTOL = 1e-7 if config.floatX.endswith('64') else 1e-6 + + def test_solve_triangular(): A = pt.matrix("A") b = pt.matrix("b") @@ -22,7 +26,7 @@ def test_solve_triangular(): b = np.random.normal(size=(5, 1)).astype(config.floatX) X_np = f(A_tri, b) - np.testing.assert_allclose(A_tri @ X_np, b) + np.testing.assert_allclose(A_tri @ X_np, b, atol=ATOL, rtol=RTOL) def test_scipy_solve_triangular_not_overloaded(): From a323a1d0d86f6835f5b79fe9c47a199546e2751d Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Mon, 28 Aug 2023 02:09:37 +0200 Subject: [PATCH 04/15] Add tolerance for float32 tests --- tests/link/numba/test_slinalg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index cd7d2013a8..7ecaab5489 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -8,8 +8,8 @@ from pytensor import config -ATOL = 0 if config.floatX.endswith('64') else 1e-6 -RTOL = 1e-7 if config.floatX.endswith('64') else 1e-6 +ATOL = 0 if config.floatX.endswith("64") else 1e-6 +RTOL = 1e-7 if config.floatX.endswith("64") else 1e-6 def test_solve_triangular(): From 268f583d884b825da3e113cba97977d522c5f612 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Mon, 28 Aug 2023 02:17:57 +0200 Subject: [PATCH 05/15] Remove overload test --- tests/link/numba/test_slinalg.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index 7ecaab5489..e423fb3325 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -1,7 +1,4 @@ -import numba as nb import numpy as np -import pytest -from scipy import linalg import pytensor import pytensor.tensor as pt @@ -27,17 +24,3 @@ def test_solve_triangular(): X_np = f(A_tri, b) np.testing.assert_allclose(A_tri @ X_np, b, atol=ATOL, rtol=RTOL) - - -def test_scipy_solve_triangular_not_overloaded(): - A_val = np.random.normal(size=(5, 5)).astype(config.floatX) - A_sym = A_val @ A_val.T - A_tri = np.linalg.cholesky(A_sym) - b = np.random.normal(size=(5, 1)).astype(config.floatX) - - @nb.njit - def test_solve_tri(a, b): - return linalg.solve_triangular(a, b) - - with pytest.raises(nb.TypingError, match="Failed in nopython mode"): - test_solve_tri(A_tri, b) From 3837a28bc1272d24c5d1dcac44358020bcff8aaa Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Mon, 28 Aug 2023 03:42:01 +0200 Subject: [PATCH 06/15] Allow b to be 1d array Remove test_SolveTriangular from numba\test_nlinalg.py --- pytensor/link/numba/dispatch/slinalg.py | 30 ++++++++++++++-------- tests/link/numba/test_nlinalg.py | 34 ------------------------- tests/link/numba/test_slinalg.py | 14 +++++++--- 3 files changed, 31 insertions(+), 47 deletions(-) diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index c1afe46f1d..afe8730058 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -134,8 +134,10 @@ def _check_scipy_linalg_matrix(a, func_name): if not isinstance(a, types.Array): msg = "%s.%s() only supported for array types" % interp raise numba.TypingError(msg, highlighting=False) - if not a.ndim == 2: - msg = "%s.%s() only supported on 2-D arrays." % interp + if a.ndim not in [1, 2]: + msg = "%s.%s() only supported on 1d or 2d arrays, found %s." % ( + interp + (a.ndim,) + ) raise numba.TypingError(msg, highlighting=False) if not isinstance(a.dtype, (types.Float, types.Complex)): msg = "%s.%s() only supported on " "float and complex arrays." % interp @@ -161,11 +163,8 @@ def numba_xtrtrs(cls, dtype): """ Called by scipy.linalg.solve_triangular """ - d = _blas_kinds[dtype] - func_name = f"{d}trtrs" - float_pointer = _get_float_pointer_for_dtype(d) + addr, float_pointer = _get_addr_and_float_pointer(dtype, "trtrs") - addr = get_cython_function_address("scipy.linalg.cython_lapack", func_name) functype = ctypes.CFUNCTYPE( None, _ptr_int, # UPLO @@ -177,8 +176,8 @@ def numba_xtrtrs(cls, dtype): _ptr_int, # LDA float_pointer, # B _ptr_int, # LDB - _ptr_int, - ) # INFO + _ptr_int, # INFO + ) return functype(addr) @@ -202,6 +201,8 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False): numba_trtrs = _LAPACK().numba_xtrtrs(dtype) def impl(A, B, trans=0, lower=False, unit_diagonal=False): + B_is_1d = B.ndim == 1 + _N = np.int32(A.shape[-1]) if A.shape[-2] != _N: raise linalg.LinAlgError("Last 2 dimensions of A must be square") @@ -210,7 +211,12 @@ def impl(A, B, trans=0, lower=False, unit_diagonal=False): raise linalg.LinAlgError("Dimensions of A and B do not conform") A_copy = _copy_to_fortran_order(A) - B_copy = _copy_to_fortran_order(B) + + # Need to expand B here; I tried everywhere else and it doesn't work + if B_is_1d: + B_copy = _copy_to_fortran_order(np.expand_dims(B, -1)) + else: + B_copy = _copy_to_fortran_order(B) # if isinstance(trans, str): # if trans not in ['N', 'C', 'T']: @@ -227,11 +233,13 @@ def impl(A, B, trans=0, lower=False, unit_diagonal=False): else: transval = ord("C") + B_NDIM = 1 if B_is_1d else B.shape[1] + UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) TRANS = val_to_int_ptr(transval) DIAG = val_to_int_ptr(ord("U") if unit_diagonal else ord("N")) N = val_to_int_ptr(_N) - NRHS = val_to_int_ptr(B.shape[1]) + NRHS = val_to_int_ptr(B_NDIM) LDA = val_to_int_ptr(_N) LDB = val_to_int_ptr(_N) INFO = val_to_int_ptr(0) @@ -249,6 +257,8 @@ def impl(A, B, trans=0, lower=False, unit_diagonal=False): INFO, ) + if B_is_1d: + return B_copy[:, 0] return B_copy return impl diff --git a/tests/link/numba/test_nlinalg.py b/tests/link/numba/test_nlinalg.py index 7bc60d1313..955e976200 100644 --- a/tests/link/numba/test_nlinalg.py +++ b/tests/link/numba/test_nlinalg.py @@ -110,40 +110,6 @@ def test_Solve(A, x, lower, exc): ) -@pytest.mark.parametrize( - "A, x, lower, exc", - [ - ( - set_test_value( - at.dmatrix(), - (lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")), - ), - set_test_value(at.dvector(), rng.random(size=(3,)).astype("float64")), - "sym", - UserWarning, - ), - ], -) -def test_SolveTriangular(A, x, lower, exc): - g = slinalg.SolveTriangular(lower)(A, x) - - if isinstance(g, list): - g_fg = FunctionGraph(outputs=g) - else: - g_fg = FunctionGraph(outputs=[g]) - - cm = contextlib.suppress() if exc is None else pytest.warns(exc) - with cm: - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - @pytest.mark.parametrize( "x, exc", [ diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index e423fb3325..b0e648e51f 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -1,4 +1,6 @@ + import numpy as np +import pytest import pytensor import pytensor.tensor as pt @@ -7,11 +9,17 @@ ATOL = 0 if config.floatX.endswith("64") else 1e-6 RTOL = 1e-7 if config.floatX.endswith("64") else 1e-6 +rng = np.random.default_rng(42849) -def test_solve_triangular(): +@pytest.mark.parametrize( + "b_func, b_size", + [(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))], + ids=["b_col_vec", "b_matrix", "b_vec"], +) +def test_solve_triangular(b_func, b_size): A = pt.matrix("A") - b = pt.matrix("b") + b = b_func("b") X = pt.linalg.solve_triangular(A, b, lower=True) f = pytensor.function([A, b], X, mode="NUMBA") @@ -20,7 +28,7 @@ def test_solve_triangular(): A_sym = A_val @ A_val.T A_tri = np.linalg.cholesky(A_sym) - b = np.random.normal(size=(5, 1)).astype(config.floatX) + b = np.random.normal(size=b_size).astype(config.floatX) X_np = f(A_tri, b) np.testing.assert_allclose(A_tri @ X_np, b, atol=ATOL, rtol=RTOL) From 74f39654f51fa3267b3cbbf9f0c2daa6c2fc6828 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Mon, 28 Aug 2023 03:42:44 +0200 Subject: [PATCH 07/15] Allow b to be 1d array Remove test_SolveTriangular from numba\test_nlinalg.py --- pyproject.toml | 4 ++++ tests/link/numba/test_slinalg.py | 1 - 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7cfa7da8e7..0f397f79fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,6 +117,10 @@ pytensor = [ "js/*", ] +[tool.pytest.ini_options] +#env = "PYTENSOR_FLAGS=floatX=float32,gcc__cxxflags='-march=core2'" + + [tool.coverage.run] omit = [ "pytensor/_version.py", diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index b0e648e51f..a957c90ae1 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -1,4 +1,3 @@ - import numpy as np import pytest From 3e1c85ab341fb87963a54e412351b652d4b9bfc0 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Mon, 28 Aug 2023 03:45:46 +0200 Subject: [PATCH 08/15] revert local change to pyproject.toml --- pyproject.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0f397f79fb..7cfa7da8e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,10 +117,6 @@ pytensor = [ "js/*", ] -[tool.pytest.ini_options] -#env = "PYTENSOR_FLAGS=floatX=float32,gcc__cxxflags='-march=core2'" - - [tool.coverage.run] omit = [ "pytensor/_version.py", From 15e23b973b9fa53d3477ef8a1189d48dd2e4804a Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Mon, 28 Aug 2023 04:51:46 +0200 Subject: [PATCH 09/15] add numba importorskip to test_slinalg.py --- tests/link/numba/test_slinalg.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index a957c90ae1..7a22cd3029 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -6,6 +6,9 @@ from pytensor import config +numba = pytest.importorskip("numba") + + ATOL = 0 if config.floatX.endswith("64") else 1e-6 RTOL = 1e-7 if config.floatX.endswith("64") else 1e-6 rng = np.random.default_rng(42849) From 7649d59cf3f42d5a64499bdf2e43a6c9b499a0ef Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 29 Aug 2023 01:01:03 +0200 Subject: [PATCH 10/15] Test all parameterizations of solve_triangular --- pytensor/link/numba/dispatch/slinalg.py | 10 ++--- tests/link/numba/test_slinalg.py | 55 +++++++++++++++++++++---- 2 files changed, 49 insertions(+), 16 deletions(-) diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index afe8730058..c345834c75 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -4,7 +4,7 @@ import numpy as np from numba.core import cgutils, types from numba.extending import get_cython_function_address, intrinsic, overload -from numba.np.linalg import _blas_kinds, _copy_to_fortran_order, ensure_lapack +from numba.np.linalg import _copy_to_fortran_order, ensure_lapack, get_blas_kind from scipy import linalg from pytensor.link.numba.dispatch import basic as numba_basic @@ -114,7 +114,7 @@ def _get_underlying_float(dtype): def _get_addr_and_float_pointer(dtype, name): - d = _blas_kinds[dtype] + d = get_blas_kind(dtype) func_name = f"{d}{name}" float_pointer = _get_float_pointer_for_dtype(d) addr = get_cython_function_address("scipy.linalg.cython_lapack", func_name) @@ -154,10 +154,6 @@ class _LAPACK: def __init__(self): ensure_lapack() - @classmethod - def test_blas_kinds(cls, dtype): - return _blas_kinds[dtype] - @classmethod def numba_xtrtrs(cls, dtype): """ @@ -233,7 +229,7 @@ def impl(A, B, trans=0, lower=False, unit_diagonal=False): else: transval = ord("C") - B_NDIM = 1 if B_is_1d else B.shape[1] + B_NDIM = 1 if B_is_1d else int(B.shape[1]) UPLO = val_to_int_ptr(ord("L") if lower else ord("U")) TRANS = val_to_int_ptr(transval) diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index 7a22cd3029..4ab314d1f0 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -14,23 +14,60 @@ rng = np.random.default_rng(42849) +def transpose_func(x, trans): + if trans == 0: + return x + if trans == 1: + return x.conj().T + if trans == 2: + return x.T + + @pytest.mark.parametrize( "b_func, b_size", [(pt.matrix, (5, 1)), (pt.matrix, (5, 5)), (pt.vector, (5,))], ids=["b_col_vec", "b_matrix", "b_vec"], ) -def test_solve_triangular(b_func, b_size): - A = pt.matrix("A") - b = b_func("b") +@pytest.mark.parametrize("lower", [True, False], ids=["lower=True", "lower=False"]) +@pytest.mark.parametrize("trans", [0, 1, 2], ids=["trans=N", "trans=C", "trans=T"]) +@pytest.mark.parametrize( + "unit_diag", [True, False], ids=["unit_diag=True", "unit_diag=False"] +) +# @pytest.mark.parametrize('complex', [True, False], ids=['complex', 'real']) +def test_solve_triangular(b_func, b_size, lower, trans, unit_diag, complex=False): + # TODO: Complex raises ValueError: To change to a dtype of a different size, the last axis must be contiguous, why? + complex_dtype = "complex64" if config.floatX.endswith("32") else "complex128" + dtype = complex_dtype if complex else config.floatX - X = pt.linalg.solve_triangular(A, b, lower=True) + A = pt.matrix("A", dtype=dtype) + b = b_func("b", dtype=dtype) + + X = pt.linalg.solve_triangular( + A, b, lower=lower, trans=trans, unit_diagonal=unit_diag + ) f = pytensor.function([A, b], X, mode="NUMBA") - A_val = np.random.normal(size=(5, 5)).astype(config.floatX) - A_sym = A_val @ A_val.T - A_tri = np.linalg.cholesky(A_sym) + A_val = np.random.normal(size=(5, 5)) + b = np.random.normal(size=b_size) + + if complex: + A_val = A_val + np.random.normal(size=(5, 5)) * 1j + b = b + np.random.normal(size=b_size) * 1j + A_sym = A_val @ A_val.conj().T + + A_tri = np.linalg.cholesky(A_sym).astype(dtype) + if unit_diag: + adj_mat = np.ones((5, 5)) + adj_mat[np.diag_indices(5)] = 1 / np.diagonal(A_tri) + A_tri = A_tri * adj_mat + + A_tri = A_tri.astype(dtype) + b = b.astype(dtype) - b = np.random.normal(size=b_size).astype(config.floatX) + if not lower: + A_tri = A_tri.T X_np = f(A_tri, b) - np.testing.assert_allclose(A_tri @ X_np, b, atol=ATOL, rtol=RTOL) + np.testing.assert_allclose( + transpose_func(A_tri, trans) @ X_np, b, atol=ATOL, rtol=RTOL + ) From c7a1f28416a1ce441a3631ce26b7b860cf3b3d26 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 29 Aug 2023 15:39:44 +0200 Subject: [PATCH 11/15] Raise when inputs are complex Add informative message to error raised by check_finite=True --- pytensor/link/numba/dispatch/slinalg.py | 12 ++++++-- tests/link/numba/test_slinalg.py | 37 +++++++++++++++++++++++-- 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index c345834c75..ce837c606b 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -192,8 +192,12 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False): _check_scipy_linalg_matrix(B, "solve_triangular") dtype = A.dtype - w_type = _get_underlying_float(dtype) + if str(dtype) in ["complex64", "complex128"]: + raise ValueError( + "Complex inputs not currently supported by solve_triangular in Numba mode" + ) + w_type = _get_underlying_float(dtype) numba_trtrs = _LAPACK().numba_xtrtrs(dtype) def impl(A, B, trans=0, lower=False, unit_diagonal=False): @@ -271,8 +275,10 @@ def numba_funcify_SolveTriangular(op, node, **kwargs): def solve_triangular(a, b): res = _solve_triangular(a, b, trans, lower, unit_diagonal) if check_finite: - if np.any(np.isinf(res)): - raise ValueError + if np.any(np.bitwise_or(np.isinf(res), np.isnan(res))): + raise ValueError( + "Non-numeric values (nan or inf) returned by solve_triangular" + ) return res return solve_triangular diff --git a/tests/link/numba/test_slinalg.py b/tests/link/numba/test_slinalg.py index 4ab314d1f0..75e016f1e0 100644 --- a/tests/link/numba/test_slinalg.py +++ b/tests/link/numba/test_slinalg.py @@ -1,3 +1,5 @@ +import re + import numpy as np import pytest @@ -33,9 +35,16 @@ def transpose_func(x, trans): @pytest.mark.parametrize( "unit_diag", [True, False], ids=["unit_diag=True", "unit_diag=False"] ) -# @pytest.mark.parametrize('complex', [True, False], ids=['complex', 'real']) -def test_solve_triangular(b_func, b_size, lower, trans, unit_diag, complex=False): - # TODO: Complex raises ValueError: To change to a dtype of a different size, the last axis must be contiguous, why? +@pytest.mark.parametrize("complex", [True, False], ids=["complex", "real"]) +@pytest.mark.filterwarnings( + 'ignore:Cannot cache compiled function "numba_funcified_fgraph"' +) +def test_solve_triangular(b_func, b_size, lower, trans, unit_diag, complex): + if complex: + # TODO: Complex raises ValueError: To change to a dtype of a different size, the last axis must be contiguous, + # why? + pytest.skip("Complex inputs currently not supported to solve_triangular") + complex_dtype = "complex64" if config.floatX.endswith("32") else "complex128" dtype = complex_dtype if complex else config.floatX @@ -71,3 +80,25 @@ def test_solve_triangular(b_func, b_size, lower, trans, unit_diag, complex=False np.testing.assert_allclose( transpose_func(A_tri, trans) @ X_np, b, atol=ATOL, rtol=RTOL ) + + +@pytest.mark.parametrize("value", [np.nan, np.inf]) +@pytest.mark.filterwarnings( + 'ignore:Cannot cache compiled function "numba_funcified_fgraph"' +) +def test_solve_triangular_raises_on_nan_inf(value): + A = pt.matrix("A") + b = pt.matrix("b") + + X = pt.linalg.solve_triangular(A, b, check_finite=True) + f = pytensor.function([A, b], X, mode="NUMBA") + A_val = np.random.normal(size=(5, 5)) + A_sym = A_val @ A_val.conj().T + + A_tri = np.linalg.cholesky(A_sym).astype(config.floatX) + b = np.full((5, 1), value) + + with pytest.raises( + ValueError, match=re.escape("Non-numeric values (nan or inf) returned ") + ): + f(A_tri, b) From 296fec3672d43047a83e4c438cd014b86c7737e9 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 29 Aug 2023 16:37:15 +0200 Subject: [PATCH 12/15] simplify check for complex input types --- pytensor/link/numba/dispatch/slinalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index ce837c606b..2f9736f4d6 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -192,7 +192,7 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False): _check_scipy_linalg_matrix(B, "solve_triangular") dtype = A.dtype - if str(dtype) in ["complex64", "complex128"]: + if str(dtype).startswith('complex'): raise ValueError( "Complex inputs not currently supported by solve_triangular in Numba mode" ) From a77968723ca0eec550dbeae7206d06cbe31fd159 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 29 Aug 2023 16:37:15 +0200 Subject: [PATCH 13/15] simplify check for complex input types --- pytensor/link/numba/dispatch/slinalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index ce837c606b..1f85eb44c2 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -192,7 +192,7 @@ def solve_triangular_impl(A, B, trans=0, lower=False, unit_diagonal=False): _check_scipy_linalg_matrix(B, "solve_triangular") dtype = A.dtype - if str(dtype) in ["complex64", "complex128"]: + if str(dtype).startswith("complex"): raise ValueError( "Complex inputs not currently supported by solve_triangular in Numba mode" ) From dd8cfedc456b181d67e35fa794b9173da3585ee1 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 29 Aug 2023 18:22:36 +0200 Subject: [PATCH 14/15] Rename _get_addr_and_float_pointer to _get_lapack_ptr_and_ptr_type Rename addr to lapack_ptr --- pytensor/link/numba/dispatch/slinalg.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index 1f85eb44c2..a11722855d 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -113,13 +113,13 @@ def _get_underlying_float(dtype): return np.dtype(out_type) -def _get_addr_and_float_pointer(dtype, name): +def _get_lapack_ptr_and_ptr_type(dtype, name): d = get_blas_kind(dtype) func_name = f"{d}{name}" float_pointer = _get_float_pointer_for_dtype(d) - addr = get_cython_function_address("scipy.linalg.cython_lapack", func_name) + lapack_ptr = get_cython_function_address("scipy.linalg.cython_lapack", func_name) - return addr, float_pointer + return lapack_ptr, float_pointer def _check_scipy_linalg_matrix(a, func_name): @@ -159,7 +159,7 @@ def numba_xtrtrs(cls, dtype): """ Called by scipy.linalg.solve_triangular """ - addr, float_pointer = _get_addr_and_float_pointer(dtype, "trtrs") + lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "trtrs") functype = ctypes.CFUNCTYPE( None, @@ -175,7 +175,7 @@ def numba_xtrtrs(cls, dtype): _ptr_int, # INFO ) - return functype(addr) + return functype(lapack_ptr) def _solve_triangular(A, B, trans=0, lower=False, unit_diagonal=False): From 519b1c538e5319d3916156fb7263bf721051d1c0 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Tue, 29 Aug 2023 18:39:45 +0200 Subject: [PATCH 15/15] Don't copy A matrix in overload func Don't copy B matrix when B is array in overload func --- pytensor/link/numba/dispatch/slinalg.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/pytensor/link/numba/dispatch/slinalg.py b/pytensor/link/numba/dispatch/slinalg.py index a11722855d..ad8065defd 100644 --- a/pytensor/link/numba/dispatch/slinalg.py +++ b/pytensor/link/numba/dispatch/slinalg.py @@ -210,20 +210,11 @@ def impl(A, B, trans=0, lower=False, unit_diagonal=False): if A.shape[0] != B.shape[0]: raise linalg.LinAlgError("Dimensions of A and B do not conform") - A_copy = _copy_to_fortran_order(A) - - # Need to expand B here; I tried everywhere else and it doesn't work if B_is_1d: - B_copy = _copy_to_fortran_order(np.expand_dims(B, -1)) + B_copy = np.asfortranarray(np.expand_dims(B, -1)) else: B_copy = _copy_to_fortran_order(B) - # if isinstance(trans, str): - # if trans not in ['N', 'C', 'T']: - # raise ValueError('Parameter "trans" should be one of N, C, T or 0, 1, 2') - # transval = ord(trans) - - # else: if trans not in [0, 1, 2]: raise ValueError('Parameter "trans" should be one of N, C, T or 0, 1, 2') if trans == 0: @@ -250,7 +241,7 @@ def impl(A, B, trans=0, lower=False, unit_diagonal=False): DIAG, N, NRHS, - A_copy.view(w_type).ctypes, + np.asfortranarray(A).T.view(w_type).ctypes, LDA, B_copy.view(w_type).ctypes, LDB, @@ -258,7 +249,7 @@ def impl(A, B, trans=0, lower=False, unit_diagonal=False): ) if B_is_1d: - return B_copy[:, 0] + return B_copy[..., 0] return B_copy return impl