Skip to content

Commit 27cbfbf

Browse files
high performance optimization for force calculation
1 parent bc17385 commit 27cbfbf

4 files changed

Lines changed: 924 additions & 760 deletions

File tree

source/module_hamilt_pw/hamilt_pwdft/forces.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
#include "module_hamilt_general/module_surchem/surchem.h"
1717
#include "module_hamilt_general/module_vdw/vdw.h"
1818
#include "kernels/force_op.h"
19-
19+
#include <type_traits>
2020
#ifdef _OPENMP
2121
#include <omp.h>
2222
#endif
@@ -579,7 +579,7 @@ void Forces<FPTYPE, Device>::cal_force_loc(const UnitCell& ucell,
579579
syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, forcelc_d, forcelc.c, this->nat * 3);
580580
syncmem_var_h2d_op()(this->ctx, this->cpu_ctx, vloc_d, vloc.c, vloc.nr * vloc.nc);
581581

582-
hamilt::cal_force_loc_op<FPTYPE, Device>()(
582+
hamilt::cal_force_loc_op<FPTYPE, Device>()(
583583
this->nat,
584584
rho_basis->npw,
585585
ucell.tpiba * ucell.omega,
@@ -591,6 +591,8 @@ void Forces<FPTYPE, Device>::cal_force_loc(const UnitCell& ucell,
591591
vloc_d,
592592
vloc.nc,
593593
forcelc_d);
594+
595+
594596
syncmem_var_d2h_op()(this->cpu_ctx, this->ctx, forcelc.c, forcelc_d, this->nat * 3);
595597

596598
delmem_int_op()(this->ctx,iat2it_d);
@@ -799,6 +801,7 @@ void Forces<FPTYPE, Device>::cal_force_ew(const UnitCell& ucell,
799801
aux_d,
800802
forceion_d);
801803

804+
802805
syncmem_var_d2h_op()(this->cpu_ctx, this->ctx, forceion.c, forceion_d, this->nat * 3);
803806
delmem_int_op()(this->ctx,iat2it_d);
804807
delmem_var_op()(this->ctx,gcar_d);
@@ -917,8 +920,25 @@ void Forces<FPTYPE, Device>::cal_force_ew(const UnitCell& ucell,
917920
return;
918921
}
919922

923+
namespace hamilt {
924+
925+
#if defined(__ROCM) || defined(__HIP_PLATFORM_AMD__)
926+
template struct cal_force_ew_sincos_op<double, base_device::DEVICE_GPU>;
927+
template struct cal_force_ew_sincos_op<float, base_device::DEVICE_GPU>;
920928

929+
template struct cal_force_loc_sincos_op<double, base_device::DEVICE_GPU>;
930+
template struct cal_force_loc_sincos_op<float, base_device::DEVICE_GPU>;
931+
#endif
932+
933+
#if defined(__CUDA) || defined(__NVCC__)
934+
template struct cal_force_ew_op<double, base_device::DEVICE_GPU>;
935+
template struct cal_force_ew_op<float, base_device::DEVICE_GPU>;
936+
937+
template struct cal_force_loc_op<double, base_device::DEVICE_GPU>;
938+
template struct cal_force_loc_op<float, base_device::DEVICE_GPU>;
939+
#endif
921940

941+
} // namespace hamilt
922942
template class Forces<double, base_device::DEVICE_CPU>;
923943
#if ((defined __CUDA) || (defined __ROCM))
924944
template class Forces<double, base_device::DEVICE_GPU>;

source/module_hamilt_pw/hamilt_pwdft/kernels/force_op.h

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,35 @@ struct cal_force_ew_op{
179179
FPTYPE* forceion
180180
) {};
181181
};
182+
183+
template <typename FPTYPE, typename Device>
184+
struct cal_force_loc_sincos_op{
185+
void operator()(
186+
const Device* ctx,
187+
const int nat,
188+
const int npw,
189+
const int ntype,
190+
const FPTYPE* gcar,
191+
const FPTYPE* tau,
192+
const FPTYPE* vloc_per_type,
193+
const std::complex<FPTYPE>* aux,
194+
const FPTYPE& scale_factor,
195+
FPTYPE* force) {};
196+
};
197+
198+
template <typename FPTYPE, typename Device>
199+
struct cal_force_ew_sincos_op{
200+
void operator()(
201+
const Device* ctx,
202+
const int nat,
203+
const int npw,
204+
const int ig_gge0,
205+
const FPTYPE* gcar,
206+
const FPTYPE* tau,
207+
const FPTYPE* it_facts,
208+
const std::complex<FPTYPE>* aux,
209+
FPTYPE* force) {};
210+
};
182211
#if __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
183212
template <typename FPTYPE>
184213
struct cal_vkb1_nl_op<FPTYPE, base_device::DEVICE_GPU>
@@ -335,6 +364,32 @@ struct cal_force_ew_op<FPTYPE, base_device::DEVICE_GPU>{
335364
FPTYPE* forceion
336365
);
337366
};
367+
template <typename FPTYPE>
368+
struct cal_force_loc_sincos_op<FPTYPE, base_device::DEVICE_GPU> {
369+
void operator()(const base_device::DEVICE_GPU* ctx,
370+
const int& nat,
371+
const int& npw,
372+
const int& ntype,
373+
const FPTYPE* gcar,
374+
const FPTYPE* tau,
375+
const FPTYPE* vloc_per_type,
376+
const std::complex<FPTYPE>* aux,
377+
const FPTYPE& scale_factor,
378+
FPTYPE* force);
379+
};
380+
381+
template <typename FPTYPE>
382+
struct cal_force_ew_sincos_op<FPTYPE, base_device::DEVICE_GPU> {
383+
void operator()(const base_device::DEVICE_GPU* ctx,
384+
const int& nat,
385+
const int& npw,
386+
const int& ig_gge0,
387+
const FPTYPE* gcar,
388+
const FPTYPE* tau,
389+
const FPTYPE* it_facts,
390+
const std::complex<FPTYPE>* aux,
391+
FPTYPE* force);
392+
};
338393
#endif // __CUDA || __UT_USE_CUDA || __ROCM || __UT_USE_ROCM
339394
} // namespace hamilt
340-
#endif // W_ABACUS_DEVELOP_ABACUS_DEVELOP_SOURCE_source_pw_HAMILT_PWDFT_KERNELS_FORCE_OP_H
395+
#endif // W_ABACUS_DEVELOP_ABACUS_DEVELOP_SOURCE_source_pw_HAMILT_PWDFT_KERNELS_FORCE_OP_H

0 commit comments

Comments
 (0)