Skip to content

Commit

Permalink
c++ test
Browse files Browse the repository at this point in the history
  • Loading branch information
decade-afk committed Feb 14, 2025
1 parent 8f3a34b commit bce6474
Show file tree
Hide file tree
Showing 10 changed files with 397 additions and 71 deletions.
2 changes: 2 additions & 0 deletions paddle/phi/backends/dynload/rocsolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ extern void *rocsolver_dso_handle;
__macro(rocsolver_dpotrs); \
__macro(rocsolver_cpotrs); \
__macro(rocsolver_zpotrs); \
__macro(rocsolver_sgetrs); \
__macro(rocsolver_dgetrs); \
__macro(rocsolver_sgetrf); \
__macro(rocsolver_dgetrf); \
__macro(rocsolver_cgetrf); \
Expand Down
8 changes: 4 additions & 4 deletions paddle/phi/kernels/cpu/lu_solve_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ void LuSolveKernel(const Context& dev_ctx,
reinterpret_cast<int*>(const_cast<int*>(pivots.data<int>()));

for (int i = 0; i < batchsize; i++) {
auto* out_data_item = &out_data[i * lda * n_int];
auto* lu_data_item = &lu_data[i * ldb * nrhs_int];
auto* out_data_item = &out_data[i * lda * nrhs_int];
auto* lu_data_item = &lu_data[i * ldb * n_int];
auto* pivots_data_item = &pivots_data[i * n_int];
phi::funcs::lapackLuSolve<T>(trans_char,
n_int,
Expand All @@ -73,8 +73,8 @@ void LuSolveKernel(const Context& dev_ctx,
info,
0,
phi::errors::PreconditionNotMet(
"LU solve failed with error code %d. Check if matrix is singular.",
info));
"LU solve failed with error code %d. Check if matrix is singular.",
info));
}
*out = Transpose2DTo6D<Context, T>(dev_ctx, *out);
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/funcs/lapack/lapack_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ void lapackLuSolve<double>(char trans,
double *b,
int ldb,
int *info) {
dynload::dgetrs_(&trans, &n, &nrhs, a, &lda, ipiv, b, &ldb, info);
dynload::dgetrs_(&trans, &n, &nrhs, a, &lda, ipiv, b, &ldb, info);
}

template <>
Expand All @@ -54,7 +54,7 @@ void lapackLuSolve<float>(char trans,
float *b,
int ldb,
int *info) {
dynload::sgetrs_(&trans, &n, &nrhs, a, &lda, ipiv, b, &ldb, info);
dynload::sgetrs_(&trans, &n, &nrhs, a, &lda, ipiv, b, &ldb, info);
}

// eigh
Expand Down
5 changes: 0 additions & 5 deletions paddle/phi/kernels/gpu/lu_solve_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef PADDLE_WITH_HIP
// HIP not support cusolver

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

Expand All @@ -23,5 +20,3 @@

PD_REGISTER_KERNEL(
lu_solve_grad, GPU, ALL_LAYOUT, phi::LuSolveGradKernel, float, double) {}

#endif // not PADDLE_WITH_HIP
106 changes: 87 additions & 19 deletions paddle/phi/kernels/gpu/lu_solve_kernle.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef PADDLE_WITH_HIP
// HIP not support cusolver

#ifdef PADDLE_WITH_HIP
#include "paddle/phi/backends/dynload/rocsolver.h"
#else
#include "paddle/phi/backends/dynload/cusolver.h"
#endif

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

Expand All @@ -25,6 +27,47 @@

namespace phi {

#ifdef PADDLE_WITH_HIP
template <typename T>
void rocsolver_getrs(const rocsolver_handle& cusolverH,
rocsolver_operation_t trans,
int n,
int nrhs,
T* a,
int lda,
int* ipiv,
T* b,
int ldb);

template <>
void rocsolver_getrs<float>(const rocsolver_handle& cusolverH,
rocsolver_operation_t trans,
int n,
int nrhs,
float* a,
int lda,
int* ipiv,
float* b,
int ldb) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::rocsolver_sgetrs(
cusolverH, trans, n, nrhs, a, lda, ipiv, b, ldb));
}

template <>
void rocsolver_getrs<double>(const rocsolver_handle& cusolverH,
rocsolver_operation_t trans,
int n,
int nrhs,
double* a,
int lda,
int* ipiv,
double* b,
int ldb) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::rocsolver_dgetrs(
cusolverH, trans, n, nrhs, a, lda, ipiv, b, ldb));
}

