Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions pytensor/tensor/_linalg/solve/rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,22 @@
from pytensor.tensor.variable import TensorVariable


def decompose_A(A, assume_a):
def decompose_A(A, assume_a, check_finite):
if assume_a == "gen":
return lu_factor(A, check_finite=False)
return lu_factor(A, check_finite=check_finite)
else:
raise NotImplementedError


def solve_lu_decomposed_system(A_decomp, b, b_ndim, assume_a, transposed=False):
if assume_a == "gen":
return lu_solve(A_decomp, b, b_ndim=b_ndim, trans=transposed)
def solve_lu_decomposed_system(A_decomp, b, transposed=False, *, core_solve_op: Solve):
if core_solve_op.assume_a == "gen":
return lu_solve(
A_decomp,
b,
trans=transposed,
b_ndim=core_solve_op.b_ndim,
check_finite=core_solve_op.check_finite,
)
else:
raise NotImplementedError

Expand Down Expand Up @@ -102,14 +108,19 @@ def find_solve_clients(var, assume_a):
):
return None

A_decomp = decompose_A(A, assume_a=assume_a)
# If any Op had check_finite=True, we also do it for the LU decomposition
check_finite_decomp = False
for client, _ in A_solve_clients_and_transpose:
if client.op.core_op.check_finite:
check_finite_decomp = True
break
A_decomp = decompose_A(A, assume_a=assume_a, check_finite=check_finite_decomp)

replacements = {}
for client, transposed in A_solve_clients_and_transpose:
_, b = client.inputs
b_ndim = client.op.core_op.b_ndim
new_x = solve_lu_decomposed_system(
A_decomp, b, b_ndim=b_ndim, assume_a=assume_a, transposed=transposed
A_decomp, b, transposed=transposed, core_solve_op=client.op.core_op
)
[old_x] = client.outputs
new_x = atleast_Nd(new_x, n=old_x.type.ndim).astype(old_x.type.dtype)
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ def tensor(
try:
# Help catching errors with the new tensor API
# Many single letter strings are valid sctypes
if str(name) == "floatX" or (len(str(name)) > 1 and np.dtype(name).type):
if str(name) == "floatX" or (len(str(name)) > 2 and np.dtype(name).type):
raise ValueError(
f"The first and only positional argument of tensor is now `name`. Got {name}.\n"
"This name looks like a dtype, which you should pass as a keyword argument only."
Expand Down
33 changes: 33 additions & 0 deletions tests/tensor/linalg/test_rewriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,36 @@ def test_lu_decomposition_reused_scan(transposed):
resx1 = fn_opt(A_test, x0_test)
rtol = 1e-7 if config.floatX == "float64" else 1e-6
np.testing.assert_allclose(resx0, resx1, rtol=rtol)


def test_lu_decomposition_reused_preserves_check_finite():
# Check that the LU decomposition rewrite preserves the check_finite flag
rewrite_name = reuse_lu_decomposition_multiple_solves.__name__

A = tensor("A", shape=(2, 2))
b1 = tensor("b1", shape=(2,))
b2 = tensor("b2", shape=(2,))

x1 = solve(A, b1, assume_a="gen", check_finite=True)
x2 = solve(A, b2, assume_a="gen", check_finite=False)
fn_opt = function(
[A, b1, b2], [x1, x2], mode=get_default_mode().including(rewrite_name)
)
opt_nodes = fn_opt.maker.fgraph.apply_nodes
assert count_vanilla_solve_nodes(opt_nodes) == 0
assert count_lu_decom_nodes(opt_nodes) == 1
assert count_lu_solve_nodes(opt_nodes) == 2

# We should get an error if A or b1 is non finite
A_valid = np.array([[1, 0], [0, 1]], dtype=A.type.dtype)
b1_valid = np.array([1, 1], dtype=b1.type.dtype)
b2_valid = np.array([1, 1], dtype=b2.type.dtype)

assert fn_opt(A_valid, b1_valid, b2_valid) # Fine
assert fn_opt(
A_valid, b1_valid, b2_valid * np.nan
) # Should not raise (also fine on most LAPACK implementations?)
with pytest.raises(ValueError, match="array must not contain infs or NaNs"):
assert fn_opt(A_valid, b1_valid * np.nan, b2_valid)
with pytest.raises(ValueError, match="array must not contain infs or NaNs"):
assert fn_opt(A_valid * np.nan, b1_valid, b2_valid)