Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EHN: cluster: JAX support (non-jitted) #22255

Merged
merged 1 commit into from
Jan 6, 2025

Conversation

crusaderky
Copy link
Contributor

@crusaderky crusaderky commented Jan 6, 2025

@@ -137,8 +137,8 @@ def whiten(obs, check_finite=True):
obs = _asarray(obs, check_finite=check_finite, xp=xp)
std_dev = xp.std(obs, axis=0)
zero_std_mask = std_dev == 0
if xp.any(zero_std_mask):
std_dev[zero_std_mask] = 1.0
std_dev = xpx.at(std_dev, zero_std_mask).set(1.0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails on jax.jit. My current intention is to change jax.jit itself to special-case arr.at[idx].set(value) when idx is a boolean mask and value is a scalar, so that it can be rewritten as jnp.where(idx, value, arr). Failing that, I can implement the same special case in array-api-extra.

Copy link
Member

@jakevdp jakevdp Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My current intention is to change jax.jit itself

I don't think changing jax.jit itself is a viable path here – I wouldn't suggest starting on that route.

It might be viable to make arr.at[idx].set(value) lower to lax.select rather than lax.scatter in the specific case of a boolean idx. I've tried that in the past, but it's really tricky to properly handle all corner cases of broadcasted and/or multi-dimensional indices, correctly implementing autodiff and batching rules, etc.

The easiest thing would probably be to do this at the level of xp.at, though boolean indices were specifically excluded from the initial discussions there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried that in the past

Do you have a (partial, non-functioning) PR I could start from?

if xp.any(zero_std_mask):
std_dev[zero_std_mask] = 1.0
std_dev = xpx.at(std_dev, zero_std_mask).set(1.0)
if check_finite and xp.any(zero_std_mask):
Copy link
Contributor Author

@crusaderky crusaderky Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails on jax.jit with an error message that's not not useful to final scipy users, unless you explicitly pass check_finite=False.
data-apis/array-api-compat#225 is a blocker for a more robust fix.

@lucascolley lucascolley added enhancement A new feature or improvement array types Items related to array API support and input array validation (see gh-18286) labels Jan 6, 2025
@crusaderky
Copy link
Contributor Author

All CI failures are unrelated

Copy link
Member

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @crusaderky, LGTM!

Are there any potential improvements to make to

if is_jax(xp):
# calculate_cluster_sizes doesn't accept read-only arrays
Zpart = np.array(Zpart, copy=True)
else:
Zpart = np.asarray(Zpart)
at this time?

@lucascolley lucascolley added this to the 1.16.0 milestone Jan 6, 2025
@lucascolley
Copy link
Member

also, have you ran the full CUDA tests locally? I can, if not.

@crusaderky
Copy link
Contributor Author

thanks @crusaderky, LGTM!

Are there any potential improvements to make to

if is_jax(xp):
# calculate_cluster_sizes doesn't accept read-only arrays
Zpart = np.array(Zpart, copy=True)
else:
Zpart = np.asarray(Zpart)

at this time?

Not yet. I intend to replace this and all other calls to np.asarray in the middle of xp-aware code with xpx.apply_numpy_function(func, arg), which internally will call jax.pure_callback(func, arg) for JAX and xp.asarray(func(np.asarray(arg)) for everything else.

@crusaderky
Copy link
Contributor Author

also, have you ran the full CUDA tests locally? I can, if not.

Yes. All green functionally. This said, there are points where arrays are yanked from device memory to host memory by np.asarray, ran through a Cython kernel (search for _vq.), and sent back to jax.

@lucascolley
Copy link
Member

This said, there are points where arrays are yanked from device memory to host memory by np.asarray, ran through a Cython kernel (search for _vq.), and sent back to jax.

Ah, does JAX not raise/warn there? Do we want to error out there @rgommers as per gh-18286?

@jakevdp
Copy link
Member

jakevdp commented Jan 6, 2025

Ah, does JAX not raise/warn there?

JAX will raise there during traced execution (e.g. under JIT), but in eager execution np.asarray is just fine. If you want to prevent this, you can enable the transfer guard.

@rgommers
Copy link
Member

rgommers commented Jan 6, 2025

Using the "disallow" transfer guard in scipy/conftest.py would seem reasonable to me.

@crusaderky
Copy link
Contributor Author

Using the "disallow" transfer guard in scipy/conftest.py would seem reasonable to me.

I'll investigate this but I expect it to result in a substantial amount of xfails. It would also make sense to figure out the equivalent setting for torch and cupy within the same scope. Either way, I'd rather leave it to a later PR...

@lucascolley
Copy link
Member

I'll investigate this but I expect it to result in a substantial amount of xfails. It would also make sense to figure out the equivalent setting for torch and cupy within the same scope. Either way, I'd rather leave it to a later PR...

It should line up with existing cpu_only skips, no? CuPy errors on np.asarray

@rgommers
Copy link
Member

rgommers commented Jan 6, 2025

It should line up with existing cpu_only skips, no? CuPy errors on np.asarray

That's what I expect as well. Same for PyTorch, it does not do implicit transfers.

@lucascolley
Copy link
Member

Using the "disallow" transfer guard in scipy/conftest.py would seem reasonable to me.

Reminder to follow-up with this - thanks!

@lucascolley lucascolley merged commit a32bc49 into scipy:main Jan 6, 2025
35 of 37 checks passed
@crusaderky crusaderky deleted the jax_cluster branch January 7, 2025 00:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
array types Items related to array API support and input array validation (see gh-18286) enhancement A new feature or improvement scipy.cluster
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants