You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
could instead be derived from the input primitives, as long as we are careful to only use vmap over multiple primitives of the same total angular momentum (e.g. evaluate the ERI by shell)
This issue will be used to investigate removing the padding used within the eri_primitives -> note that a similar pattern is used in the evaluation of the nuclear attraction integrals which could also be improved.
The text was updated successfully, but these errors were encountered:
Note. For water STO-3G test-case LMAX is 1 instead of 4. This means the sizes of Ci Cj Ck become (4*4+1)^3=3375 instead of (1+1)^3=8.
As we discussed, I think there's a way to circumvent padding to L_MAX in Jax without resorting to C++.
Problem: Different primitives have different L (in our case L=0 for hydrogen and L=1 for oxygen). The resulting (Ci, Cj, Ck) have shapes 1 for L=0 and 3 for L=3. The output of the broadcast Ci Cj Ck can then take shapes (1,1,1), (1,1,3), (1,3,1), (3,1,1), (3,3,1), (3,1,3), ..., (3,3,3).
Current solution: Pad everything L=4. This works but increases memory/compute/?trace time? 400x.
Other solution: Batch together calls with the same shape. Example: do the (1,1,1) calls together, do the (1,1,3), (1,3,1) and (3,1,1) calls together, and so on. For inspiration, this is done here in ~50 lines of Jax.. The (counts,sizes) looks like [(13271, 1), (32711, 3), (57121, 9), ...] which correspond to the cases (1,1,1,1) then (1,1,3,1) and (1,3,1,1) and so on.
@awf Happy to clarify in person. TLDR: Looks like for this case we should be able to get performant Jax code.
The padding introduced to the
c_term
here:pyscf-ipu/pyscf_ipu/experimental/integrals.py
Line 215 in 023e114
could instead be derived from the input primitives, as long as we are careful to only use vmap over multiple primitives of the same total angular momentum (e.g. evaluate the ERI by shell)
This issue will be used to investigate removing the padding used within the eri_primitives -> note that a similar pattern is used in the evaluation of the nuclear attraction integrals which could also be improved.
The text was updated successfully, but these errors were encountered: