We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Problem:
This blog post describes how to implement sparse matrix multiplication in Jax as:
@jax.partial(jax.jit, static_argnums=(2)) def sp_matmul(A, B, shape): """ Arguments: A: (N, M) sparse matrix represented as a tuple (indexes, values) B: (M,K) dense matrix shape: value of N Returns: (N, K) dense matrix """ assert B.ndim == 2 indexes, values = A rows, cols = indexes in_ = B.take(cols, axis=0) prod = in_*values[:, None] res = jax.ops.segment_sum(prod, rows, shape) return res
This has been done for ERI in sparse_symmetric_ERI.py in sequentialized_iter(...)
sparse_symmetric_ERI.py
Alex generated a profile for it, and thinks it is not as efficient as it should
Solution: Identify the source of inefficiency in sequentialized_iter(...) and find a work around.
The text was updated successfully, but these errors were encountered:
Alex identified that the main problem is actually in eigh, but there still might be oppties to speed up sequentialized_iter(...)
eigh
sequentialized_iter(...)
Sorry, something went wrong.
No branches or pull requests
Problem:
This blog post describes how to implement sparse matrix multiplication in Jax as:
This has been done for ERI in
sparse_symmetric_ERI.py
in sequentialized_iter(...)Alex generated a profile for it, and thinks it is not as efficient as it should
Solution:
Identify the source of inefficiency in sequentialized_iter(...) and find a work around.
The text was updated successfully, but these errors were encountered: