diff --git a/tests/tensor/linalg/test_rewriting.py b/tests/tensor/linalg/test_rewriting.py index f1ea2e1af3..c6f5fae851 100644 --- a/tests/tensor/linalg/test_rewriting.py +++ b/tests/tensor/linalg/test_rewriting.py @@ -251,3 +251,53 @@ def test_decomposition_reused_preserves_check_finite(assume_a, counter): assert fn_opt(A_valid, b1_valid * np.nan, b2_valid) with pytest.raises(ValueError, match="array must not contain infs or NaNs"): assert fn_opt(A_valid * np.nan, b1_valid, b2_valid) + + +@pytest.mark.parametrize( + "lower_first", [True, False], ids=["lower_first", "upper_first"] +) +def test_cho_solve_handles_lower_flags(lower_first): + rewrite_name = reuse_decomposition_multiple_solves.__name__ + A = tensor("A", shape=(5, None)) + b = tensor("b", shape=(5,)) + + x1 = solve(A, b, assume_a="pos", lower=lower_first, check_finite=False) + x2 = solve(A.mT, b, assume_a="pos", lower=not lower_first, check_finite=False) + + dx1_dA = grad(x1.sum(), A) + dx2_dA = grad(x2.sum(), A) + + fn = function([A, b], [x1, dx1_dA, x2, dx2_dA]) + fn_no_rewrite = function( + [A, b], + [x1, dx1_dA, x2, dx2_dA], + mode=get_default_mode().excluding(rewrite_name), + ) + + rng = np.random.default_rng() + L_values = rng.normal(size=(5, 5)).astype(config.floatX) + A_values = L_values @ L_values.T # Ensure A is positive definite + + if lower_first: + A_values[np.triu_indices(5, k=1)] = np.nan + else: + A_values[np.tril_indices(5, k=-1)] = np.nan + + b_values = rng.normal(size=(5,)).astype(config.floatX) + + # This computation should not raise an error, and none of them should be NaN + res = fn(A_values, b_values) + expected_res = fn_no_rewrite(A_values, b_values) + + for x, expected_x in zip(res, expected_res): + assert np.isfinite(x).all() + np.testing.assert_allclose( + x, + expected_x, + atol=1e-6 if config.floatX == "float64" else 1e-3, + rtol=1e-6 if config.floatX == "float64" else 1e-3, + ) + + # If we put the NaN in the wrong place, it should raise an error + with pytest.raises(np.linalg.LinAlgError): + fn(A_values.T, b_values)