-
Notifications
You must be signed in to change notification settings - Fork 135
Allow explicit RNG and Sparse input types in JAX functions #278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,38 +1,66 @@ | ||
import jax.experimental.sparse as jsp | ||
from scipy.sparse import spmatrix | ||
|
||
from pytensor.graph.basic import Constant | ||
from pytensor.graph.type import HasDataType | ||
from pytensor.link.jax.dispatch import jax_funcify, jax_typify | ||
from pytensor.sparse.basic import Dot, StructuredDot | ||
from pytensor.sparse.basic import Dot, StructuredDot, Transpose | ||
from pytensor.sparse.type import SparseTensorType | ||
from pytensor.tensor import TensorType | ||
|
||
|
||
@jax_typify.register(spmatrix) | ||
def jax_typify_spmatrix(matrix, dtype=None, **kwargs): | ||
# Note: This changes the type of the constants from CSR/CSC to BCOO | ||
# We could add BCOO as a PyTensor type but this would only be useful for JAX graphs | ||
# and it would break the premise of one graph -> multiple backends. | ||
# The same situation happens with RandomGenerators... | ||
return jsp.BCOO.from_scipy_sparse(matrix) | ||
|
||
|
||
class BCOOType(TensorType, HasDataType): | ||
"""JAX-compatible BCOO type. | ||
|
||
This type is not exposed to users directly. | ||
|
||
It is introduced by the JIT linker in place of any SparseTensorType input | ||
variables used in the original function. Nodes in the function graph will | ||
still show the original types as inputs and outputs. | ||
""" | ||
|
||
def filter(self, data, strict: bool = False, allow_downcast=None): | ||
if isinstance(data, jsp.BCOO): | ||
return data | ||
|
||
if strict: | ||
raise TypeError() | ||
|
||
return jax_typify(data) | ||
|
||
|
||
@jax_typify.register(SparseTensorType) | ||
def jax_typify_SparseTensorType(type): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the idea to have BCOO be the default sparse tensor type, or is it a stopgap? I think some algorithms prefer different types, so it'd be good long term to have different subclasses for SparseTensorType (BCOO, CSC, etc.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From reading JAX docs it seems they are pushing for BCOO only at the moment There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Jax's sparse support isn't great though, I'm not sure they're the best lead to follow. Or, I guess this PR is about pytensor's Jax support only and not necessarily other backends? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is just JAX backend. AFAICT BCOO is the only thing somewhat supported. Their other format (CSC or CSR) doesn't allow for any of the other jax transformations (vmap, grad, jit?). They pushed a paper on BCOO so I think it's really what their planning publicly at least. Pytensor itself uses scipy formats as well as numba (haven't worked on it much tough) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah gotcha. I didn't know about the paper, will try and find that. And that will make it more difficult if Jax has a particular way of handling this vs scipy or numba. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fwiw pytensor only supports a subset of scipy formats (crs and csc, scipy has 7 formats listed). Numba supports the same formats pytensor does, but that's not a coincidence. My point is that there's room to redefine what pytensor's sparse formats should be, if it were advantageous to do so. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jessegrabowski that's definitely true. Still, it's unlikely that we will have a common set of types (RNG / Shared / Tuples / Whatever), that work for all backends. This PR is more focused on how we can allow specalized backend-only types and not about deciding which specific types we want to provide to users as default in PyTensor. |
||
return BCOOType( | ||
dtype=type.dtype, | ||
shape=type.shape, | ||
name=type.name, | ||
broadcastable=type.broadcastable, | ||
) | ||
|
||
|
||
@jax_funcify.register(Dot) | ||
@jax_funcify.register(StructuredDot) | ||
def jax_funcify_sparse_dot(op, node, **kwargs): | ||
for input in node.inputs: | ||
if isinstance(input.type, SparseTensorType) and not isinstance(input, Constant): | ||
raise NotImplementedError( | ||
"JAX sparse dot only implemented for constant sparse inputs" | ||
) | ||
|
||
if isinstance(node.outputs[0].type, SparseTensorType): | ||
raise NotImplementedError("JAX sparse dot only implemented for dense outputs") | ||
|
||
@jsp.sparsify | ||
def sparse_dot(x, y): | ||
out = x @ y | ||
if isinstance(out, jsp.BCOO): | ||
if isinstance(out, jsp.BCOO) and not isinstance( | ||
node.outputs[0].type, SparseTensorType | ||
): | ||
out = out.todense() | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return out | ||
|
||
return sparse_dot | ||
|
||
|
||
@jax_funcify.register(Transpose) | ||
def jax_funcify_sparse_transpose(op, **kwargs): | ||
def sparse_transpose(x): | ||
return x.T | ||
|
||
return sparse_transpose |
Uh oh!
There was an error while loading. Please reload this page.