diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 783e1f46a9..3cc652007d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ exclude: | )$ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.4.0 hooks: - id: debug-statements exclude: | @@ -25,9 +25,10 @@ repos: - id: black language_version: python3 - repo: https://github.com/pycqa/flake8 - rev: 5.0.4 + rev: 6.0.0 hooks: - id: flake8 + language_version: python39 - repo: https://github.com/pycqa/isort rev: 5.10.1 hooks: diff --git a/pytensor/link/numba/dispatch/__init__.py b/pytensor/link/numba/dispatch/__init__.py index 1e7d40b1b1..c7cb2632a1 100644 --- a/pytensor/link/numba/dispatch/__init__.py +++ b/pytensor/link/numba/dispatch/__init__.py @@ -9,5 +9,6 @@ import pytensor.link.numba.dispatch.random import pytensor.link.numba.dispatch.elemwise import pytensor.link.numba.dispatch.scan +import pytensor.link.numba.dispatch.sparse # isort: on diff --git a/pytensor/link/numba/dispatch/sparse.py b/pytensor/link/numba/dispatch/sparse.py new file mode 100644 index 0000000000..d07e029501 --- /dev/null +++ b/pytensor/link/numba/dispatch/sparse.py @@ -0,0 +1,142 @@ +import scipy as sp +import scipy.sparse +from numba.core import cgutils, types +from numba.extending import ( + NativeValue, + box, + make_attribute_wrapper, + models, + register_model, + typeof_impl, + unbox, +) + + +class CSMatrixType(types.Type): + """A Numba `Type` modeled after the base class `scipy.sparse.compressed._cs_matrix`.""" + + name: str + instance_class: type + + def __init__(self, dtype): + self.dtype = dtype + self.data = types.Array(dtype, 1, "A") + self.indices = types.Array(types.int32, 1, "A") + self.indptr = types.Array(types.int32, 1, "A") + self.shape = types.UniTuple(types.int64, 2) + super().__init__(self.name) + + +make_attribute_wrapper(CSMatrixType, "data", "data") +make_attribute_wrapper(CSMatrixType, "indices", "indices") +make_attribute_wrapper(CSMatrixType, "indptr", "indptr") +make_attribute_wrapper(CSMatrixType, "shape", "shape") + + +class CSRMatrixType(CSMatrixType): + name = "csr_matrix" + + @staticmethod + def instance_class(data, indices, indptr, shape): + return sp.sparse.csr_matrix((data, indices, indptr), shape, copy=False) + + +class CSCMatrixType(CSMatrixType): + name = "csc_matrix" + + @staticmethod + def instance_class(data, indices, indptr, shape): + return sp.sparse.csc_matrix((data, indices, indptr), shape, copy=False) + + +@typeof_impl.register(sp.sparse.csc_matrix) +def typeof_csc_matrix(val, c): + data = typeof_impl(val.data, c) + return CSCMatrixType(data.dtype) + + +@typeof_impl.register(sp.sparse.csr_matrix) +def typeof_csr_matrix(val, c): + data = typeof_impl(val.data, c) + return CSRMatrixType(data.dtype) + + +@register_model(CSRMatrixType) +class CSRMatrixModel(models.StructModel): + def __init__(self, dmm, fe_type): + members = [ + ("data", fe_type.data), + ("indices", fe_type.indices), + ("indptr", fe_type.indptr), + ("shape", fe_type.shape), + ] + super().__init__(dmm, fe_type, members) + + +@register_model(CSCMatrixType) +class CSCMatrixModel(models.StructModel): + def __init__(self, dmm, fe_type): + members = [ + ("data", fe_type.data), + ("indices", fe_type.indices), + ("indptr", fe_type.indptr), + ("shape", fe_type.shape), + ] + super().__init__(dmm, fe_type, members) + + +@unbox(CSCMatrixType) +@unbox(CSRMatrixType) +def unbox_matrix(typ, obj, c): + + struct_ptr = cgutils.create_struct_proxy(typ)(c.context, c.builder) + + data = c.pyapi.object_getattr_string(obj, "data") + indices = c.pyapi.object_getattr_string(obj, "indices") + indptr = c.pyapi.object_getattr_string(obj, "indptr") + shape = c.pyapi.object_getattr_string(obj, "shape") + + struct_ptr.data = c.unbox(typ.data, data).value + struct_ptr.indices = c.unbox(typ.indices, indices).value + struct_ptr.indptr = c.unbox(typ.indptr, indptr).value + struct_ptr.shape = c.unbox(typ.shape, shape).value + + c.pyapi.decref(data) + c.pyapi.decref(indices) + c.pyapi.decref(indptr) + c.pyapi.decref(shape) + + is_error_ptr = cgutils.alloca_once_value(c.builder, cgutils.false_bit) + is_error = c.builder.load(is_error_ptr) + + res = NativeValue(struct_ptr._getvalue(), is_error=is_error) + + return res + + +@box(CSCMatrixType) +@box(CSRMatrixType) +def box_matrix(typ, val, c): + struct_ptr = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val) + + data_obj = c.box(typ.data, struct_ptr.data) + indices_obj = c.box(typ.indices, struct_ptr.indices) + indptr_obj = c.box(typ.indptr, struct_ptr.indptr) + shape_obj = c.box(typ.shape, struct_ptr.shape) + + c.pyapi.incref(data_obj) + c.pyapi.incref(indices_obj) + c.pyapi.incref(indptr_obj) + c.pyapi.incref(shape_obj) + + cls_obj = c.pyapi.unserialize(c.pyapi.serialize_object(typ.instance_class)) + obj = c.pyapi.call_function_objargs( + cls_obj, (data_obj, indices_obj, indptr_obj, shape_obj) + ) + + c.pyapi.decref(data_obj) + c.pyapi.decref(indices_obj) + c.pyapi.decref(indptr_obj) + c.pyapi.decref(shape_obj) + + return obj diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index d3f7b6e493..eefade378f 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -54,6 +54,7 @@ import numpy as np import pytensor +import pytensor.link.utils as link_utils from pytensor import tensor as at from pytensor.compile.builders import construct_nominal_fgraph, infer_shape from pytensor.compile.function.pfunc import pfunc @@ -75,7 +76,6 @@ from pytensor.graph.utils import InconsistencyError, MissingInputError from pytensor.link.c.basic import CLinker from pytensor.link.c.exceptions import MissingGXX -from pytensor.link.utils import raise_with_op from pytensor.printing import op_debug_information from pytensor.scan.utils import ScanProfileStats, Validator, forced_replace, safe_new from pytensor.tensor.basic import as_tensor_variable @@ -1627,7 +1627,7 @@ def p(node, inputs, outputs): if hasattr(self.fn.vm, "position_of_error") and hasattr( self.fn.vm, "thunks" ): - raise_with_op( + link_utils.raise_with_op( self.fn.maker.fgraph, self.fn.vm.nodes[self.fn.vm.position_of_error], self.fn.vm.thunks[self.fn.vm.position_of_error], @@ -1930,7 +1930,7 @@ def perform(self, node, inputs, output_storage, params=None): # done by raise_with_op is not implemented in C. if hasattr(vm, "thunks"): # For the CVM - raise_with_op( + link_utils.raise_with_op( self.fn.maker.fgraph, vm.nodes[vm.position_of_error], vm.thunks[vm.position_of_error], @@ -1940,7 +1940,7 @@ def perform(self, node, inputs, output_storage, params=None): # We don't have access from python to all the # temps values So for now, we just don't print # the extra shapes/strides info - raise_with_op( + link_utils.raise_with_op( self.fn.maker.fgraph, vm.nodes[vm.position_of_error] ) else: diff --git a/tests/link/numba/test_sparse.py b/tests/link/numba/test_sparse.py new file mode 100644 index 0000000000..39227fb19f --- /dev/null +++ b/tests/link/numba/test_sparse.py @@ -0,0 +1,40 @@ +import numba +import numpy as np +import scipy as sp + +# Load Numba customizations +import pytensor.link.numba.dispatch.sparse # noqa: F401 + + +def test_sparse_unboxing(): + @numba.njit + def test_unboxing(x, y): + return x.shape, y.shape + + x_val = sp.sparse.csr_matrix(np.eye(100)) + y_val = sp.sparse.csc_matrix(np.eye(101)) + + res = test_unboxing(x_val, y_val) + + assert res == (x_val.shape, y_val.shape) + + +def test_sparse_boxing(): + @numba.njit + def test_boxing(x, y): + return x, y + + x_val = sp.sparse.csr_matrix(np.eye(100)) + y_val = sp.sparse.csc_matrix(np.eye(101)) + + res_x_val, res_y_val = test_boxing(x_val, y_val) + + assert np.array_equal(res_x_val.data, x_val.data) + assert np.array_equal(res_x_val.indices, x_val.indices) + assert np.array_equal(res_x_val.indptr, x_val.indptr) + assert res_x_val.shape == x_val.shape + + assert np.array_equal(res_y_val.data, y_val.data) + assert np.array_equal(res_y_val.indices, y_val.indices) + assert np.array_equal(res_y_val.indptr, y_val.indptr) + assert res_y_val.shape == y_val.shape