Skip to content

Commit

Permalink
use eval_vxc
Browse files Browse the repository at this point in the history
  • Loading branch information
wxj6000 committed Feb 2, 2025
1 parent b565366 commit a42a070
Showing 1 changed file with 71 additions and 83 deletions.
154 changes: 71 additions & 83 deletions gpu4pyscf/dft/numint.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,29 +338,32 @@ def eval_vxc(mol, ao, wv, idx, vmat, xctype_code):
assert vmat.flags['C_CONTIGUOUS']
assert idx.dtype == np.int32

vmat = vmat.reshape(-1, nao, nao)
nset = vmat.shape[0]
ngrids, nao_mask = ao.shape[-1], ao.shape[-2]
wv = wv.reshape(nset, -1, ngrids)
cublas_handle = cupy.cuda.device.get_cublas_handle()
stream = cupy.cuda.get_current_stream()
buf1 = cupy.empty([nao_mask, ngrids])
buf2 = cupy.empty([nao_mask, nao_mask])
for i in range(nset):
err = libgdft.GDFTeval_vxc(
ctypes.cast(stream.ptr, ctypes.c_void_p),
ctypes.cast(cublas_handle, ctypes.c_void_p),
ctypes.c_int(xctype_code),
ctypes.cast(ao.data.ptr, ctypes.c_void_p),
ctypes.cast(wv[i].data.ptr, ctypes.c_void_p),
ctypes.cast(buf1.data.ptr, ctypes.c_void_p),
ctypes.cast(buf2.data.ptr, ctypes.c_void_p),
ctypes.c_int(ngrids),
ctypes.c_int(nao_mask),
ctypes.c_int(nao),
ctypes.cast(idx.data.ptr, ctypes.c_void_p),
ctypes.cast(vmat[i].data.ptr, ctypes.c_void_p),
)

err = libgdft.GDFTeval_vxc(
ctypes.cast(stream.ptr, ctypes.c_void_p),
ctypes.cast(cublas_handle, ctypes.c_void_p),
ctypes.c_int(xctype_code),
ctypes.cast(ao.data.ptr, ctypes.c_void_p),
ctypes.cast(wv.data.ptr, ctypes.c_void_p),
ctypes.cast(buf1.data.ptr, ctypes.c_void_p),
ctypes.cast(buf2.data.ptr, ctypes.c_void_p),
ctypes.c_int(ngrids),
ctypes.c_int(nao_mask),
ctypes.c_int(nao),
ctypes.cast(idx.data.ptr, ctypes.c_void_p),
ctypes.cast(vmat.data.ptr, ctypes.c_void_p),
)

if err != 0:
raise RuntimeError('CUDA Error in GDFTeval_vxc')
if err != 0:
raise RuntimeError('CUDA Error in GDFTeval_vxc')
return

def _vv10nlc(rho, coords, vvrho, vvweight, vvcoords, nlc_pars):
Expand Down Expand Up @@ -525,15 +528,15 @@ def _nr_rks_task(ni, mol, grids, xc_code, dms, mo_coeff, mo_occ,
# libxc calls are still running on default stream
nelec = cupy.zeros(nset)
excsum = cupy.zeros(nset)
wv = []
wv = cupy.empty_like(rho_tot)
for i in range(nset):
exc, vxc = ni.eval_xc_eff(xc_code, rho_tot[i], deriv=1, xctype=xctype)[:2]
vxc = cupy.asarray(vxc, order='C')
exc = cupy.asarray(exc, order='C')
den = rho_tot[i][0] * weights
nelec[i] = den.sum()
excsum[i] = cupy.dot(den, exc[:,0])
wv.append(vxc * weights)
wv[i] = vxc * weights
if xctype == 'GGA':
wv[i][0] *= .5
if xctype == 'MGGA':
Expand All @@ -553,8 +556,7 @@ def _nr_rks_task(ni, mol, grids, xc_code, dms, mo_coeff, mo_occ,
max_memory=None,
grid_range=(grid_start, grid_end)):
p1 = p0 + weight.size
for i in range(nset):
eval_vxc(_sorted_mol, ao_mask, wv[i][:,p0:p1], idx, vmat[i], xctype_code)
eval_vxc(_sorted_mol, ao_mask, wv[:,:,p0:p1], idx, vmat, xctype_code)
p0 = p1
t0 = log.timer_debug1(f'eval integration on {device_id}', *t0)
return vmat, nelec.get(), excsum.get()
Expand Down Expand Up @@ -903,56 +905,36 @@ def _nr_uks_task(ni, mol, grids, xc_code, dms, mo_coeff, mo_occ,
rho_b = eval_rho(_sorted_mol, ao_mask, dmb[i][idx[:,None],idx], xctype=xctype, hermi=hermi)
else:
mo_coeff_mask = mo_coeff[:, idx,:]
#rho_a = eval_rho2(_sorted_mol, ao_mask, mo_coeff_mask[0], mo_occ[0], None, xctype)
#rho_b = eval_rho2(_sorted_mol, ao_mask, mo_coeff_mask[1], mo_occ[1], None, xctype)
rho_a = eval_rho2_fast(ao_mask, mo_coeff_mask[0], mo_occ[1], None, xctype, with_lapl)
rho_b = eval_rho2_fast(ao_mask, mo_coeff_mask[0], mo_occ[1], None, xctype, with_lapl)
rho = cupy.stack([rho_a, rho_b], axis=0)
exc, vxc = ni.eval_xc_eff(xc_code, rho, deriv=1, xctype=xctype)[:2]
t1 = log.timer_debug1('eval vxc', *t0)

if xctype == 'LDA':
den_a = rho_a * weight
den_b = rho_b * weight
wv = vxc[:,0] * weight
eval_vxc(_sorted_mol, ao_mask, wv[0], idx, vmata[i], xctype_code)
eval_vxc(_sorted_mol, ao_mask, wv[1], idx, vmatb[i], xctype_code)
#va = ao_mask.dot(_scale_ao(ao_mask, wv[0]).T)
#vb = ao_mask.dot(_scale_ao(ao_mask, wv[1]).T)
#add_sparse(vmata[i], va, idx)
#add_sparse(vmatb[i], vb, idx)

elif xctype == 'GGA':
den_a = rho_a[0] * weight
den_b = rho_b[0] * weight
wv = vxc * weight
wv[:,0] *= .5
eval_vxc(_sorted_mol, ao_mask, wv[0], idx, vmata[i], xctype_code)
eval_vxc(_sorted_mol, ao_mask, wv[1], idx, vmatb[i], xctype_code)
#va = ao_mask[0].dot(_scale_ao(ao_mask, wv[0]).T)
#vb = ao_mask[0].dot(_scale_ao(ao_mask, wv[1]).T)
#add_sparse(vmata[i], va, idx)
#add_sparse(vmatb[i], vb, idx)
elif xctype == 'NLC':
raise NotImplementedError('NLC')
elif xctype == 'MGGA':
den_a = rho_a[0] * weight
den_b = rho_b[0] * weight
wv = vxc * weight
wv[:,[0, 4]] *= .5
eval_vxc(_sorted_mol, ao_mask, wv[0], idx, vmata[i], xctype_code)
eval_vxc(_sorted_mol, ao_mask, wv[1], idx, vmatb[i], xctype_code)
#va = ao_mask[0].dot(_scale_ao(ao_mask[:4], wv[0,:4]).T)
#vb = ao_mask[0].dot(_scale_ao(ao_mask[:4], wv[1,:4]).T)
#va += _tau_dot(ao_mask, ao_mask, wv[0,4])
#vb += _tau_dot(ao_mask, ao_mask, wv[1,4])
#add_sparse(vmata[i], va, idx)
#add_sparse(vmatb[i], vb, idx)
elif xctype == 'HF':
pass
else:
raise NotImplementedError(f'numint.nr_uks for functional {xc_code}')

eval_vxc(_sorted_mol, ao_mask, wv[0], idx, vmata[i], xctype_code)
eval_vxc(_sorted_mol, ao_mask, wv[1], idx, vmatb[i], xctype_code)

if xctype == "LDA":
den_a = rho_a * weight
den_b = rho_b * weight
else:
den_a = rho_a[0] * weight
den_b = rho_b[0] * weight
nelec[0,i] += den_a.sum()
nelec[1,i] += den_b.sum()
excsum[i] += cupy.dot(den_a, exc[:,0])
Expand Down Expand Up @@ -1067,7 +1049,8 @@ def get_rho(ni, mol, dm, grids, max_memory=2000, verbose=None):
if mo_coeff is None:
rho[p0:p1] = eval_rho(_sorted_mol, ao, dm, xctype='LDA', hermi=1)
else:
rho[p0:p1] = eval_rho2(_sorted_mol, ao, mo_coeff, mo_occ, None, 'LDA')
#rho[p0:p1] = eval_rho2(_sorted_mol, ao, mo_coeff, mo_occ, None, 'LDA')
rho[p0:p1] = eval_rho2_fast(ao, mo_coeff, mo_occ, None, 'LDA')
t1 = log.timer_debug2('eval rho slice', *t1)
t0 = log.timer_debug1('eval rho', *t0)

Expand All @@ -1088,6 +1071,9 @@ def _nr_rks_fxc_task(ni, mol, grids, xc_code, fxc, dms, mo1, occ_coeff,
xctype = ni._xc_type(xc_code)
opt = getattr(ni, 'gdftopt', None)

if xctype == 'NLC':
raise NotImplementedError('NLC')

_sorted_mol = opt.mol
nao = mol.nao
dms = cupy.asarray(dms)
Expand All @@ -1099,6 +1085,13 @@ def _nr_rks_fxc_task(ni, mol, grids, xc_code, fxc, dms, mo1, occ_coeff,
else:
ao_deriv = 1

if xctype in ("LDA", "HF"):
xctype_code = 0
elif xctype in ("GGA", "NLC"):
xctype_code = 1
else:
xctype_code = 2

ngrids_glob = grids.coords.shape[0]
ngrids_per_device = (ngrids_glob + _num_devices - 1) // _num_devices
ngrids_per_device = (ngrids_per_device + MIN_BLK_SIZE - 1) // MIN_BLK_SIZE * MIN_BLK_SIZE
Expand Down Expand Up @@ -1136,21 +1129,13 @@ def _nr_rks_fxc_task(ni, mol, grids, xc_code, fxc, dms, mo1, occ_coeff,
fxc_w = fxc[:,:,p0:p1] * weights
wv = contract('axg,xyg->ayg', rho1, fxc_w)

for i in range(nset):
if xctype == 'LDA':
vmat_tmp = ao.dot(_scale_ao(ao, wv[i]).T)
elif xctype == 'GGA':
wv[i,0] *= .5
aow = _scale_ao(ao, wv[i])
vmat_tmp = aow.dot(ao[0].T)
elif xctype == 'NLC':
raise NotImplementedError('NLC')
else:
wv[i,0] *= .5
wv[i,4] *= .5
vmat_tmp = ao[0].dot(_scale_ao(ao[:4], wv[i,:4]).T)
vmat_tmp+= _tau_dot(ao, ao, wv[i,4])
add_sparse(vmat[i], vmat_tmp, mask)
if xctype == 'GGA':
wv[:,0] *= .5
if xctype == 'MGGA':
wv[:,0] *= .5
wv[:,4] *= .5

eval_vxc(_sorted_mol, ao, wv, mask, vmat, xctype_code)

t1 = log.timer_debug2('integration', *t1)
ao = rho1 = None
Expand Down Expand Up @@ -1251,6 +1236,13 @@ def _nr_uks_fxc_task(ni, mol, grids, xc_code, fxc, dms, mo1, occ_coeff,
else:
ao_deriv = 1

if xctype in ("LDA", "HF"):
xctype_code = 0
elif xctype in ("GGA", "NLC"):
xctype_code = 1
else:
xctype_code = 2

ngrids_glob = grids.coords.shape[0]
ngrids_per_device = (ngrids_glob + _num_devices - 1) // _num_devices
ngrids_per_device = (ngrids_per_device + MIN_BLK_SIZE - 1) // MIN_BLK_SIZE * MIN_BLK_SIZE
Expand Down Expand Up @@ -1295,25 +1287,19 @@ def _nr_uks_fxc_task(ni, mol, grids, xc_code, fxc, dms, mo1, occ_coeff,
wv_a+= contract('xg,xyg->yg', rho1b[i], fxc_w[1,:,0])
wv_b = contract('xg,xyg->yg', rho1a[i], fxc_w[0,:,1])
wv_b+= contract('xg,xyg->yg', rho1b[i], fxc_w[1,:,1])
if xctype == 'LDA':
va = ao.dot(_scale_ao(ao, wv_a[0]).T)
vb = ao.dot(_scale_ao(ao, wv_b[0]).T)
elif xctype == 'GGA':

if xctype == 'GGA':
wv_a[0] *= .5 # for transpose_sum at the end
wv_b[0] *= .5
va = ao[0].dot(_scale_ao(ao, wv_a).T)
vb = ao[0].dot(_scale_ao(ao, wv_b).T)
elif xctype == 'NLC':
raise NotImplementedError('NLC')
else:
#elif xctype == 'NLC':
# raise NotImplementedError('NLC')
if xctype == 'MGGA':
wv_a[[0,4]] *= .5 # for transpose_sum at the end
wv_b[[0,4]] *= .5
va = ao[0].dot(_scale_ao(ao[:4], wv_a[:4]).T)
vb = ao[0].dot(_scale_ao(ao[:4], wv_b[:4]).T)
va += _tau_dot(ao, ao, wv_a[4])
vb += _tau_dot(ao, ao, wv_b[4])
add_sparse(vmata[i], va, mask)
add_sparse(vmatb[i], vb, mask)

eval_vxc(_sorted_mol, ao, wv_a, mask, vmata[i], xctype_code)
eval_vxc(_sorted_mol, ao, wv_b, mask, vmatb[i], xctype_code)

t1 = log.timer_debug2('integration', *t1)
t0 = log.timer_debug1('vxc', *t0)
return vmata, vmatb
Expand Down Expand Up @@ -1465,15 +1451,17 @@ def nr_nlc_vxc(ni, mol, grids, xc_code, dms, relativity=0, hermi=1,
vv_vxc = xc_deriv.transform_vxc(rho, vxc, 'GGA', spin=0)
t1 = log.timer_debug1('transform vxc', *t1)

xctype_code = 1
vmat = cupy.zeros((nao,nao))
p1 = 0
for ao, mask, weight, coords \
in ni.block_loop(_sorted_mol, grids, nao, ao_deriv, max_memory=max_memory):
p0, p1 = p1, p1 + weight.size
wv = vv_vxc[:,p0:p1] * weight
wv[0] *= .5
aow = _scale_ao(ao, wv)
add_sparse(vmat, ao[0].dot(aow.T), mask)
#aow = _scale_ao(ao, wv)
#add_sparse(vmat, ao[0].dot(aow.T), mask)
eval_vxc(_sorted_mol, ao, wv, mask, vmat, xctype_code)
t1 = log.timer_debug1('integration', *t1)

transpose_sum(vmat)
Expand Down

0 comments on commit a42a070

Please sign in to comment.