From bce6474ca53bfc37fdcd4e70a72c5ab5c89ee18e Mon Sep 17 00:00:00 2001 From: decade-afk <3995409050@qq.com> Date: Thu, 13 Feb 2025 09:55:09 +0800 Subject: [PATCH] c++ test --- paddle/phi/backends/dynload/rocsolver.h | 2 + paddle/phi/kernels/cpu/lu_solve_kernel.cc | 8 +- .../kernels/funcs/lapack/lapack_function.cc | 4 +- .../phi/kernels/gpu/lu_solve_grad_kernel.cu | 5 - paddle/phi/kernels/gpu/lu_solve_kernle.cu | 106 ++++++-- .../kernels/impl/lu_solve_grad_kernel_impl.h | 2 +- paddle/phi/ops/yaml/backward.yaml | 4 +- paddle/phi/ops/yaml/ops.yaml | 2 +- python/paddle/tensor/linalg.py | 78 ++++-- test/legacy_test/test_lu_solve_op.py | 257 +++++++++++++++++- 10 files changed, 397 insertions(+), 71 deletions(-) diff --git a/paddle/phi/backends/dynload/rocsolver.h b/paddle/phi/backends/dynload/rocsolver.h index 522620073ee7d..91de0bb26122d 100644 --- a/paddle/phi/backends/dynload/rocsolver.h +++ b/paddle/phi/backends/dynload/rocsolver.h @@ -45,6 +45,8 @@ extern void *rocsolver_dso_handle; __macro(rocsolver_dpotrs); \ __macro(rocsolver_cpotrs); \ __macro(rocsolver_zpotrs); \ + __macro(rocsolver_sgetrs); \ + __macro(rocsolver_dgetrs); \ __macro(rocsolver_sgetrf); \ __macro(rocsolver_dgetrf); \ __macro(rocsolver_cgetrf); \ diff --git a/paddle/phi/kernels/cpu/lu_solve_kernel.cc b/paddle/phi/kernels/cpu/lu_solve_kernel.cc index 012fa9606727b..3ebfac6eba45d 100644 --- a/paddle/phi/kernels/cpu/lu_solve_kernel.cc +++ b/paddle/phi/kernels/cpu/lu_solve_kernel.cc @@ -57,8 +57,8 @@ void LuSolveKernel(const Context& dev_ctx, reinterpret_cast(const_cast(pivots.data())); for (int i = 0; i < batchsize; i++) { - auto* out_data_item = &out_data[i * lda * n_int]; - auto* lu_data_item = &lu_data[i * ldb * nrhs_int]; + auto* out_data_item = &out_data[i * lda * nrhs_int]; + auto* lu_data_item = &lu_data[i * ldb * n_int]; auto* pivots_data_item = &pivots_data[i * n_int]; phi::funcs::lapackLuSolve(trans_char, n_int, @@ -73,8 +73,8 @@ void LuSolveKernel(const Context& dev_ctx, info, 0, phi::errors::PreconditionNotMet( - "LU solve failed with error code %d. Check if matrix is singular.", - info)); + "LU solve failed with error code %d. Check if matrix is singular.", + info)); } *out = Transpose2DTo6D(dev_ctx, *out); } diff --git a/paddle/phi/kernels/funcs/lapack/lapack_function.cc b/paddle/phi/kernels/funcs/lapack/lapack_function.cc index 78f685d43383f..9ea4d5ec67425 100644 --- a/paddle/phi/kernels/funcs/lapack/lapack_function.cc +++ b/paddle/phi/kernels/funcs/lapack/lapack_function.cc @@ -41,7 +41,7 @@ void lapackLuSolve(char trans, double *b, int ldb, int *info) { - dynload::dgetrs_(&trans, &n, &nrhs, a, &lda, ipiv, b, &ldb, info); + dynload::dgetrs_(&trans, &n, &nrhs, a, &lda, ipiv, b, &ldb, info); } template <> @@ -54,7 +54,7 @@ void lapackLuSolve(char trans, float *b, int ldb, int *info) { - dynload::sgetrs_(&trans, &n, &nrhs, a, &lda, ipiv, b, &ldb, info); + dynload::sgetrs_(&trans, &n, &nrhs, a, &lda, ipiv, b, &ldb, info); } // eigh diff --git a/paddle/phi/kernels/gpu/lu_solve_grad_kernel.cu b/paddle/phi/kernels/gpu/lu_solve_grad_kernel.cu index be82bc7bb384c..ce69312b3c9a5 100644 --- a/paddle/phi/kernels/gpu/lu_solve_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/lu_solve_grad_kernel.cu @@ -12,9 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef PADDLE_WITH_HIP -// HIP not support cusolver - #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" @@ -23,5 +20,3 @@ PD_REGISTER_KERNEL( lu_solve_grad, GPU, ALL_LAYOUT, phi::LuSolveGradKernel, float, double) {} - -#endif // not PADDLE_WITH_HIP diff --git a/paddle/phi/kernels/gpu/lu_solve_kernle.cu b/paddle/phi/kernels/gpu/lu_solve_kernle.cu index ec3107a7e26df..6bd47d9734239 100644 --- a/paddle/phi/kernels/gpu/lu_solve_kernle.cu +++ b/paddle/phi/kernels/gpu/lu_solve_kernle.cu @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef PADDLE_WITH_HIP -// HIP not support cusolver - +#ifdef PADDLE_WITH_HIP +#include "paddle/phi/backends/dynload/rocsolver.h" +#else #include "paddle/phi/backends/dynload/cusolver.h" +#endif + #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" @@ -25,6 +27,47 @@ namespace phi { +#ifdef PADDLE_WITH_HIP +template +void rocsolver_getrs(const rocsolver_handle& cusolverH, + rocsolver_operation_t trans, + int n, + int nrhs, + T* a, + int lda, + int* ipiv, + T* b, + int ldb); + +template <> +void rocsolver_getrs(const rocsolver_handle& cusolverH, + rocsolver_operation_t trans, + int n, + int nrhs, + float* a, + int lda, + int* ipiv, + float* b, + int ldb) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::rocsolver_sgetrs( + cusolverH, trans, n, nrhs, a, lda, ipiv, b, ldb)); +} + +template <> +void rocsolver_getrs(const rocsolver_handle& cusolverH, + rocsolver_operation_t trans, + int n, + int nrhs, + double* a, + int lda, + int* ipiv, + double* b, + int ldb) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::rocsolver_dgetrs( + cusolverH, trans, n, nrhs, a, lda, ipiv, b, ldb)); +} + +#else template void cusolver_getrs(const cusolverDnHandle_t& cusolverH, cublasOperation_t trans, @@ -39,7 +82,7 @@ void cusolver_getrs(const cusolverDnHandle_t& cusolverH, template <> void cusolver_getrs(const cusolverDnHandle_t& cusolverH, - cublasOperation_t trans, + cublasOperation_t trans, int n, int nrhs, float* a, @@ -54,18 +97,19 @@ void cusolver_getrs(const cusolverDnHandle_t& cusolverH, template <> void cusolver_getrs(const cusolverDnHandle_t& cusolverH, - cublasOperation_t trans, - int n, - int nrhs, - double* a, - int lda, - int* ipiv, - double* b, - int ldb, - int* info) { + cublasOperation_t trans, + int n, + int nrhs, + double* a, + int lda, + int* ipiv, + double* b, + int ldb, + int* info) { PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnDgetrs( cusolverH, trans, n, nrhs, a, lda, ipiv, b, ldb, info)); } +#endif template void LuSolveKernel(const Context& dev_ctx, @@ -73,6 +117,7 @@ void LuSolveKernel(const Context& dev_ctx, const DenseTensor& lu, const DenseTensor& pivots, const std::string& trans, + const bool left, DenseTensor* out) { dev_ctx.template Alloc(out); // Copy x to out since cusolverDn*getrs overwrites the input @@ -82,6 +127,19 @@ void LuSolveKernel(const Context& dev_ctx, auto x_dims = x.dims(); auto lu_dims = lu.dims(); +#ifdef PADDLE_WITH_HIP + rocsolver_operation_t trans_op; + if (trans == "N") { + trans_op = rocblas_operation_none; + } else if (trans == "T") { + trans_op = rocblas_operation_transpose; + } else if (trans == "C") { + trans_op = rocblas_operation_conjugate_transpose; + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "trans must be one of ['N', 'T', 'C'], but got %s", trans)); + } +#else cublasOperation_t trans_op; if (trans == "N") { trans_op = CUBLAS_OP_N; @@ -93,6 +151,7 @@ void LuSolveKernel(const Context& dev_ctx, PADDLE_THROW(phi::errors::InvalidArgument( "trans must be one of ['N', 'T', 'C'], but got %s", trans)); } +#endif int n = static_cast(lu_dims[lu_dims.size() - 1]); int nrhs = static_cast(x_dims[x_dims.size() - 1]); int lda = std::max(1, n); @@ -112,9 +171,21 @@ void LuSolveKernel(const Context& dev_ctx, reinterpret_cast(const_cast(pivots.data())); for (int i = 0; i < batchsize; i++) { auto handle = dev_ctx.cusolver_dn_handle(); - auto* out_data_item = &out_data[i * n * n]; - auto* lu_data_item = &lu_data[i * n * n]; + auto* out_data_item = &out_data[i * lda * nrhs]; + auto* lu_data_item = &lu_data[i * ldb * n]; auto* pivots_data_item = &pivots_data[i * n]; +#ifdef PADDLE_WITH_HIP + rocsolver_getrs(handle, + trans_op, + n, + nrhs, + lu_data_item, + lda, + pivots_data_item, + out_data_item, + ldb); + d_info = 0; +#else cusolver_getrs(handle, trans_op, n, @@ -125,9 +196,8 @@ void LuSolveKernel(const Context& dev_ctx, out_data_item, ldb, d_info); +#endif } - // Synchronize to ensure the solve is complete - dev_ctx.Wait(); *out = Transpose2DTo6D(dev_ctx, *out); } @@ -135,5 +205,3 @@ void LuSolveKernel(const Context& dev_ctx, PD_REGISTER_KERNEL( lu_solve, GPU, ALL_LAYOUT, phi::LuSolveKernel, float, double) {} - -#endif // not PADDLE_WITH_HIP diff --git a/paddle/phi/kernels/impl/lu_solve_grad_kernel_impl.h b/paddle/phi/kernels/impl/lu_solve_grad_kernel_impl.h index 8cd3956d3c357..893b39e3b606b 100644 --- a/paddle/phi/kernels/impl/lu_solve_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/lu_solve_grad_kernel_impl.h @@ -60,7 +60,7 @@ void LuSolveGradKernel(const Context& dev_ctx, auto blas = phi::funcs::GetBlas(dev_ctx); auto out_grad_dims = out_grad.dims(); - auto mat_dim_l = + auto mat_dim_l = phi::funcs::CreateMatrixDescriptor(out_grad_dims, 0, false); auto out_mH_dims = out_mH.dims(); auto mat_dim_g = phi::funcs::CreateMatrixDescriptor(out_mH_dims, 0, true); diff --git a/paddle/phi/ops/yaml/backward.yaml b/paddle/phi/ops/yaml/backward.yaml index 157e042097e68..a41ec9266b01e 100644 --- a/paddle/phi/ops/yaml/backward.yaml +++ b/paddle/phi/ops/yaml/backward.yaml @@ -1994,8 +1994,8 @@ inplace : (out_grad -> x_grad) - backward_op : lu_solve_grad - forward : lu_solve (Tensor x, Tensor lu, Tensor pivots, str trans = "N") -> Tensor(out) - args : (Tensor x, Tensor lu, Tensor pivots, Tensor out, Tensor out_grad, str trans = "N") + forward : lu_solve (Tensor x, Tensor lu, Tensor pivots, str trans) -> Tensor(out) + args : (Tensor x, Tensor lu, Tensor pivots, Tensor out, Tensor out_grad, str trans) output : Tensor(x_grad), Tensor(lu_grad) infer_meta : func : GeneralBinaryGradInferMeta diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 326b7353c273a..ffbe505744bd0 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3149,7 +3149,7 @@ interfaces : paddle::dialect::InferSymbolicShapeInterface - op : lu_solve - args : (Tensor x, Tensor lu, Tensor pivots, str trans = "N") + args : (Tensor x, Tensor lu, Tensor pivots, str trans) output : Tensor(out) infer_meta : func : UnchangedInferMeta diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 3e1f936d4b58d..37873dbc92706 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -22,6 +22,7 @@ from paddle import _C_ops from paddle.base.libpaddle import DataType from paddle.common_ops_import import VarDesc +from paddle.tensor.math import broadcast_shape from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only from ..base.data_feeder import ( @@ -3576,7 +3577,12 @@ def lu( def lu_solve( - b: Tensor, lu_data: Tensor, pivots: Tensor, trans: str = "N", name=None + b: Tensor, + lu_data: Tensor, + pivots: Tensor, + trans: str = "N", + left: bool = False, + name=None ): r""" Computes the solution y to the system of linear equations :math:`Ay = b` , @@ -3590,31 +3596,36 @@ def lu_solve( pivots (Tensor): Permutation matrix P of LU decomposition. It has shape :math:`(*, m)`, where :math:`*` is batch dimensions, that can be converted to a permutation matrix P, with data type int32. - trans (str): The transpose of the matrix A. It can be "N" or "T", + trans (str): The transpose of the matrix A. It can be "N" , "T" or "C, + left (bool): If True, solve the equation :math:`xA = b`. Default: False. + Returns: Tensor, the same data type as the `b` and `lu_data`. Examples: >>> import paddle >>> import numpy as np - >>> b = paddle.to_tensor(np.array([[1], [3], [3]]), paddle.float32) - >>> LU_data = paddle.to_tensor(np.array([[2, 1, 1], [0.5, 1, 1.5], [0.5, 0, 2.5]]), paddle.float32) - >>> LU_pivots = paddle.to_tensor(np.array([2, 2, 3]), paddle.int32) - >>> y = paddle.lu_solve(b, LU_data, LU_pivots) + >>> # Ax = b + >>> A = paddle.to_tensor([[3, 1], [1, 2]], dtype="float64") + >>> b = paddle.to_tensor([[9, 8], [9, 8]], dtype="float64") + >>> lu, p = paddle.linalg.lu(A) + >>> y = paddle.lu_solve(b, lu, p) + + >>> print(y) + Tensor(shape=[2, 2], dtype=float64, place=Place(cpu), stop_gradient=True, + [[1.80000000, 1.60000000], + [3.60000000, 3.20000000]]) + + >>> # xA = b ==> A^T x^T = b^T + >>> A = paddle.to_tensor([[3, 1], [1, 2]], dtype="float64") + >>> b = paddle.to_tensor([[9, 8], [9, 8]], dtype="float64") + >>> lu, p = paddle.linalg.lu(A) + >>> y = paddle.lu_solve(b, lu, p, trans="T", left=True) + >>> print(y) - [[ 1.9000002] - [-1.4000001] - [ 0.6 ]] + Tensor(shape=[2, 2], dtype=float64, place=Place(cpu), stop_gradient=True, + [[2., 3.], + [2., 3.]]) """ - b = ( - b - if b.shape[:-2] == lu_data.shape[:-2] - else paddle.broadcast_to(b, lu_data.shape[:-2] + b.shape[-2:]) - ) - pivots = ( - pivots - if pivots.shape[:-1] == lu_data.shape[:-2] - else paddle.broadcast_to(pivots, lu_data.shape[:-2] + pivots.shape[-1:]) - ) if b.ndim < 2: raise ValueError( f'`b` dimension must be gather than 2, but got {len(b.shape)}' @@ -3639,7 +3650,26 @@ def lu_solve( raise ValueError( f'`pivots` shape[-1] must be equal to `lu_data` shape[-1], but got {pivots.shape[-1]} and {lu_data.shape[-1]}' ) - + temp_shape = broadcast_shape(b.shape[:-2], lu_data.shape[:-2]) + batch_shape = broadcast_shape(temp_shape, pivots.shape[:-1]) + b = ( + b + if b.shape[:-2] == batch_shape + else paddle.broadcast_to(b, batch_shape + b.shape[-2:]) + ) + # 实数的共轭矩阵和转置矩阵相同 + trans = trans if trans == "N" else "T" + pivots = ( + pivots + if pivots.shape[:-1] == batch_shape + else paddle.broadcast_to(pivots, batch_shape + pivots.shape[-1:]) + ) + lu_data = ( + lu_data + if lu_data.shape[:-2] == batch_shape + else paddle.broadcast_to(lu_data, batch_shape + lu_data.shape[-2:]) + ) + pivots.stop_gradient = True if in_dynamic_or_pir_mode(): out = _C_ops.lu_solve(b, lu_data, pivots, trans) else: @@ -3652,10 +3682,12 @@ def lu_solve( out = helper.create_variable_for_type_inference(dtype=b.dtype) helper.append_op( type='lu_solve', - inputs={'x': b, 'lu': lu_data, 'pivots': pivots}, - outputs={'out': out}, - attrs={'trans':trans}, + inputs={'X': b, 'Lu': lu_data, 'Pivots': pivots}, + outputs={'Out': out}, + attrs={'trans': trans}, ) + if left: + out = _transpose_last_2dim(out) return out diff --git a/test/legacy_test/test_lu_solve_op.py b/test/legacy_test/test_lu_solve_op.py index d1ea68398c046..474b784ecd3c4 100644 --- a/test/legacy_test/test_lu_solve_op.py +++ b/test/legacy_test/test_lu_solve_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,9 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy -import itertools -import os import unittest import numpy as np @@ -25,31 +22,263 @@ from paddle.base import core -class TestSolveOpAPI_1(unittest.TestCase): +def _transpose_last_2dim(x): + """transpose the last 2 dimension of a tensor""" + x_new_dims = list(range(len(x.shape))) + x_new_dims[-1], x_new_dims[-2] = x_new_dims[-2], x_new_dims[-1] + x = paddle.transpose(x, x_new_dims) + return x + + +def get_inandout(A, b, trans, left): + paddle.disable_static(base.CPUPlace()) + A = paddle.randn(A, dtype='float64') + b = paddle.randn(b, dtype='float64') + if trans=="N" and left==False: # Ax = b + LU, pivots = paddle.linalg.lu(A) + out = paddle.linalg.solve(A, b) + elif trans=="T" and left==False: # A^Tx = b + A = _transpose_last_2dim(A) + out = paddle.linalg.solve(A, b) + elif trans=="N" and left==True: # xA = b ==> A^Tx^T = b^T + LU, pivots = paddle.linalg.lu(A) + # A = _transpose_last_2dim(A) + # b = _transpose_last_2dim(b) + out = paddle.linalg.solve(A, b, False) + # out = _transpose_last_2dim(out) + elif trans=="T" and left==True: # xA^T = b ==> Ax^T = b^T + LU, pivots = paddle.linalg.lu(A) + # b = _transpose_last_2dim(b) + A = _transpose_last_2dim(A) + out = paddle.linalg.solve(A, b, False) + # out = _transpose_last_2dim(out) + LU = LU.numpy() + pivots = pivots.numpy() + b = b.numpy() + out = out.numpy() + paddle.enable_static() + return LU, pivots, b, out + + +# class TestLuSolveOp(OpTest): +# def setUp(self): + +# self.python_api = paddle.linalg.lu_solve +# self.op_type = "lu_solve" +# self.init_value() +# self.LU, self.pivots, self.b, self.out = get_inandout(self.A_shape, self.b_shape, self.trans, self.left) +# self.inputs = { +# 'X': self.b, +# 'Lu': self.LU, +# 'Pivots': self.pivots, +# } +# self.attrs = {'trans': self.trans, 'left': self.left} +# self.outputs = {'Out': self.out} + +# def init_value(self): +# self.A_shape = [15, 15] +# self.b_shape = [15, 10] +# self.trans = "N" +# self.left = False + +# def test_check_output(self): +# paddle.enable_static() +# self.check_output(check_pir=True) +# paddle.disable_static() + +# def test_check_grad(self): +# paddle.enable_static() +# self.check_grad(['X', 'Lu'], ['Out'], check_pir=True) +# paddle.disable_static() + + +class TestLuSolveOpAPI(unittest.TestCase): def setUp(self): - self.A = paddle.to_tensor([[2, 5],[4, 8]], dtype='float32') - self.b = paddle.to_tensor([[1, 1],[1, 1]], dtype='float32') - self.LU, self.pivots = paddle.linalg.lu(self.A) + self.init_value() + self.LU, self.pivots, self.b, self.out = get_inandout(self.A_shape, self.b_shape, self.trans, self.left) self.place = [] self.place.append(base.CPUPlace()) - # if core.is_compiled_with_cuda(): - # self.place.append(base.CUDAPlace(0)) + if core.is_compiled_with_cuda(): + self.place.append(base.CUDAPlace(0)) + + def init_value(self): + # Ax = b + self.A_shape = [15, 15] + self.b_shape = [15, 10] + self.trans = "N" + self.left = False def test_dygraph(self): def run(place): paddle.disable_static(place) - lu_solve_x = paddle.linalg.lu_solve(self.b, self.LU, self.pivots) - solve_x = paddle.linalg.solve(self.A, self.b) + LU = paddle.to_tensor(self.LU) + pivots = paddle.to_tensor(self.pivots) + b = paddle.to_tensor(self.b) + lu_solve_x = paddle.linalg.lu_solve(b, LU, pivots, self.trans, self.left) np.testing.assert_allclose( - lu_solve_x.numpy(), solve_x.numpy(), rtol=1e-05 + lu_solve_x.numpy(), self.out, rtol=1e-05 ) paddle.enable_static() for place in self.place: run(place) + + def test_static(self): + def run(place): + paddle.enable_static(place) + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + b = paddle.static.data( + name='X', shape=self.b_shape, dtype='float64' + ) + LU = paddle.static.data( + name='Lu', shape=self.LU.shape, dtype='float64' + ) + pivots = paddle.static.data( + name='Pivots', shape=self.pivots.shape, dtype='float64' + ) + lu_solve_x = paddle.linalg.lu_solve(b, LU, pivots, self.trans, self.left) + exe = base.Executor(place) + fetches = exe.run( + paddle.static.default_main_program(), + feed={ + 'X': self.b, + 'Lu': self.LU, + 'Pivots': self.pivots, + }, + fetch_list=[lu_solve_x], + ) + np.testing.assert_allclose( + fetches[0], self.out, rtol=1e-05 + ) + paddle.disable_static() + + for place in self.place: + run(place) + + +class TestLuSolveOpAPI2(TestLuSolveOpAPI): + def init_value(self): + # Ax = b + self.A_shape = [2, 15, 15] + self.b_shape = [1, 15, 10] + self.trans = "N" + self.left = False + + +class TestLuSolveOpAPI3(TestLuSolveOpAPI): + def init_value(self): + # xA = b + self.A_shape = [15, 15] + self.b_shape = [15, 10] + self.trans = "N" + self.left = True + + +class TestLuSolveOpAPI4(TestLuSolveOpAPI): + def init_value(self): + # xA = b + self.A_shape = [2, 15, 15] + self.b_shape = [1, 15, 10] + self.trans = "N" + self.left = True + + +class TestLuSolveOpAPI5(TestLuSolveOpAPI): + def init_value(self): + # xA^T = b + self.A_shape = [15, 15] + self.b_shape = [15, 10] + self.trans = "T" + self.left = True + + +class TestLuSolveOpAPI6(TestLuSolveOpAPI): + def init_value(self): + # xA^T = b + self.A_shape = [2, 15, 15] + self.b_shape = [1, 15, 10] + self.trans = "T" + self.left = True + +class TestLuSolveOpAPI5(TestLuSolveOpAPI): + def init_value(self): + # A^Tx = b + self.A_shape = [15, 15] + self.b_shape = [15, 10] + self.trans = "T" + self.left = False + + +class TestLuSolveOpAPI6(TestLuSolveOpAPI): + def init_value(self): + # A^Tx = b + self.A_shape = [2, 15, 15] + self.b_shape = [1, 15, 10] + self.trans = "T" + self.left = False + + +class TestLSolveError(unittest.TestCase): + def test_errors(self): + with paddle.base.dygraph.guard(): + # The size of b should gather than 2. + def test_b_size(): + b = paddle.randn([3]) + lu = paddle.randn([3, 3]) + pivots = paddle.randn([3]) + paddle.linalg.lu_solve(b, lu, pivots) + + self.assertRaises(ValueError, test_b_size) + + # The size of lu should gather than 2. + def test_lu_size(): + b = paddle.randn([3, 1]) + lu = paddle.randn([3]) + pivots = paddle.randn([3]) + paddle.linalg.lu_solve(b, lu, pivots) + + self.assertRaises(ValueError, test_lu_size) + + # The size of pivots should gather than 1. + def test_pivots_size(): + b = paddle.randn([3, 1]) + lu = paddle.randn([3, 3]) + pivots = paddle.randn([]) + paddle.linalg.lu_solve(b, lu, pivots) + + self.assertRaises(ValueError, test_pivots_size) + + # b.shape[-2] should equal to lu.shape[-2]. + def test_b_lu_shape(): + b = paddle.randn([1, 3]) + lu = paddle.randn([3, 3]) + pivots = paddle.randn([3]) + paddle.linalg.lu_solve(b, lu, pivots) + + self.assertRaises(ValueError, test_b_lu_shape) + + # lu.shape[-1] should equal to pivots.shape[-1]. + def test_b_pivots_shape(): + b = paddle.randn([3, 1]) + lu = paddle.randn([3, 3]) + pivots = paddle.randn([2]) + paddle.linalg.lu_solve(b, lu, pivots) + + self.assertRaises(ValueError, test_b_pivots_shape) + + # lu.shape[-2] should equal to lu.shape[-1]. + def test_lu_shape(): + b = paddle.randn([3, 1]) + lu = paddle.randn([3, 2]) + pivots = paddle.randn([3]) + paddle.linalg.lu_solve(b, lu, pivots) + + self.assertRaises(ValueError, test_lu_shape) if __name__ == "__main__": - paddle.enable_static() + paddle.seed(2025) unittest.main()