Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: getrs serial internal implementations #2488

Merged
merged 3 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 31 additions & 33 deletions batched/dense/impl/KokkosBatched_Getrs_Serial_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,48 +30,46 @@ struct SerialGetrsInternal {

//// Non-transpose ////
template <>
struct SerialGetrsInternal<Trans::NoTranspose, Algo::Getrs::Unblocked> {
template <typename AViewType, typename PivViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const PivViewType &piv, const BViewType &b) {
KokkosBatched::SerialLaswp<Direct::Forward>::invoke(piv, b);
KokkosBatched::SerialTrsm<Side::Left, Uplo::Lower, Trans::NoTranspose, Diag::Unit, Algo::Trsm::Unblocked>::invoke(
1.0, A, b);
KokkosBatched::SerialTrsm<Side::Left, Uplo::Upper, Trans::NoTranspose, Diag::NonUnit,
Algo::Trsm::Unblocked>::invoke(1.0, A, b);
template <typename AViewType, typename PivViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION int SerialGetrsInternal<Trans::NoTranspose, Algo::Getrs::Unblocked>::invoke(
const AViewType &A, const PivViewType &piv, const BViewType &b) {
KokkosBatched::SerialLaswp<Direct::Forward>::invoke(piv, b);
KokkosBatched::SerialTrsm<Side::Left, Uplo::Lower, Trans::NoTranspose, Diag::Unit, Algo::Trsm::Unblocked>::invoke(
1.0, A, b);
KokkosBatched::SerialTrsm<Side::Left, Uplo::Upper, Trans::NoTranspose, Diag::NonUnit, Algo::Trsm::Unblocked>::invoke(
1.0, A, b);

return 0;
}
};
return 0;
}

//// Transpose ////
template <>
struct SerialGetrsInternal<Trans::Transpose, Algo::Getrs::Unblocked> {
template <typename AViewType, typename PivViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const PivViewType &piv, const BViewType &b) {
KokkosBatched::SerialTrsm<Side::Left, Uplo::Upper, Trans::Transpose, Diag::NonUnit, Algo::Trsm::Unblocked>::invoke(
1.0, A, b);
KokkosBatched::SerialTrsm<Side::Left, Uplo::Lower, Trans::Transpose, Diag::Unit, Algo::Trsm::Unblocked>::invoke(
1.0, A, b);
KokkosBatched::SerialLaswp<Direct::Backward>::invoke(piv, b);
template <typename AViewType, typename PivViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION int SerialGetrsInternal<Trans::Transpose, Algo::Getrs::Unblocked>::invoke(const AViewType &A,
const PivViewType &piv,
const BViewType &b) {
KokkosBatched::SerialTrsm<Side::Left, Uplo::Upper, Trans::Transpose, Diag::NonUnit, Algo::Trsm::Unblocked>::invoke(
1.0, A, b);
KokkosBatched::SerialTrsm<Side::Left, Uplo::Lower, Trans::Transpose, Diag::Unit, Algo::Trsm::Unblocked>::invoke(1.0,
A, b);
KokkosBatched::SerialLaswp<Direct::Backward>::invoke(piv, b);

return 0;
}
};
return 0;
}

//// Conj-Transpose ////
template <>
struct SerialGetrsInternal<Trans::ConjTranspose, Algo::Getrs::Unblocked> {
template <typename AViewType, typename PivViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const PivViewType &piv, const BViewType &b) {
KokkosBatched::SerialTrsm<Side::Left, Uplo::Upper, Trans::ConjTranspose, Diag::NonUnit,
Algo::Trsm::Unblocked>::invoke(1.0, A, b);
KokkosBatched::SerialTrsm<Side::Left, Uplo::Lower, Trans::ConjTranspose, Diag::Unit, Algo::Trsm::Unblocked>::invoke(
1.0, A, b);
KokkosBatched::SerialLaswp<Direct::Backward>::invoke(piv, b);
template <typename AViewType, typename PivViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION int SerialGetrsInternal<Trans::ConjTranspose, Algo::Getrs::Unblocked>::invoke(
const AViewType &A, const PivViewType &piv, const BViewType &b) {
KokkosBatched::SerialTrsm<Side::Left, Uplo::Upper, Trans::ConjTranspose, Diag::NonUnit,
Algo::Trsm::Unblocked>::invoke(1.0, A, b);
KokkosBatched::SerialTrsm<Side::Left, Uplo::Lower, Trans::ConjTranspose, Diag::Unit, Algo::Trsm::Unblocked>::invoke(
1.0, A, b);
KokkosBatched::SerialLaswp<Direct::Backward>::invoke(piv, b);

return 0;
}
};
return 0;
}

} // namespace Impl
} // namespace KokkosBatched
Expand Down
2 changes: 1 addition & 1 deletion batched/dense/impl/KokkosBatched_Laswp_Serial_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ struct SerialLaswp<Direct::Forward> {
template <>
struct SerialLaswp<Direct::Backward> {
template <typename PivViewType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const PivViewType piv, const AViewType &A) {
KOKKOS_INLINE_FUNCTION static int invoke(const PivViewType &piv, const AViewType &A) {
auto info = KokkosBatched::Impl::checkLaswpInput(piv, A);
if (info) return info;

Expand Down
8 changes: 8 additions & 0 deletions batched/dense/impl/KokkosBatched_Laswp_Serial_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,14 @@ struct SerialLaswpVectorBackwardInternal {
template <typename IntType, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const int plen, const IntType *KOKKOS_RESTRICT p, const int ps0,
/* */ ValueType *KOKKOS_RESTRICT A, const int as0) {
// On H100 with Cuda 12.0.0, the compiler seems to apply
// an aggressive optimization which crashes this function
// Disabling loop unrolling fixes the issue
#if defined(KOKKOS_ENABLE_CUDA) && defined(KOKKOS_ARCH_HOPPER90)
#if CUDA_VERSION >= 12000 && CUDA_VERSION < 12100
#pragma unroll 1
#endif
#endif
for (int i = (plen - 1); i >= 0; --i) {
const int piv = p[i * ps0];
if (piv != i) {
Expand Down
2 changes: 0 additions & 2 deletions batched/dense/unit_test/Test_Batched_SerialGetrs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
#include <KokkosBatched_Getrs.hpp>
#include "Test_Batched_DenseUtils.hpp"

using namespace KokkosBatched;

namespace Test {
namespace Getrs {

Expand Down
Loading