Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions source/module_hamilt_pw/hamilt_pwdft/forces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#include "module_hamilt_general/module_surchem/surchem.h"
#include "module_hamilt_general/module_vdw/vdw.h"
#include "kernels/force_op.h"

#include <type_traits>
#ifdef _OPENMP
#include <omp.h>
#endif
Expand Down Expand Up @@ -579,7 +579,7 @@ void Forces<FPTYPE, Device>::cal_force_loc(const UnitCell& ucell,
syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, forcelc_d, forcelc.c, this->nat * 3);
syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, vloc_d, vloc.c, vloc.nr * vloc.nc);

hamilt::cal_force_loc_op<FPTYPE, Device>()(
hamilt::cal_force_loc_op<FPTYPE, Device>()(
this->nat,
rho_basis->npw,
ucell.tpiba * ucell.omega,
Expand All @@ -591,6 +591,8 @@ void Forces<FPTYPE, Device>::cal_force_loc(const UnitCell& ucell,
vloc_d,
vloc.nc,
forcelc_d);


syncmem_var_d2h_op()(this->cpu_ctx, this->ctx, forcelc.c, forcelc_d, this->nat * 3);

delmem_int_op()(this->ctx,iat2it_d);
Expand Down Expand Up @@ -799,6 +801,7 @@ void Forces<FPTYPE, Device>::cal_force_ew(const UnitCell& ucell,
aux_d,
forceion_d);


syncmem_var_d2h_op()(this->cpu_ctx, this->ctx, forceion.c, forceion_d, this->nat * 3);
delmem_int_op()(this->ctx,iat2it_d);
delmem_var_op()(this->ctx,gcar_d);
Expand Down Expand Up @@ -917,8 +920,25 @@ void Forces<FPTYPE, Device>::cal_force_ew(const UnitCell& ucell,
return;
}

namespace hamilt {

#if defined(__ROCM) || defined(__HIP_PLATFORM_AMD__)
template struct cal_force_ew_sincos_op<double, base_device::DEVICE_GPU>;
template struct cal_force_ew_sincos_op<float, base_device::DEVICE_GPU>;

template struct cal_force_loc_sincos_op<double, base_device::DEVICE_GPU>;
template struct cal_force_loc_sincos_op<float, base_device::DEVICE_GPU>;
#endif

#if defined(__CUDA) || defined(__NVCC__)
template struct cal_force_ew_op<double, base_device::DEVICE_GPU>;
template struct cal_force_ew_op<float, base_device::DEVICE_GPU>;

template struct cal_force_loc_op<double, base_device::DEVICE_GPU>;
template struct cal_force_loc_op<float, base_device::DEVICE_GPU>;
#endif

} // namespace hamilt
template class Forces<double, base_device::DEVICE_CPU>;
#if ((defined __CUDA) || (defined __ROCM))
template class Forces<double, base_device::DEVICE_GPU>;
Expand Down
57 changes: 56 additions & 1 deletion source/module_hamilt_pw/hamilt_pwdft/kernels/force_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,35 @@ struct cal_force_ew_op{
FPTYPE* forceion
) {};
};

template <typename FPTYPE, typename Device>
struct cal_force_loc_sincos_op{
void operator()(
const Device* ctx,
const int nat,
const int npw,
const int ntype,
const FPTYPE* gcar,
const FPTYPE* tau,
const FPTYPE* vloc_per_type,
const std::complex<FPTYPE>* aux,
const FPTYPE& scale_factor,
FPTYPE* force) {};
};

template <typename FPTYPE, typename Device>
struct cal_force_ew_sincos_op{
void operator()(
const Device* ctx,
const int nat,
const int npw,
const int ig_gge0,
const FPTYPE* gcar,
const FPTYPE* tau,
const FPTYPE* it_facts,
const std::complex<FPTYPE>* aux,
FPTYPE* force) {};
};
#if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
template <typename FPTYPE>
struct cal_vkb1_nl_op<FPTYPE, base_device::DEVICE_GPU>
Expand Down Expand Up @@ -335,6 +364,32 @@ struct cal_force_ew_op<FPTYPE, base_device::DEVICE_GPU>{
FPTYPE* forceion
);
};
template <typename FPTYPE>
struct cal_force_loc_sincos_op<FPTYPE, base_device::DEVICE_GPU> {
void operator()(const base_device::DEVICE_GPU* ctx,
const int& nat,
const int& npw,
const int& ntype,
const FPTYPE* gcar,
const FPTYPE* tau,
const FPTYPE* vloc_per_type,
const std::complex<FPTYPE>* aux,
const FPTYPE& scale_factor,
FPTYPE* force);
};

template <typename FPTYPE>
struct cal_force_ew_sincos_op<FPTYPE, base_device::DEVICE_GPU> {
void operator()(const base_device::DEVICE_GPU* ctx,
const int& nat,
const int& npw,
const int& ig_gge0,
const FPTYPE* gcar,
const FPTYPE* tau,
const FPTYPE* it_facts,
const std::complex<FPTYPE>* aux,
FPTYPE* force);
};
#endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
} // namespace hamilt
#endif // W_ABACUS_DEVELOP_ABACUS_DEVELOP_SOURCE_source_pw_HAMILT_PWDFT_KERNELS_FORCE_OP_H
#endif // W_ABACUS_DEVELOP_ABACUS_DEVELOP_SOURCE_source_pw_HAMILT_PWDFT_KERNELS_FORCE_OP_H
Loading
Loading