#else
template <typename T>
void cusolver_getrs(const cusolverDnHandle_t& cusolverH,
cublasOperation_t trans,
Expand All @@ -39,7 +82,7 @@ void cusolver_getrs(const cusolverDnHandle_t& cusolverH,

template <>
void cusolver_getrs<float>(const cusolverDnHandle_t& cusolverH,
cublasOperation_t trans,
cublasOperation_t trans,
int n,
int nrhs,
float* a,
Expand All @@ -54,25 +97,27 @@ void cusolver_getrs<float>(const cusolverDnHandle_t& cusolverH,

template <>
void cusolver_getrs<double>(const cusolverDnHandle_t& cusolverH,
cublasOperation_t trans,
int n,
int nrhs,
double* a,
int lda,
int* ipiv,
double* b,
int ldb,
int* info) {
cublasOperation_t trans,
int n,
int nrhs,
double* a,
int lda,
int* ipiv,
double* b,
int ldb,
int* info) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnDgetrs(
cusolverH, trans, n, nrhs, a, lda, ipiv, b, ldb, info));
}
#endif

template <typename T, typename Context>
void LuSolveKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& lu,
const DenseTensor& pivots,
const std::string& trans,
const bool left,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
// Copy x to out since cusolverDn*getrs overwrites the input
Expand All @@ -82,6 +127,19 @@ void LuSolveKernel(const Context& dev_ctx,
auto x_dims = x.dims();
auto lu_dims = lu.dims();

#ifdef PADDLE_WITH_HIP
rocsolver_operation_t trans_op;
if (trans == "N") {
trans_op = rocblas_operation_none;
} else if (trans == "T") {
trans_op = rocblas_operation_transpose;
} else if (trans == "C") {
trans_op = rocblas_operation_conjugate_transpose;
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"trans must be one of ['N', 'T', 'C'], but got %s", trans));
}
#else
cublasOperation_t trans_op;
if (trans == "N") {
trans_op = CUBLAS_OP_N;
Expand All @@ -93,6 +151,7 @@ void LuSolveKernel(const Context& dev_ctx,
PADDLE_THROW(phi::errors::InvalidArgument(
"trans must be one of ['N', 'T', 'C'], but got %s", trans));
}
#endif
int n = static_cast<int>(lu_dims[lu_dims.size() - 1]);
int nrhs = static_cast<int>(x_dims[x_dims.size() - 1]);
int lda = std::max(1, n);
Expand All @@ -112,9 +171,21 @@ void LuSolveKernel(const Context& dev_ctx,
reinterpret_cast<int*>(const_cast<int*>(pivots.data<int>()));
for (int i = 0; i < batchsize; i++) {
auto handle = dev_ctx.cusolver_dn_handle();
auto* out_data_item = &out_data[i * n * n];
auto* lu_data_item = &lu_data[i * n * n];
auto* out_data_item = &out_data[i * lda * nrhs];
auto* lu_data_item = &lu_data[i * ldb * n];
auto* pivots_data_item = &pivots_data[i * n];
#ifdef PADDLE_WITH_HIP
rocsolver_getrs<T>(handle,
trans_op,
n,
nrhs,
lu_data_item,
lda,
pivots_data_item,
out_data_item,
ldb);
d_info = 0;
#else
cusolver_getrs<T>(handle,
trans_op,
n,
Expand All @@ -125,15 +196,12 @@ void LuSolveKernel(const Context& dev_ctx,
out_data_item,
ldb,
d_info);
#endif
}
// Synchronize to ensure the solve is complete
dev_ctx.Wait();
*out = Transpose2DTo6D<Context, T>(dev_ctx, *out);
}

} // namespace phi

PD_REGISTER_KERNEL(
lu_solve, GPU, ALL_LAYOUT, phi::LuSolveKernel, float, double) {}

#endif // not PADDLE_WITH_HIP
2 changes: 1 addition & 1 deletion paddle/phi/kernels/impl/lu_solve_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ void LuSolveGradKernel(const Context& dev_ctx,

auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
auto out_grad_dims = out_grad.dims();
auto mat_dim_l =
auto mat_dim_l =
phi::funcs::CreateMatrixDescriptor(out_grad_dims, 0, false);
auto out_mH_dims = out_mH.dims();
auto mat_dim_g = phi::funcs::CreateMatrixDescriptor(out_mH_dims, 0, true);
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/ops/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1994,8 +1994,8 @@
inplace : (out_grad -> x_grad)

- backward_op : lu_solve_grad
forward : lu_solve (Tensor x, Tensor lu, Tensor pivots, str trans = "N") -> Tensor(out)
args : (Tensor x, Tensor lu, Tensor pivots, Tensor out, Tensor out_grad, str trans = "N")
forward : lu_solve (Tensor x, Tensor lu, Tensor pivots, str trans) -> Tensor(out)
args : (Tensor x, Tensor lu, Tensor pivots, Tensor out, Tensor out_grad, str trans)
output : Tensor(x_grad), Tensor(lu_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3149,7 +3149,7 @@
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : lu_solve
args : (Tensor x, Tensor lu, Tensor pivots, str trans = "N")
args : (Tensor x, Tensor lu, Tensor pivots, str trans)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
Expand Down
78 changes: 55 additions & 23 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from paddle import _C_ops
from paddle.base.libpaddle import DataType
from paddle.common_ops_import import VarDesc
from paddle.tensor.math import broadcast_shape
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only

from ..base.data_feeder import (
Expand Down Expand Up @@ -3576,7 +3577,12 @@ def lu(


def lu_solve(
b: Tensor, lu_data: Tensor, pivots: Tensor, trans: str = "N", name=None
b: Tensor,
lu_data: Tensor,
pivots: Tensor,
trans: str = "N",
left: bool = False,
name=None
):
r"""
Computes the solution y to the system of linear equations :math:`Ay = b` ,
Expand All @@ -3590,31 +3596,36 @@ def lu_solve(
pivots (Tensor): Permutation matrix P of LU decomposition. It has
shape :math:`(*, m)`, where :math:`*` is batch dimensions, that can be converted
to a permutation matrix P, with data type int32.
trans (str): The transpose of the matrix A. It can be "N" or "T",
trans (str): The transpose of the matrix A. It can be "N" , "T" or "C,
left (bool): If True, solve the equation :math:`xA = b`. Default: False.
Returns:
Tensor, the same data type as the `b` and `lu_data`.
Examples:
>>> import paddle
>>> import numpy as np
>>> b = paddle.to_tensor(np.array([[1], [3], [3]]), paddle.float32)
>>> LU_data = paddle.to_tensor(np.array([[2, 1, 1], [0.5, 1, 1.5], [0.5, 0, 2.5]]), paddle.float32)
>>> LU_pivots = paddle.to_tensor(np.array([2, 2, 3]), paddle.int32)
>>> y = paddle.lu_solve(b, LU_data, LU_pivots)
>>> # Ax = b
>>> A = paddle.to_tensor([[3, 1], [1, 2]], dtype="float64")
>>> b = paddle.to_tensor([[9, 8], [9, 8]], dtype="float64")
>>> lu, p = paddle.linalg.lu(A)
>>> y = paddle.lu_solve(b, lu, p)
>>> print(y)
Tensor(shape=[2, 2], dtype=float64, place=Place(cpu), stop_gradient=True,
[[1.80000000, 1.60000000],
[3.60000000, 3.20000000]])
>>> # xA = b ==> A^T x^T = b^T
>>> A = paddle.to_tensor([[3, 1], [1, 2]], dtype="float64")
>>> b = paddle.to_tensor([[9, 8], [9, 8]], dtype="float64")
>>> lu, p = paddle.linalg.lu(A)
>>> y = paddle.lu_solve(b, lu, p, trans="T", left=True)
>>> print(y)
[[ 1.9000002]
[-1.4000001]
[ 0.6 ]]
Tensor(shape=[2, 2], dtype=float64, place=Place(cpu), stop_gradient=True,
[[2., 3.],
[2., 3.]])
"""
b = (
b
if b.shape[:-2] == lu_data.shape[:-2]
else paddle.broadcast_to(b, lu_data.shape[:-2] + b.shape[-2:])
)
pivots = (
pivots
if pivots.shape[:-1] == lu_data.shape[:-2]
else paddle.broadcast_to(pivots, lu_data.shape[:-2] + pivots.shape[-1:])
)
if b.ndim < 2:
raise ValueError(
f'`b` dimension must be gather than 2, but got {len(b.shape)}'
Expand All @@ -3639,7 +3650,26 @@ def lu_solve(
raise ValueError(
f'`pivots` shape[-1] must be equal to `lu_data` shape[-1], but got {pivots.shape[-1]} and {lu_data.shape[-1]}'
)

temp_shape = broadcast_shape(b.shape[:-2], lu_data.shape[:-2])
batch_shape = broadcast_shape(temp_shape, pivots.shape[:-1])
b = (
b
if b.shape[:-2] == batch_shape
else paddle.broadcast_to(b, batch_shape + b.shape[-2:])
)
# 实数的共轭矩阵和转置矩阵相同
trans = trans if trans == "N" else "T"
pivots = (
pivots
if pivots.shape[:-1] == batch_shape
else paddle.broadcast_to(pivots, batch_shape + pivots.shape[-1:])
)
lu_data = (
lu_data
if lu_data.shape[:-2] == batch_shape
else paddle.broadcast_to(lu_data, batch_shape + lu_data.shape[-2:])
)
pivots.stop_gradient = True
if in_dynamic_or_pir_mode():
out = _C_ops.lu_solve(b, lu_data, pivots, trans)
else:
Expand All @@ -3652,10 +3682,12 @@ def lu_solve(
out = helper.create_variable_for_type_inference(dtype=b.dtype)
helper.append_op(
type='lu_solve',
inputs={'x': b, 'lu': lu_data, 'pivots': pivots},
outputs={'out': out},
attrs={'trans':trans},
inputs={'X': b, 'Lu': lu_data, 'Pivots': pivots},
outputs={'Out': out},
attrs={'trans': trans},
)
if left:
out = _transpose_last_2dim(out)
return out


Expand Down
Loading

0 comments on commit bce6474

Please sign in to comment.