Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve efficiency of sparse symmetric ERI #86

Open
blazejba opened this issue Sep 14, 2023 · 1 comment
Open

Improve efficiency of sparse symmetric ERI #86

blazejba opened this issue Sep 14, 2023 · 1 comment

Comments

@blazejba
Copy link
Contributor

Problem:

This blog post describes how to implement sparse matrix multiplication in Jax as:

@jax.partial(jax.jit, static_argnums=(2))
def sp_matmul(A, B, shape):
    """
    Arguments:
        A: (N, M) sparse matrix represented as a tuple (indexes, values)
        B: (M,K) dense matrix
        shape: value of N
    Returns:
        (N, K) dense matrix
    """
    assert B.ndim == 2
    indexes, values = A
    rows, cols = indexes
    in_ = B.take(cols, axis=0)
    prod = in_*values[:, None]
    res = jax.ops.segment_sum(prod, rows, shape)
    return res

This has been done for ERI in sparse_symmetric_ERI.py in sequentialized_iter(...)

Alex generated a profile for it, and thinks it is not as efficient as it should
image

Solution:
Identify the source of inefficiency in sequentialized_iter(...) and find a work around.

@blazejba
Copy link
Contributor Author

blazejba commented Sep 14, 2023

Alex identified that the main problem is actually in eigh, but there still might be oppties to speed up sequentialized_iter(...)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant