Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add atomic overlap interface + host implementation #81

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/gauxc/xc_integrator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class XCIntegrator {
using exc_vxc_type_uks = std::tuple< value_type, matrix_type, matrix_type >;
using exc_grad_type = std::vector< value_type >;
using exx_type = matrix_type;
using atomic_overlap_type = matrix_type;

private:

Expand All @@ -58,6 +59,7 @@ class XCIntegrator {
exc_grad_type eval_exc_grad( const MatrixType& );
exx_type eval_exx ( const MatrixType&,
const IntegratorSettingsEXX& = IntegratorSettingsEXX{} );
atomic_overlap_type eval_atomic_overlap(int64_t iAtom = -1);


const util::Timer& get_timings() const;
Expand Down
7 changes: 7 additions & 0 deletions include/gauxc/xc_integrator/impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ typename XCIntegrator<MatrixType>::exx_type
return pimpl_->eval_exx(P,settings);
};

template <typename MatrixType>
typename XCIntegrator<MatrixType>::atomic_overlap_type
XCIntegrator<MatrixType>::eval_atomic_overlap(int64_t iAtom) {
if( not pimpl_ ) GAUXC_PIMPL_NOT_INITIALIZED();
return pimpl_->eval_atomic_overlap(iAtom);
};

template <typename MatrixType>
const util::Timer& XCIntegrator<MatrixType>::get_timings() const {
if( not pimpl_ ) GAUXC_PIMPL_NOT_INITIALIZED();
Expand Down
16 changes: 16 additions & 0 deletions include/gauxc/xc_integrator/replicated/impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,5 +123,21 @@ typename ReplicatedXCIntegrator<MatrixType>::exx_type

}

template <typename MatrixType>
typename ReplicatedXCIntegrator<MatrixType>::atomic_overlap_type
ReplicatedXCIntegrator<MatrixType>::eval_atomic_overlap_( int64_t iAtom ) {

if( not pimpl_ ) GAUXC_PIMPL_NOT_INITIALIZED();

const size_t nbf = get_load_balancer_().basis().nbf();
matrix_type S( nbf, nbf );

pimpl_->eval_atomic_overlap( iAtom, nbf, nbf, S.data(), nbf );

return S;

}


}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class ReplicatedXCIntegratorImpl {
int64_t ldp, value_type* K, int64_t ldk,
const IntegratorSettingsEXX& settings ) = 0;

virtual void eval_atomic_overlap_(int64_t iAtom, int64_t m, int64_t n,
value_type* S, int64_t LDS) = 0;

public:

ReplicatedXCIntegratorImpl( std::shared_ptr< functional_type > func,
Expand Down Expand Up @@ -86,6 +89,9 @@ class ReplicatedXCIntegratorImpl {
int64_t ldp, value_type* K, int64_t ldk,
const IntegratorSettingsEXX& settings );

void eval_atomic_overlap(int64_t iAtom, int64_t m, int64_t n,
value_type* S, int64_t LDS);

inline const util::Timer& get_timings() const { return timer_; }

inline std::unique_ptr< LocalWorkDriver > release_local_work_driver() {
Expand Down
3 changes: 3 additions & 0 deletions include/gauxc/xc_integrator/replicated_xc_integrator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class ReplicatedXCIntegrator : public XCIntegratorImpl<MatrixType> {
using exc_vxc_type_uks = typename XCIntegratorImpl<MatrixType>::exc_vxc_type_uks;
using exc_grad_type = typename XCIntegratorImpl<MatrixType>::exc_grad_type;
using exx_type = typename XCIntegratorImpl<MatrixType>::exx_type;
using atomic_overlap_type = typename XCIntegratorImpl<MatrixType>::atomic_overlap_type;

private:

Expand All @@ -43,6 +44,8 @@ class ReplicatedXCIntegrator : public XCIntegratorImpl<MatrixType> {
exc_vxc_type_uks eval_exc_vxc_ ( const MatrixType&, const MatrixType& ) override;
exc_grad_type eval_exc_grad_( const MatrixType& ) override;
exx_type eval_exx_ ( const MatrixType&, const IntegratorSettingsEXX& ) override;
atomic_overlap_type eval_atomic_overlap_(int64_t iAtom) override;

const util::Timer& get_timings_() const override;
const LoadBalancer& get_load_balancer_() const override;
LoadBalancer& get_load_balancer_() override;
Expand Down
9 changes: 7 additions & 2 deletions include/gauxc/xc_integrator/xc_integrator_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class XCIntegratorImpl {
using exc_vxc_type_uks = typename XCIntegrator<MatrixType>::exc_vxc_type_uks;
using exc_grad_type = typename XCIntegrator<MatrixType>::exc_grad_type;
using exx_type = typename XCIntegrator<MatrixType>::exx_type;
using atomic_overlap_type = typename XCIntegrator<MatrixType>::atomic_overlap_type;

protected:

Expand All @@ -33,6 +34,8 @@ class XCIntegratorImpl {
virtual exc_grad_type eval_exc_grad_( const MatrixType& P ) = 0;
virtual exx_type eval_exx_ ( const MatrixType& P,
const IntegratorSettingsEXX& settings ) = 0;
virtual atomic_overlap_type eval_atomic_overlap_(int64_t iAtom) = 0;

virtual const util::Timer& get_timings_() const = 0;
virtual const LoadBalancer& get_load_balancer_() const = 0;
virtual LoadBalancer& get_load_balancer_() = 0;
Expand All @@ -56,8 +59,6 @@ class XCIntegratorImpl {
}

/** Integrate EXC / VXC (Mean field terms) for RKS
*
* TODO: add API for UKS/GKS
*
* @param[in] P The alpha density matrix
* @returns EXC / VXC in a combined structure
Expand Down Expand Up @@ -92,6 +93,10 @@ class XCIntegratorImpl {
return eval_exx_(P,settings);
}

atomic_overlap_type eval_atomic_overlap(int64_t iAtom) {
return eval_atomic_overlap_(iAtom);
}

/** Get internal timers
*
* @returns Timer instance for internal timings
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "reference_replicated_xc_host_integrator_exc_vxc.hpp"
#include "reference_replicated_xc_host_integrator_exc_grad.hpp"
#include "reference_replicated_xc_host_integrator_exx.hpp"
#include "reference_replicated_xc_host_integrator_atomic_overlap.hpp"

namespace GauXC {
namespace detail {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class ReferenceReplicatedXCHostIntegrator :
int64_t ldp, value_type* K, int64_t ldk,
const IntegratorSettingsEXX& settings ) override;

void eval_atomic_overlap_(int64_t iAtom, int64_t m, int64_t n,
value_type* S, int64_t lds) override;

void integrate_den_local_work_( const value_type* P, int64_t ldp,
value_type *N_EL );

Expand All @@ -61,6 +64,8 @@ class ReferenceReplicatedXCHostIntegrator :
void exx_local_work_( const value_type* P, int64_t ldp, value_type* K, int64_t ldk,
const IntegratorSettingsEXX& settings );

void atomic_overlap_local_work_(int64_t iAtom, value_type* S, int64_t lds);

public:

template <typename... Args>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
/**
* GauXC Copyright (c) 2020-2023, The Regents of the University of California,
* through Lawrence Berkeley National Laboratory (subject to receipt of
* any required approvals from the U.S. Dept. of Energy). All rights reserved.
*
* See LICENSE.txt for details
*/
#pragma once

#include "reference_replicated_xc_host_integrator.hpp"
#include "host/local_host_work_driver.hpp"
#include "host/blas.hpp"
#include <stdexcept>
#include <set>

namespace GauXC {
namespace detail {

template <typename ValueType>
void ReferenceReplicatedXCHostIntegrator<ValueType>::
eval_atomic_overlap_( int64_t iAtom, int64_t m, int64_t n,
value_type* S, int64_t lds ) {

const auto& basis = this->load_balancer_->basis();

// Check that S is sane
const int64_t nbf = basis.nbf();
if( m != n )
GAUXC_GENERIC_EXCEPTION(" S Must Be Square");
if( m != nbf )
GAUXC_GENERIC_EXCEPTION(" S Must Have Same Dimension as Basis");
if( lds < nbf )
GAUXC_GENERIC_EXCEPTION(" Invalid LDS");


// Get Tasks
this->load_balancer_->get_tasks();

// Compute Local contributions to atomic overlap
this->timer_.time_op("XCIntegrator.LocalWork_AtomicOverlap", [&](){
atomic_overlap_local_work_(iAtom, S, lds);
});

#ifdef GAUXC_ENABLE_MPI
this->timer_.time_op("XCIntegrator.LocalWait_AtomicOverlap", [&](){
MPI_Barrier( this->load_balancer_->runtime().comm() );
});
#endif

// Reduce Results
this->timer_.time_op("XCIntegrator.Allreduce_AtomicOverlap", [&](){

if( not this->reduction_driver_->takes_host_memory() )
GAUXC_GENERIC_EXCEPTION("This Module Only Works With Host Reductions");

this->reduction_driver_->allreduce_inplace( S, nbf*nbf, ReductionOp::Sum );

});

}










template <typename ValueType>
void ReferenceReplicatedXCHostIntegrator<ValueType>::
atomic_overlap_local_work_( int64_t iAtom, value_type* S, int64_t lds ) {

// Cast LWD to LocalHostWorkDriver
auto* lwd = dynamic_cast<LocalHostWorkDriver*>(this->local_work_driver_.get());

// Setup Aliases
const auto& basis = this->load_balancer_->basis();
const auto& mol = this->load_balancer_->molecule();


// Get basis map
BasisSetMap basis_map(basis,mol);

const int32_t nbf = basis.nbf();

// Sort tasks on size (XXX: maybe doesnt matter?)
auto task_comparator = []( const XCTask& a, const XCTask& b ) {
return (a.points.size() * a.bfn_screening.nbe) > (b.points.size() * b.bfn_screening.nbe);
};

auto& tasks = this->load_balancer_->get_tasks();
auto task_begin = tasks.begin();
auto task_end = tasks.end();
if(iAtom >= 0) {
task_end = std::partition(task_begin, task_end, [=](const auto& t){ return t.iParent == iAtom; });
}
std::sort( task_begin, task_end, task_comparator );


// Check that Partition Weights have been calculated
auto& lb_state = this->load_balancer_->state();
if( not lb_state.modified_weights_are_stored ) {
GAUXC_GENERIC_EXCEPTION("Weights Have Not Beed Modified");
}

// Zero out integrands
for( auto j = 0; j < nbf; ++j )
for( auto i = 0; i < nbf; ++i )
S[i + j*lds] = 0.;


const size_t ntasks = std::distance(task_begin, task_end);
#pragma omp parallel
{

XCHostData<value_type> host_data; // Thread local host data

#pragma omp for schedule(dynamic)
for( size_t iT = 0; iT < ntasks; ++iT ) {

// Alias current task
const auto& task = tasks[iT];

// Get tasks constants
const int32_t npts = task.points.size();
const int32_t nbe = task.bfn_screening.nbe;
const int32_t nshells = task.bfn_screening.shell_list.size();

const auto* points = task.points.data()->data();
const auto* weights = task.weights.data();
const int32_t* shell_list = task.bfn_screening.shell_list.data();

// Allocate enough memory for batch
host_data.basis_eval.resize(npts * nbe);
host_data.nbe_scr.resize(nbe * nbe);
host_data.zmat.resize(nbe * nbe);

// Alias/Partition out scratch memory
auto* basis_eval = host_data.basis_eval.data();
auto* nbe_scr = host_data.nbe_scr.data();
auto* zmat = host_data.zmat.data();

// Get the submatrix map for batch
std::vector< std::array<int32_t, 3> > submat_map;
std::tie(submat_map, std::ignore) =
gen_compressed_submat_map(basis_map, task.bfn_screening.shell_list, nbf, nbf);

// Evaluate Collocation
lwd->eval_collocation( npts, nshells, nbe, points, basis, shell_list,
basis_eval );

// Copy BFN -> Z
blas::lacpy('A', npts, nbe, basis_eval, nbe, zmat, nbe);

// Scale columns of Z by weights
for(auto j = 0; j < npts; ++j) {
blas::scal(nbe, 0.5 * weights[j], zmat + j*nbe, nbe);
}

// Incremet LT of S
#pragma omp critical
{
lwd->inc_vxc( npts, nbf, nbe, basis_eval, submat_map, zmat, nbe, S, lds,
nbe_scr );
}

} // Loop over tasks
} // End OpenMP Scope

// Symmetrize VXC
for( int32_t j = 0; j < nbf; ++j ) {
for( int32_t i = j+1; i < nbf; ++i ) {
S[ j + i*lds ] = S[ i + j*lds ];
}
}
}

}
}
10 changes: 10 additions & 0 deletions src/xc_integrator/replicated/replicated_xc_integrator_impl.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ void ReplicatedXCIntegratorImpl<ValueType>::

}


template <typename ValueType>
void ReplicatedXCIntegratorImpl<ValueType>::
eval_atomic_overlap( int64_t iAtom, int64_t m, int64_t n, value_type* S,
int64_t lds ) {

eval_atomic_overlap_(iAtom,m,n,S,lds);

}

template class ReplicatedXCIntegratorImpl<double>;

}
Expand Down
Loading