Skip to content

Commit

Permalink
[SCF] minor cleanup, enable gauxc device execution
Browse files Browse the repository at this point in the history
  • Loading branch information
ajaypanyala committed Mar 28, 2024
1 parent 8e83bf1 commit f33d199
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 17 deletions.
31 changes: 22 additions & 9 deletions exachem/scf/scf_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,8 +446,6 @@ void scf_restart_test(const ExecutionContext& ec, const SystemData& sys_data, bo
const auto rank = ec.pg().rank();
const bool is_uhf = (sys_data.is_unrestricted);

int rstatus = 1;

std::string movecsfile_alpha = files_prefix + ".alpha.movecs";
std::string densityfile_alpha = files_prefix + ".alpha.density";
std::string movecsfile_beta = files_prefix + ".beta.movecs";
Expand All @@ -458,12 +456,11 @@ void scf_restart_test(const ExecutionContext& ec, const SystemData& sys_data, bo
status = fs::exists(movecsfile_alpha) && fs::exists(densityfile_alpha);
if(is_uhf) status = status && fs::exists(movecsfile_beta) && fs::exists(densityfile_beta);
}
rstatus = status;
ec.pg().barrier();
ec.pg().broadcast(&rstatus, 0);
ec.pg().broadcast(&status, 0);
std::string fnf = movecsfile_alpha + "; " + densityfile_alpha;
if(is_uhf) fnf = fnf + "; " + movecsfile_beta + "; " + densityfile_beta;
if(rstatus == 0) tamm_terminate("Error reading one or all of the files: [" + fnf + "]");
if(!status) tamm_terminate("Error reading one or all of the files: [" + fnf + "]");
}

template<typename T>
Expand Down Expand Up @@ -909,6 +906,8 @@ Matrix compute_schwarz_ints(ExecutionContext& ec, const SCFVars& scf_vars,
Tensor<TensorType> schwarz{scf_vars.tAO, scf_vars.tAO};
Tensor<TensorType> schwarz_mat{tnsh, tnsh};
Tensor<TensorType>::allocate(&ec, schwarz_mat);
Scheduler sch{ec};
sch(schwarz_mat() = 0).execute();

auto compute_schwarz_matrix = [&](const IndexVector& blockid) {
auto bi0 = blockid[0];
Expand Down Expand Up @@ -1229,13 +1228,27 @@ gauxc_util::setup_gauxc(ExecutionContext& ec, const SystemData& sys_data, const
GauXC::RadialQuad::MuraKnowles, grid_type);
auto gauxc_molmeta = std::make_shared<GauXC::MolMeta>(gauxc_mol);

std::string lb_exec_space_str = "HOST";
std::string int_exec_space_str = "HOST";

#ifdef GAUXC_HAS_DEVICE
std::map<std::string, GauXC::ExecutionSpace> exec_space_map = {
{"HOST", GauXC::ExecutionSpace::Host}, {"DEVICE", GauXC::ExecutionSpace::Device}};

auto lb_exec_space = exec_space_map.at(lb_exec_space_str);
auto int_exec_space = exec_space_map.at(int_exec_space_str);
#else
auto lb_exec_space = GauXC::ExecutionSpace::Host;
auto int_exec_space = GauXC::ExecutionSpace::Host;
#endif

// Set the load balancer
GauXC::LoadBalancerFactory lb_factory(GauXC::ExecutionSpace::Host, "Replicated");
GauXC::LoadBalancerFactory lb_factory(lb_exec_space, "Replicated");
auto gauxc_lb = lb_factory.get_shared_instance(gauxc_rt, gauxc_mol, gauxc_molgrid, gauxc_basis);

// Modify the weighting algorithm from the input [Becke, SSF, LKO]
GauXC::MolecularWeightsSettings mw_settings = {GauXC::XCWeightAlg::LKO, false};
GauXC::MolecularWeightsFactory mw_factory(GauXC::ExecutionSpace::Host, "Default", mw_settings);
GauXC::MolecularWeightsFactory mw_factory(int_exec_space, "Default", mw_settings);
auto mw = mw_factory.get_instance();
mw.modify_weights(*gauxc_lb);

Expand Down Expand Up @@ -1273,8 +1286,8 @@ gauxc_util::setup_gauxc(ExecutionContext& ec, const SystemData& sys_data, const
GauXC::functional_type gauxc_func = GauXC::functional_type(kernels);

// Initialize GauXC integrator
GauXC::XCIntegratorFactory<Matrix> integrator_factory(GauXC::ExecutionSpace::Host, "Replicated",
"Default", "Default", "Default");
GauXC::XCIntegratorFactory<Matrix> integrator_factory(int_exec_space, "Replicated", "Default",
"Default", "Default");
auto gc2 = std::chrono::high_resolution_clock::now();
auto gc_time = std::chrono::duration_cast<std::chrono::duration<double>>((gc2 - gc1)).count();

Expand Down
13 changes: 7 additions & 6 deletions exachem/scf/scf_guess.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -722,18 +722,18 @@ void scf_diagonalize(Scheduler& sch, const SystemData& sys_data, SCFVars& scf_va

const int64_t Northo_a = sys_data.nbf; // X_a.cols();
// TODO: avoid eigen Fp
Matrix X_a;
std::vector<double> eps_a;
Matrix X_a;
if(rank == 0) {
// alpha
Matrix Fp = tamm_to_eigen_matrix(ttensors.F_alpha);
X_a = tamm_to_eigen_matrix(ttensors.X_alpha);
C_alpha.resize(N, Northo_a);
std::vector<double> eps_a(Northo_a);

blas::gemm(blas::Layout::ColMajor, blas::Op::NoTrans, blas::Op::Trans, N, Northo_a, N, 1.,
Fp.data(), N, X_a.data(), Northo_a, 0., C_alpha.data(), N);
blas::gemm(blas::Layout::ColMajor, blas::Op::NoTrans, blas::Op::NoTrans, Northo_a, Northo_a, N,
1., X_a.data(), Northo_a, C_alpha.data(), N, 0., Fp.data(), Northo_a);
eps_a.resize(Northo_a);
lapack::syevd(lapack::Job::Vec, lapack::Uplo::Lower, Northo_a, Fp.data(), Northo_a,
eps_a.data());
blas::gemm(blas::Layout::ColMajor, blas::Op::Trans, blas::Op::NoTrans, Northo_a, N, Northo_a,
Expand All @@ -747,12 +747,13 @@ void scf_diagonalize(Scheduler& sch, const SystemData& sys_data, SCFVars& scf_va
// beta
Matrix Fp = tamm_to_eigen_matrix(ttensors.F_beta);
C_beta.resize(N, Northo_b);
Matrix& X_b = X_a;
std::vector<double> eps_b(Northo_b);
Matrix& X_b = X_a;

blas::gemm(blas::Layout::ColMajor, blas::Op::NoTrans, blas::Op::Trans, N, Northo_b, N, 1.,
Fp.data(), N, X_b.data(), Northo_b, 0., C_beta.data(), N);
blas::gemm(blas::Layout::ColMajor, blas::Op::NoTrans, blas::Op::NoTrans, Northo_b, Northo_b,
N, 1., X_b.data(), Northo_b, C_beta.data(), N, 0., Fp.data(), Northo_b);
std::vector<double> eps_b(Northo_b);
lapack::syevd(lapack::Job::Vec, lapack::Uplo::Lower, Northo_b, Fp.data(), Northo_b,
eps_b.data());
blas::gemm(blas::Layout::ColMajor, blas::Op::Trans, blas::Op::NoTrans, Northo_b, N, Northo_b,
Expand All @@ -768,7 +769,7 @@ void scf_diagonalize(Scheduler& sch, const SystemData& sys_data, SCFVars& scf_va
hl_gap -= scf_vars.lshift;

if(!scf_vars.lshift_reset) {
sch.ec().pg().broadcast(&hl_gap, 1, 0);
sch.ec().pg().broadcast(&hl_gap, 0);
if(hl_gap < 1e-2) {
scf_vars.lshift_reset = true;
scf_vars.lshift = 0.5;
Expand Down
10 changes: 8 additions & 2 deletions exachem/scf/scf_iter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,10 @@ void scf_diis(ExecutionContext& ec, const TiledIndexSpace& tAO, Tensor<TensorTyp
}

while(info != 0) {
if(idim == 1) return;
if(idim == 1) {
Tensor<TensorType>::deallocate(dhi_trace);
return;
}

N = idim + 1;
std::vector<TensorType> AC(N * (N + 1) / 2);
Expand Down Expand Up @@ -1053,7 +1056,10 @@ void scf_diis(ExecutionContext& ec, const TiledIndexSpace& tAO, Tensor<TensorTyp
}

idim--;
if(idim == 1) return;
if(idim == 1) {
Tensor<TensorType>::deallocate(dhi_trace);
return;
}
Matrix A_pl = A.block(2, 2, idim, idim);
Matrix A_new = Matrix::Zero(idim + 1, idim + 1);

Expand Down

0 comments on commit f33d199

Please sign in to comment.