Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion source/source_basis/module_pw/pw_transform_k_dsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ void PW_Basis_K::recip2real_dsp(const std::complex<double>* in,
{
const int one = 1;
const std::complex<double> factor1 = std::complex<double>(factor, 0);
zaxpy_(&nrxx, &factor1, auxr, &one, out, &one);
BlasConnector::axpy(nrxx, factor1, auxr, one, out, one);
}
else
{
Expand Down
2 changes: 1 addition & 1 deletion source/source_io/unk_overlap_lcao.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,7 @@ std::complex<double> unkOverlap_lcao::det_berryphase(const UnitCell& ucell,

int* ipiv = new int[para_orb.nrow];
int info = 0;
pzgetrf_(&occBands, &occBands, out_matrix, &one, &one, para_orb.desc, ipiv, &info);
ScalapackConnector::getrf(occBands, occBands, out_matrix, one, one, para_orb.desc, ipiv, &info);

for (int i = 0; i < occBands; i++) // global
{
Expand Down
6 changes: 3 additions & 3 deletions source/source_lcao/module_deepks/deepks_orbpre.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,11 @@ void DeePKS_domain::cal_orbital_precalc(const std::vector<TH>& dm_hl,
{
for (int m2 = 0; m2 < nm; ++m2) // m1 = 1 for s, 3 for p, 5 for d
{
accessor[ik][inl][m1][m2] += ddot_(&row_size,
accessor[ik][inl][m1][m2] += BlasConnector::dot(row_size,
p_g1dmt + index * row_size * nks,
&inc,
inc,
s_1t.data() + index * row_size,
&inc);
inc);
index++;
}
}
Expand Down
21 changes: 11 additions & 10 deletions source/source_lcao/module_deepks/deepks_pdm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "deepks_pdm.h"
#include "source_base/constants.h"
#include "source_base/libm/libm.h"
#include "source_base/module_external/blas_connector.h"
#include "source_base/timer.h"
#include "source_lcao/module_hcontainer/atom_pair.h"
#ifdef __MPI
Expand Down Expand Up @@ -402,11 +403,11 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
{
for (int m2 = 0; m2 < nm; ++m2) // m1 = 1 for s, 3 for p, 5 for d
{
accessor[m1][m2] += ddot_(&row_size,
g_1dmt.data() + index * row_size,
&inc,
s_1t.data() + index * row_size,
&inc);
accessor[m1][m2] += BlasConnector::dot(row_size,
g_1dmt.data() + index * row_size,
inc,
s_1t.data() + index * row_size,
inc);
index++;
}
}
Expand All @@ -428,11 +429,11 @@ void DeePKS_domain::cal_pdm(bool& init_pdm,
{
// ddot_: dot product of two vectors
// inc means the increment of the index
accessor[iproj * nproj + jproj] += ddot_(&row_size,
g_1dmt.data() + index * row_size,
&inc,
s_1t.data() + index * row_size,
&inc);
accessor[iproj * nproj + jproj] += BlasConnector::dot(row_size,
g_1dmt.data() + index * row_size,
inc,
s_1t.data() + index * row_size,
inc);
index++;
}
}
Expand Down
132 changes: 43 additions & 89 deletions source/source_lcao/module_dftu/dftu_force.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,52 +23,6 @@
#include <string.h>


extern "C"
{
// I'm not sure what's happenig here, but the interface in scalapack_connecter.h
// does not seem to work, so I'll use this one here
void pzgemm_(const char* transa,
const char* transb,
const int* M,
const int* N,
const int* K,
const std::complex<double>* alpha,
const std::complex<double>* A,
const int* IA,
const int* JA,
const int* DESCA,
const std::complex<double>* B,
const int* IB,
const int* JB,
const int* DESCB,
const std::complex<double>* beta,
std::complex<double>* C,
const int* IC,
const int* JC,
const int* DESCC);

void pdgemm_(const char* transa,
const char* transb,
const int* M,
const int* N,
const int* K,
const double* alpha,
const double* A,
const int* IA,
const int* JA,
const int* DESCA,
const double* B,
const int* IB,
const int* JB,
const int* DESCB,
const double* beta,
double* C,
const int* IC,
const int* JC,
const int* DESCC);
}


void Plus_U::force_stress(const UnitCell& ucell,
const Grid_Driver& gd,
std::vector<std::vector<double>>* dmk_d, // mohan modify 2025-11-02
Expand Down Expand Up @@ -161,10 +115,10 @@ void Plus_U::force_stress(const UnitCell& ucell,


#ifdef __MPI
pzgemm_(&transT, &transN, &nlocal, &nlocal, &nlocal,
&alpha, (*dmk_c)[ik].data(), &one_int, &one_int, // important to add (), 20251103
pv.desc, VU, &one_int, &one_int, pv.desc, &beta,
&rho_VU[0], &one_int, &one_int, pv.desc);
ScalapackConnector::gemm(transT, transN, nlocal, nlocal, nlocal,
alpha, (*dmk_c)[ik].data(), one_int, one_int,
pv.desc, VU, one_int, one_int, pv.desc, beta,
&rho_VU[0], one_int, one_int, pv.desc);
#endif

delete[] VU;
Expand Down Expand Up @@ -237,24 +191,24 @@ void Plus_U::cal_force_k(const UnitCell& ucell,
this->folding_matrix_k(ucell, gd, fsr, pv, ik, dim + 1, 0, &dSm_k[0], kvec_d);

#ifdef __MPI
pzgemm_(&transN,
&transC,
&PARAM.globalv.nlocal,
&PARAM.globalv.nlocal,
&PARAM.globalv.nlocal,
&one,
ScalapackConnector::gemm(transN,
transC,
PARAM.globalv.nlocal,
PARAM.globalv.nlocal,
PARAM.globalv.nlocal,
one,
&dSm_k[0],
&one_int,
&one_int,
one_int,
one_int,
pv.desc,
rho_VU,
&one_int,
&one_int,
one_int,
one_int,
pv.desc,
&zero,
zero,
&dm_VU_dSm[0],
&one_int,
&one_int,
one_int,
one_int,
pv.desc);
#endif

Expand All @@ -275,24 +229,24 @@ void Plus_U::cal_force_k(const UnitCell& ucell,
} // end ir

#ifdef __MPI
pzgemm_(&transN,
&transN,
&PARAM.globalv.nlocal,
&PARAM.globalv.nlocal,
&PARAM.globalv.nlocal,
&one,
ScalapackConnector::gemm(transN,
transN,
PARAM.globalv.nlocal,
PARAM.globalv.nlocal,
PARAM.globalv.nlocal,
one,
&dSm_k[0],
&one_int,
&one_int,
one_int,
one_int,
pv.desc,
rho_VU,
&one_int,
&one_int,
one_int,
one_int,
pv.desc,
&zero,
zero,
&dm_VU_dSm[0],
&one_int,
&one_int,
one_int,
one_int,
pv.desc);
#endif

Expand Down Expand Up @@ -371,24 +325,24 @@ void Plus_U::cal_stress_k(const UnitCell& ucell,
this->folding_matrix_k(ucell, gd, fsr, pv, ik, dim1 + 4, dim2, &dSR_k[0], kvec_d);

#ifdef __MPI
pzgemm_(&transN,
&transN,
&nlocal,
&nlocal,
&nlocal,
&minus_half,
ScalapackConnector::gemm(transN,
transN,
nlocal,
nlocal,
nlocal,
minus_half,
rho_VU,
&one_int,
&one_int,
one_int,
one_int,
pv.desc,
&dSR_k[0],
&one_int,
&one_int,
one_int,
one_int,
pv.desc,
&zero,
zero,
&dm_VU_sover[0],
&one_int,
&one_int,
one_int,
one_int,
pv.desc);
#endif

Expand Down
Loading