-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conditionally run health checks on jitted JAX arrays and dask arrays #225
Comments
Upon further inspection, |
From the scipy side of things, For array-api-compat side of the proposal, I see no blockers in principle. Would be nice however to have something to this effect in the spec. So maybe you could take a look at relevant discussions in the array-api spec repository or start a new one if needs be? |
This problem was already extensively discussed in data-apis/array-api#642, which culminated in data-apis/array-api#652. What I'm suggesting here is a function you call to avoid blindly calling Also, dask won't raise on
I suppose that the function could be made to work on generic compliant arrays: def is_lazy_array(x):
"""Return False if x is a future or it would be otherwise impossible or expensive to
eagerly read its contents, regardless of their size, e.g. by calling ``bool(x)`` or ``float(x)``.
Return True otherwise; e.g. ``bool(x)`` etc. is guaranteed to succeed and to be cheap as
long as the array is the right dtype.
"""
if is_numpy_array(x) or is_cupy_array(x) or is_torch_array(x) or is_pydata_sparse_array(x):
return False
if is_dask_array(x) or is_ndonnx_array(x):
# Note: dask quietly computes the array when calling eager methods such as __bool__,
# but it can be extremely expensive and should be avoided.
return True
if is_jax_array(x):
return is_jax_jitted_array(x)
# Unknown array api compatible array
xp = array_namespace(x)
canary = xp.asarray(False, device=device(x))
try:
return bool(canary)
except ValueError:
return True |
Refreshing my memory here, this is the summary which I think captures the discussion from last year: data-apis/array-api#642 (comment)
The question is if in my summary linked above, I was right about this or not: "this should be rare, and I don't think it is necessary at this point to either support a special API like This was always a situation of "not blocking now, we can reconsider in the future if needed". If there are too many places where data validation is problematic, then the function you're suggesting here is basically
What is your assessment here @crusaderky? I think we the route here should probably be (3) -> (2) -> (1), so if it's a bottleneck now then consider (2)? |
I think that the path of least resistance, at least for now, is to relax the default behaviour of those functions that have optional data validation which is on by default, to be on only when possible. This means less validation by default on jitted jax, but if the user really cares about it they'll explicitly specify |
I believe that by design you can't/shouldn't distinguish between jitted and eager JAX usage, which is why you said above:
Supporting JAX's JIT is where all this really pays off, so I think using |
I'm looking at gh-228 now. The |
I've been dealing with a very similar test to is_finite here: scipy/scipy#22274 if reduce(x):
return expensive1(x)
else:
return expensive2(x) could be rewritten in the worst case as if is_lazy_array(x):
return xp.where(reduce(x), expensive1(x), expensive2(x))
elif reduce(x):
return expensive1(x)
else:
return expensive2(x) which is just a special case of return _lazywhere(reduce(x), x, expensive1, expensive2) (assuming, of course, that the output of the two code paths has the same shape and dtype). Practical example: - if contains_nan and nan_policy == 'propagate':
- return MedianTestResult(np.nan, np.nan, np.nan, None)
+ if nan_policy == 'propagate' and not is_lazy_array(contains_nan) and contains_nan:
+ return MedianTestResult(np.nan, np.nan, np.nan, None)
...
- return MedianTestResult(stat, p, grand_median, table)
+ out = stat, p, grand_median, table
+ if nan_policy == 'propagate' and is_lazy_array(contains_nan):
+ out = tuple([xp.where(contains_nan, xp.nan, v) for v in res])
+ return MedianTestResult(*out) I may, of course, be wrong and we could find a case that can only be resolved by cond, but I expect such a case to be uncommon at least. |
Thanks for the analysis! I'm not sure I am fully convinced by the |
I've found cases of functions in scipy, e.g.
https://github.com/scipy/scipy/blob/4758525cae48f9cfb6971be6702fbb412e783aa5/scipy/cluster/vq.py#L332-L334
that crash on the very first few lines with JAX when they're inside
@jax.jit
:The issue is that there is a health check enabled by default,
check_finite=True
, which triggers this code:https://github.com/scipy/scipy/blob/4758525cae48f9cfb6971be6702fbb412e783aa5/scipy/_lib/_array_api.py#L104-L108
A JAX jitted array crashes on
bool()
; a dask array is quietly computed when you do so - which is possibly even worse.There are two issues here:
My proposal:
kmeans(..., check_finite=True)
tokmeans(..., check_finite=None)
, which will mean "check if possible", and replacewith
However, @jakevdp mentioned elsewhere that DynamicJaxprTracer is not part of the public API of JAX and there is no public method to test for jitting. Not sure I can see a way forward without this information.
Also CC @rgommers @lucascolley
The text was updated successfully, but these errors were encountered: