-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Helper functionality to work around how different libraries handle copies vs. mutation and/or views? #146
Comments
Regarding What would helper APIs for "mutate if you can" look like? You'd have to be really careful using any API like that. A mutation and a copy are very different things and you'd need to make sure you write code that works correctly in both instances. I guess the way to handle it is to never actually rely on mutation semantically. Rather, mutation should just be treated as an implementation detail for performance. In other words, write code that treats everything as immutable but using functions that can mutate or "copy on write" when possible. JAX effectively does this internally, where non-aliased arrays are free to be mutated. I wonder if this request effectively amounts to "rewrite JAX on top of any array API library". If so, it might be very difficult or even impossible to achieve. For instance, even in NumPy we don't currently have the ability to fully track aliasing (views are only tracked in one direction). CC @rgommers |
Ah, I don't think I noticed that. Good to know.
Isn't that problematic for the same reasons you mentioned? "You'd have to be really careful using any API like that." Sure, we developers would need to be careful. In SciPy, we wouldn't pass the dangers on to the user. That is, the function would always return arrays that don't share memory with the inputs, regardless of whether we are doing mutations or copies internally as part of the calculaton. It would only be for circumstances in which either one is acceptable. The alternative is just that JAX is left without that functionality.
See the comment linked above : |
I would say
But you're right that this all comes down to the fact that the standard is agnostic about views vs. copies, so portable code always has to be written in a way that doesn't rely on mutation. See https://data-apis.org/array-api/latest/design_topics/copies_views_and_mutation.html (I pinged Ralf here because I know he has thought about that particular document a lot). |
And he wrote the linked comment with the prototype helper. |
So you're referring to this function? def at_set(
x : Array,
idx: Array | int | slice,
val: Array | int | float | complex,
*,
xp: ModuleType | None = None,
) -> Array:
"""In-place update. Use only if no views are involved."""
xp = array_namespace(x) if xp is None else xp
if is_jax(xp):
if xp.isdtype(idx.dtype, 'bool'):
x = xp.where(idx, x, val)
else:
x = x.at[idx].set(val)
else:
x[idx] = val
return x That seems fine (except I would make a few minor changes), as long as it's well documented that it might or might not actually mutate |
Yes, and yeah, documenting that would be key. |
Follow-up to #144 (comment)
Sometimes to make existing code compatible with backends that are not fully standard compliant, we would need to create copies where the original code would not.
For example, as a workaround for gh-144 (PyTorch doesn't support negative
step
), we could (sometimes) make the replacement:As a workaround for JAX not supporting mutation, we could sometimes make replacements like:
However, making substitutions like this could decrease performance for array types that do support the desired operation (returning a view or mutating the original, in these cases).
Functions that perform the desired operation when possible and the substitute otherwise (e.g. scipy/scipy#20085 (comment)) have been proposed. Do such things belong in
array_api_compat
?The text was updated successfully, but these errors were encountered: