-
-
Notifications
You must be signed in to change notification settings - Fork 5.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
EHN: cluster: JAX support (non-jitted) #22255
Conversation
crusaderky
commented
Jan 6, 2025
•
edited
Loading
edited
- Part of ENH: tracking issue for JAX support #22246
- See also ENH: special: JAX support (non-jitted) #22256
@@ -137,8 +137,8 @@ def whiten(obs, check_finite=True): | |||
obs = _asarray(obs, check_finite=check_finite, xp=xp) | |||
std_dev = xp.std(obs, axis=0) | |||
zero_std_mask = std_dev == 0 | |||
if xp.any(zero_std_mask): | |||
std_dev[zero_std_mask] = 1.0 | |||
std_dev = xpx.at(std_dev, zero_std_mask).set(1.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fails on jax.jit. My current intention is to change jax.jit itself to special-case arr.at[idx].set(value)
when idx is a boolean mask and value is a scalar, so that it can be rewritten as jnp.where(idx, value, arr)
. Failing that, I can implement the same special case in array-api-extra.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My current intention is to change jax.jit itself
I don't think changing jax.jit
itself is a viable path here – I wouldn't suggest starting on that route.
It might be viable to make arr.at[idx].set(value)
lower to lax.select
rather than lax.scatter
in the specific case of a boolean idx
. I've tried that in the past, but it's really tricky to properly handle all corner cases of broadcasted and/or multi-dimensional indices, correctly implementing autodiff and batching rules, etc.
The easiest thing would probably be to do this at the level of xp.at
, though boolean indices were specifically excluded from the initial discussions there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've tried that in the past
Do you have a (partial, non-functioning) PR I could start from?
if xp.any(zero_std_mask): | ||
std_dev[zero_std_mask] = 1.0 | ||
std_dev = xpx.at(std_dev, zero_std_mask).set(1.0) | ||
if check_finite and xp.any(zero_std_mask): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fails on jax.jit with an error message that's not not useful to final scipy users, unless you explicitly pass check_finite=False
.
data-apis/array-api-compat#225 is a blocker for a more robust fix.
All CI failures are unrelated |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @crusaderky, LGTM!
Are there any potential improvements to make to
scipy/scipy/cluster/hierarchy.py
Lines 1839 to 1843 in 7a81451
if is_jax(xp): | |
# calculate_cluster_sizes doesn't accept read-only arrays | |
Zpart = np.array(Zpart, copy=True) | |
else: | |
Zpart = np.asarray(Zpart) |
also, have you ran the full CUDA tests locally? I can, if not. |
Not yet. I intend to replace this and all other calls to |
Yes. All green functionally. This said, there are points where arrays are yanked from device memory to host memory by |
JAX will raise there during traced execution (e.g. under JIT), but in eager execution |
Using the |
I'll investigate this but I expect it to result in a substantial amount of xfails. It would also make sense to figure out the equivalent setting for torch and cupy within the same scope. Either way, I'd rather leave it to a later PR... |
It should line up with existing |
That's what I expect as well. Same for PyTorch, it does not do implicit transfers. |
4c9c5b9
to
7c69d17
Compare
Reminder to follow-up with this - thanks! |