Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Excessive padding used in eri_primitives #117

Open
hatemhelal opened this issue Oct 4, 2023 · 1 comment
Open

Excessive padding used in eri_primitives #117

hatemhelal opened this issue Oct 4, 2023 · 1 comment
Assignees

Comments

@hatemhelal
Copy link
Contributor

The padding introduced to the c_term here:

return segment_sum(c, index, num_segments=4 * LMAX + 1)

could instead be derived from the input primitives, as long as we are careful to only use vmap over multiple primitives of the same total angular momentum (e.g. evaluate the ERI by shell)

This issue will be used to investigate removing the padding used within the eri_primitives -> note that a similar pattern is used in the evaluation of the nuclear attraction integrals which could also be improved.

@hatemhelal hatemhelal self-assigned this Oct 4, 2023
@AlexanderMath
Copy link
Contributor

AlexanderMath commented Oct 4, 2023

Note. For water STO-3G test-case LMAX is 1 instead of 4. This means the sizes of Ci Cj Ck become (4*4+1)^3=3375 instead of (1+1)^3=8.

As we discussed, I think there's a way to circumvent padding to L_MAX in Jax without resorting to C++.

Problem: Different primitives have different L (in our case L=0 for hydrogen and L=1 for oxygen). The resulting (Ci, Cj, Ck) have shapes 1 for L=0 and 3 for L=3. The output of the broadcast Ci Cj Ck can then take shapes (1,1,1), (1,1,3), (1,3,1), (3,1,1), (3,3,1), (3,1,3), ..., (3,3,3).

Current solution: Pad everything L=4. This works but increases memory/compute/?trace time? 400x.

Other solution: Batch together calls with the same shape. Example: do the (1,1,1) calls together, do the (1,1,3), (1,3,1) and (3,1,1) calls together, and so on. For inspiration, this is done here in ~50 lines of Jax.. The (counts,sizes) looks like [(13271, 1), (32711, 3), (57121, 9), ...] which correspond to the cases (1,1,1,1) then (1,1,3,1) and (1,3,1,1) and so on.

@awf Happy to clarify in person. TLDR: Looks like for this case we should be able to get performant Jax code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

When branches are created from issues, their pull requests are automatically linked.

2 participants