Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conditionally run health checks on jitted JAX arrays and dask arrays #225

Open
crusaderky opened this issue Dec 31, 2024 · 9 comments · May be fixed by #228
Open

Conditionally run health checks on jitted JAX arrays and dask arrays #225

crusaderky opened this issue Dec 31, 2024 · 9 comments · May be fixed by #228

Comments

@crusaderky
Copy link
Contributor

crusaderky commented Dec 31, 2024

I've found cases of functions in scipy, e.g.

https://github.com/scipy/scipy/blob/4758525cae48f9cfb6971be6702fbb412e783aa5/scipy/cluster/vq.py#L332-L334

that crash on the very first few lines with JAX when they're inside @jax.jit:

>>> import jax
>>> import jax.numpy as xp
>>> from scipy.cluster.vq import kmeans
>>> a = xp.asarray([[1.,2.],[3.,4.]])
>>> jax.jit(kmeans)(a, 2)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function kmeans at /home/crusaderky/github/scipy/build-install/lib/python3.12/site-packages/scipy/cluster/vq.py:332 for jit. This concrete value was not available in Python because it depends on the value of the argument obs.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

The issue is that there is a health check enabled by default, check_finite=True, which triggers this code:

https://github.com/scipy/scipy/blob/4758525cae48f9cfb6971be6702fbb412e783aa5/scipy/_lib/_array_api.py#L104-L108

A JAX jitted array crashes on bool(); a dask array is quietly computed when you do so - which is possibly even worse.

There are two issues here:

  1. the default behaviour of the function is to inspect the contents of the array, and
  2. the error message is uncomprehensible to an end user, as it is triggered by code deep inside the scipy implementation.

My proposal:

  1. in array-api-compat, add two functions:
def is_jax_jitted_array(x):
    return isinstance(x, DynamicJaxprTracer)

def is_material_array(x):
    """Return True if x has contents in memory at the moment of calling this function,
    which are cheap to retrieve as long as they're small in size.
    Return False if x is a future or it would be otherwise impossible or expensive to
    read its contents, regardless of their size.
    """
    return not is_dask_array(x) and not is_jax_jitted_array(x)
  1. in scipy, change kmeans(..., check_finite=True) to kmeans(..., check_finite=None), which will mean "check if possible", and replace
if check_finite:
    _check_finite(x, xp)

with

def _check_material_array(x: Array, check: bool | None, check_name: str) -> bool:
    if check is None:
        return is_material_array(x)
    if check and not is_material_array(x):
        raise TypeError(f"Can't check non-material array {type(x)}. Please set {check_name} to None or False.")
    return check

...

if _check_material_array(x, check_finite, "check_finite"):
    _check_finite(x, xp)

However, @jakevdp mentioned elsewhere that DynamicJaxprTracer is not part of the public API of JAX and there is no public method to test for jitting. Not sure I can see a way forward without this information.

Also CC @rgommers @lucascolley

@crusaderky
Copy link
Contributor Author

crusaderky commented Dec 31, 2024

Upon further inspection, kmeans specifically contains so much code - namely, a cython kernel - that can only possibly run on material arrays on CPU, that I suspect it must be reimplemented in JAX and then dispatched. The question of data validation is generic and stands regardless.

@ev-br
Copy link
Contributor

ev-br commented Jan 1, 2025

From the scipy side of things, check_finite is mainly to fail with a python error when, say, feeding an array with nans to a LAPACK function which is UB and may as well segfault (in fact, there were multiple bug reports to this effect, which is why there's now check_finite=True by default).
Tweaking the check_finite to not check for "immaterial" arrays is fine---as long as the JITted code handles these internally.
Which I honestly doubt, so maybe these checks should be done somehow before jitting things.
Anyhow, this is more of a scipy question.

For array-api-compat side of the proposal, I see no blockers in principle. Would be nice however to have something to this effect in the spec. So maybe you could take a look at relevant discussions in the array-api spec repository or start a new one if needs be?
data-apis/array-api#748 looks somewhat related, maybe there are more.

@crusaderky
Copy link
Contributor Author

crusaderky commented Jan 1, 2025

This problem was already extensively discussed in data-apis/array-api#642, which culminated in data-apis/array-api#652.

What I'm suggesting here is a function you call to avoid blindly calling __bool__ and see if it raises or not.

Also, dask won't raise on __bool__, but instead it will quietly compute the whole graph and then discard it, and then recompute it again later. In a converging algorithm with 20 calls to __bool__ along the way, dask will compute everything 21 times. I'm not a fan of the behaviour, but it would be a major breaking change for dask to switch to raising ValueError instead.

is_material_array / is_lazy_array would allow to deal with it.

I suppose that the function could be made to work on generic compliant arrays:

def is_lazy_array(x):
    """Return False if x is a future or it would be otherwise impossible or expensive to
    eagerly read its contents, regardless of their size, e.g. by calling ``bool(x)`` or ``float(x)``.
    Return True otherwise; e.g. ``bool(x)`` etc. is guaranteed to succeed and to be cheap as
    long as the array is the right dtype.
    """
    if is_numpy_array(x) or is_cupy_array(x) or is_torch_array(x) or is_pydata_sparse_array(x):
        return False
    if is_dask_array(x) or is_ndonnx_array(x):
        # Note: dask quietly computes the array when calling eager methods such as __bool__,
        # but it can be extremely expensive and should be avoided.
        return True
    if is_jax_array(x):
        return is_jax_jitted_array(x)

    # Unknown array api compatible array
    xp = array_namespace(x)
    canary = xp.asarray(False, device=device(x))
    try:
        return bool(canary)
    except ValueError:
        return True

@rgommers
Copy link
Member

rgommers commented Jan 2, 2025

Refreshing my memory here, this is the summary which I think captures the discussion from last year: data-apis/array-api#642 (comment)

What I'm suggesting here is a function you call to avoid blindly calling __bool__ and see if it raises or not.

The question is if in my summary linked above, I was right about this or not: "this should be rare, and I don't think it is necessary at this point to either support a special API like cond for this, nor to have scikit-learn & co worry about this as a problem right now."

This was always a situation of "not blocking now, we can reconsider in the future if needed". If there are too many places where data validation is problematic, then the function you're suggesting here is basically cond, which can call torch.cond and jax.lax.cond under the hood. Depending on urgency, we can:

  1. Reconsider adding cond to the standard,
  2. Add cond to array-api-extra
  3. Delay this for a bit longer

What is your assessment here @crusaderky? I think we the route here should probably be (3) -> (2) -> (1), so if it's a bottleneck now then consider (2)?

@crusaderky
Copy link
Contributor Author

I think that the path of least resistance, at least for now, is to relax the default behaviour of those functions that have optional data validation which is on by default, to be on only when possible.

This means less validation by default on jitted jax, but if the user really cares about it they'll explicitly specify check_finite=True, which in turn will raise a meaningful error, which should prompt the user to pre-validate their inputs before they feed them to the jit. Note that the validation will still happen on non-jitted jax, which would be friendly to what I suspect is a typical development workflow (prototype without jit, then enable jit and debug jit-specific issues at the end).

@rgommers
Copy link
Member

rgommers commented Jan 2, 2025

I believe that by design you can't/shouldn't distinguish between jitted and eager JAX usage, which is why you said above:

However, @jakevdp mentioned elsewhere that DynamicJaxprTracer is not part of the public API of JAX and there is no public method to test for jitting. Not sure I can see a way forward without this information.

Supporting JAX's JIT is where all this really pays off, so I think using cond will be necessary sooner or later. A path of least resistance here is to delay this; it seems to me that implementing is_jax_jitted_array is more of a detour, since it's not something that is desirable to have long-term.

@rgommers
Copy link
Member

rgommers commented Jan 8, 2025

I'm looking at gh-228 now. The is_lazy_array function there seems reasonable, I just wonder if it fixes this issue. You can simply bypass all input validation in SciPy functions, if nothing else that's relevant to the actual functionality is happening. For any other if-statements (e.g., if some_check: return x else return y) we are probably still in the same situation of needing cond. I don't have a good feeling for what fraction of functionality is in the latter category.

@crusaderky
Copy link
Contributor Author

crusaderky commented Jan 8, 2025

For any other if-statements (e.g., if some_check: return x else return y) we are probably still in the same situation of needing cond. I don't have a good feeling for what fraction of functionality is in the latter category.

I've been dealing with a very similar test to is_finite here: scipy/scipy#22274
Across the whole PR, it seems to me that all the cases of

if reduce(x):
    return expensive1(x)
else:
    return expensive2(x)

could be rewritten in the worst case as

if is_lazy_array(x):
    return xp.where(reduce(x), expensive1(x), expensive2(x))
elif reduce(x):
    return expensive1(x)
else:
    return expensive2(x)

which is just a special case of

return _lazywhere(reduce(x), x, expensive1, expensive2)

(assuming, of course, that the output of the two code paths has the same shape and dtype).

Practical example:
https://github.com/scipy/scipy/blob/185233e854a44415d7fc629d9726da9b89c51357/scipy/stats/_morestats.py#L3881-L3883
In the above case, we could reach the end of the function without testing contains_nan and then on the last line disregard the output:

- if contains_nan and nan_policy == 'propagate':
-     return MedianTestResult(np.nan, np.nan, np.nan, None)
+ if nan_policy == 'propagate' and not is_lazy_array(contains_nan) and contains_nan:
+    return MedianTestResult(np.nan, np.nan, np.nan, None)
...
- return MedianTestResult(stat, p, grand_median, table)
+ out = stat, p, grand_median, table
+ if nan_policy == 'propagate' and is_lazy_array(contains_nan):
+     out = tuple([xp.where(contains_nan, xp.nan, v) for v in res])
+ return MedianTestResult(*out)

I may, of course, be wrong and we could find a case that can only be resolved by cond, but I expect such a case to be uncommon at least.

@rgommers
Copy link
Member

I may, of course, be wrong and we could find a case that can only be resolved by cond, but I expect such a case to be uncommon at least.

Thanks for the analysis! I'm not sure I am fully convinced by the MedianTestResult example - that would be cleaner with cond I think - however you are probably correct regarding the big picture, and that with is_lazy_array and _lazywhere most things can be covered. So I'm happy to move ahead with your plan here; we can always reconsider cond later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants