Skip to content

Commit

Permalink
add cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
decade-afk committed Feb 8, 2025
1 parent 138c1cb commit 7a6ccd5
Show file tree
Hide file tree
Showing 17 changed files with 680 additions and 0 deletions.
1 change: 1 addition & 0 deletions cmake/operators.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ function(op_library TARGET)
list(REMOVE_ITEM hip_srcs "cholesky_op.cu")
list(REMOVE_ITEM hip_srcs "cholesky_solve_op.cu")
list(REMOVE_ITEM hip_srcs "lu_op.cu")
list(REMOVE_ITEM hip_srcs "lu_solve_op.cu")
list(REMOVE_ITEM hip_srcs "matrix_rank_op.cu")
list(REMOVE_ITEM hip_srcs "svd_op.cu")
list(REMOVE_ITEM hip_srcs "eigvalsh_op.cu")
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/backends/dynload/cusolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ CUSOLVER_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUSOLVER_WRAP);
__macro(cusolverDnSgesvdj); \
__macro(cusolverDnDgesvdj); \
__macro(cusolverDnSgetrf); \
__macro(cusolverDnSgetrs); \
__macro(cusolverDnDgetrs); \
__macro(cusolverDnDgetrf); \
__macro(cusolverDnCgetrf); \
__macro(cusolverDnZgetrf); \
Expand Down
22 changes: 22 additions & 0 deletions paddle/phi/backends/dynload/lapack.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,26 @@ extern "C" void dgetrf_(
extern "C" void sgetrf_(
int *m, int *n, float *a, int *lda, int *ipiv, int *info);

// getrs_
extern "C" void sgetrs_(char *trans,
int *n,
int *nrhs,
float *a,
int *lda,
int *ipiv,
float *b,
int *ldb,
int *info);
extern "C" void dgetrs_(char *trans,
int *n,
int *nrhs,
double *a,
int *lda,
int *ipiv,
double *b,
int *ldb,
int *info);

// evd
extern "C" void zheevd_(char *jobz,
char *uplo,
Expand Down Expand Up @@ -339,6 +359,8 @@ extern void *lapack_dso_handle;
#define LAPACK_ROUTINE_EACH(__macro) \
__macro(dgetrf_); \
__macro(sgetrf_); \
__macro(sgetrs_); \
__macro(dgetrs_); \
__macro(zheevd_); \
__macro(cheevd_); \
__macro(dsyevd_); \
Expand Down
24 changes: 24 additions & 0 deletions paddle/phi/kernels/cpu/lu_solve_grad_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

#include "paddle/phi/kernels/impl/lu_solve_grad_kernel_impl.h"
#include "paddle/phi/kernels/lu_solve_grad_kernel.h"


// Register the CPU backward kernel
PD_REGISTER_KERNEL(
lu_solve_grad, CPU, ALL_LAYOUT, phi::LuSolveGradKernel, float, double) {}
84 changes: 84 additions & 0 deletions paddle/phi/kernels/cpu/lu_solve_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/lapack/lapack_function.h"

#include "paddle/phi/core/enforce.h"
#include "paddle/phi/kernels/impl/lu_kernel_impl.h"
#include "paddle/phi/kernels/lu_solve_kernel.h"

namespace phi {

template <typename T, typename Context>
void LuSolveKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& lu,
const DenseTensor& pivots,
const std::string& trans,
DenseTensor* out) {
// Get lu matrix dimensions
auto lu_dims = lu.dims();
// Get x matrix dimensions
auto x_dims = x.dims();

// Allocate output tensor
dev_ctx.template Alloc<T>(out);
// Copy RHS data to output (will be overwritten with solution)
// phi::Copy(dev_ctx, x, x.place(), false, out);
*out = Transpose2DTo6D<Context, T>(dev_ctx, x);

// Prepare LAPACK parameters
char trans_char = (trans == "N") ? 'N' : ((trans == "T") ? 'T' : 'C');
int n_int = lu_dims[lu_dims.size() - 1];
int nrhs_int = x_dims[x_dims.size() - 1];
int lda = std::max(1, n_int); // Leading dimension of A (LU matrix)
int ldb = std::max(1, n_int); // Leading dimension of B (RHS/solution matrix)
int info = 0;

auto outdims = out->dims();
auto outrank = outdims.size();
int batchsize = product(common::slice_ddim(outdims, 0, outrank - 2));
auto out_data = out->data<T>();
auto lu_data = reinterpret_cast<T*>(const_cast<T*>(lu.data<T>()));
auto pivots_data = reinterpret_cast<int*>(const_cast<int*>(pivots.data<int>()));

for (int i = 0; i < batchsize; i++) {
auto* out_data_item = &out_data[i * lda * n_int];
auto* lu_data_item = &lu_data[i * ldb * nrhs_int];
auto* pivots_data_item = &pivots_data[i * n_int];
phi::funcs::lapackLuSolve<T>(
trans_char,
n_int,
nrhs_int,
lu_data_item,
lda,
pivots_data_item,
out_data_item,
ldb,
&info);
PADDLE_ENFORCE_EQ(
info,
0,
phi::errors::PreconditionNotMet(
"LU solve failed with error code %d. Check if matrix is singular.",
info));
}
*out = Transpose2DTo6D<Context, T>(dev_ctx, *out);
}
} // namespace phi

PD_REGISTER_KERNEL(
lu_solve, CPU, ALL_LAYOUT, phi::LuSolveKernel, float, double) {}
27 changes: 27 additions & 0 deletions paddle/phi/kernels/funcs/lapack/lapack_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,33 @@ void lapackLu<float>(int m, int n, float *a, int lda, int *ipiv, int *info) {
dynload::sgetrf_(&m, &n, a, &lda, ipiv, info);
}

// lu_solve
template <>
void lapackLuSolve<double>(char trans,
int n,
int nrhs,
double* a,
int lda,
int* ipiv,
double* b,
int ldb,
int* info) {
dynload::dgetrs_(&trans, &n, &nrhs, a, &lda, ipiv, b, &ldb, info);
}

template <>
void lapackLuSolve<float>(char trans,
int n,
int nrhs,
float* a,
int lda,
int* ipiv,
float* b,
int ldb,
int* info) {
dynload::sgetrs_(&trans, &n, &nrhs, a, &lda, ipiv, b, &ldb, info);
}

// eigh
template <>
void lapackEigh<float>(char jobz,
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/kernels/funcs/lapack/lapack_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ namespace funcs {
template <typename T>
void lapackLu(int m, int n, T *a, int lda, int *ipiv, int *info);

// Lu_solve
template <typename T>
void lapackLuSolve(char trans,
int n,
int nrhs,
T *a,
int lda,
int *ipiv,
T *b,
int ldb,
int *info);

// Eigh
template <typename T, typename ValueType = T>
void lapackEigh(char jobz,
Expand Down
22 changes: 22 additions & 0 deletions paddle/phi/kernels/gpu/lu_solve_grad_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

#include "paddle/phi/kernels/impl/lu_solve_grad_kernel_impl.h"
#include "paddle/phi/kernels/lu_solve_grad_kernel.h"

PD_REGISTER_KERNEL(lu_solve_grad, GPU, ALL_LAYOUT, phi::LuSolveGradKernel, float, double) {
}
138 changes: 138 additions & 0 deletions paddle/phi/kernels/gpu/lu_solve_kernle.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef PADDLE_WITH_HIP
// HIP not support cusolver

#include "paddle/phi/backends/dynload/cusolver.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/kernels/lu_solve_kernel.h"
#include "paddle/phi/kernels/impl/lu_kernel_impl.h"

namespace phi {

template <typename T>
void cusolver_getrs(const cusolverDnHandle_t& cusolverH,
cublasOperation_t trans,
int n,
int nrhs,
T *a,
int lda,
int *ipiv,
T *b,
int ldb,
int *info);

template <>
void cusolver_getrs<float>(const cusolverDnHandle_t& cusolverH,
cublasOperation_t trans,
int n,
int nrhs,
float *a,
int lda,
int *ipiv,
float *b,
int ldb,
int *info) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnSgetrs(
cusolverH, trans, n, nrhs, a, lda, ipiv, b, ldb, info));
}

template <>
void cusolver_getrs<double>(const cusolverDnHandle_t& cusolverH,
cublasOperation_t trans,
int n,
int nrhs,
double *a,
int lda,
int *ipiv,
double *b,
int ldb,
int *info) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cusolverDnDgetrs(
cusolverH, trans, n, nrhs, a, lda, ipiv, b, ldb, info));
}

template <typename T, typename Context>
void LuSolveKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& lu,
const DenseTensor& pivots,
const std::string& trans,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
// Copy x to out since cusolverDn*getrs overwrites the input
// phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
*out = phi::Transpose2DTo6D<Context, T>(dev_ctx, x);
// Validate input dimensions
auto x_dims = x.dims();
auto lu_dims = lu.dims();

cublasOperation_t trans_op;
if (trans == "N") {
trans_op = CUBLAS_OP_N;
} else if (trans == "T") {
trans_op = CUBLAS_OP_T;
} else if (trans == "C") {
trans_op = CUBLAS_OP_C;
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"trans must be one of ['N', 'T', 'C'], but got %s", trans));
}
int n = static_cast<int>(lu_dims[lu_dims.size() - 1]);
int nrhs = static_cast<int>(x_dims[x_dims.size() - 1]);
int lda = std::max(1, n);
int ldb = std::max(1, n);

DenseTensor info_tensor;
info_tensor.Resize({1});
dev_ctx.template Alloc<int>(&info_tensor);
int* d_info = info_tensor.data<int>();

auto outdims = out->dims();
auto outrank = outdims.size();
int batchsize = product(common::slice_ddim(outdims, 0, outrank - 2));
auto out_data = out->data<T>();
auto lu_data = reinterpret_cast<T*>(const_cast<T*>(lu.data<T>()));
auto pivots_data = reinterpret_cast<int*>(const_cast<int*>(pivots.data<int>()));
for (int i = 0; i < batchsize; i++) {
auto handle = dev_ctx.cusolver_dn_handle();
auto* out_data_item = &out_data[i * n * n];
auto* lu_data_item = &lu_data[i * n * n];
auto* pivots_data_item = &pivots_data[i * n];
cusolver_getrs<T>(handle,
trans_op,
n,
nrhs,
lu_data_item,
lda,
pivots_data_item,
out_data_item,
ldb,
d_info);
}
// Synchronize to ensure the solve is complete
dev_ctx.Wait();
*out = Transpose2DTo6D<Context, T>(dev_ctx, *out);
}

} // namespace phi

PD_REGISTER_KERNEL(
lu_solve, GPU, ALL_LAYOUT, phi::LuSolveKernel, float, double) {}

#endif // not PADDLE_WITH_HIP
Loading

0 comments on commit 7a6ccd5

Please sign in to comment.