Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Helper functionality to work around how different libraries handle copies vs. mutation and/or views? #146

Open
mdhaber opened this issue Jun 11, 2024 · 6 comments

Comments

@mdhaber
Copy link

mdhaber commented Jun 11, 2024

Follow-up to #144 (comment)

Sometimes to make existing code compatible with backends that are not fully standard compliant, we would need to create copies where the original code would not.

For example, as a workaround for gh-144 (PyTorch doesn't support negative step), we could (sometimes) make the replacement:

from array_api_compat import torch
x = torch.arange(10)
# x[::-1]  
xp.flip(x)  # ValueError: step must be greater than zero

As a workaround for JAX not supporting mutation, we could sometimes make replacements like:

from array_api_compat import jax
# x is an array, i is a mask with the same shape
# x[i] = 0
x = jax.where(i, 0, x)

However, making substitutions like this could decrease performance for array types that do support the desired operation (returning a view or mutating the original, in these cases).

Functions that perform the desired operation when possible and the substitute otherwise (e.g. scipy/scipy#20085 (comment)) have been proposed. Do such things belong in array_api_compat?

@asmeurer
Copy link
Member

asmeurer commented Jun 11, 2024

Regarding flip(x) vs. x[::-1], note that np.flip does return a view. Torch only doesn't because it doesn't support views with negative strides at all. So just using flip is portable in the manner you suggest. It just isn't as readable (though even that could be argued), but there's no real way around that since we don't want to wrap array/tensor objects in this library. More arbitrary slices would also be more cumbersome to write, and we can definitely build a helper around that if it is needed.

What would helper APIs for "mutate if you can" look like?

You'd have to be really careful using any API like that. A mutation and a copy are very different things and you'd need to make sure you write code that works correctly in both instances. I guess the way to handle it is to never actually rely on mutation semantically. Rather, mutation should just be treated as an implementation detail for performance. In other words, write code that treats everything as immutable but using functions that can mutate or "copy on write" when possible.

JAX effectively does this internally, where non-aliased arrays are free to be mutated. I wonder if this request effectively amounts to "rewrite JAX on top of any array API library". If so, it might be very difficult or even impossible to achieve. For instance, even in NumPy we don't currently have the ability to fully track aliasing (views are only tracked in one direction).

CC @rgommers

@mdhaber
Copy link
Author

mdhaber commented Jun 11, 2024

Regarding flip(x) vs. x[::-1], note that np.flip does return a view.

Ah, I don't think I noticed that. Good to know.

Torch only doesn't because it doesn't support views with negative strides at all....

Isn't that problematic for the same reasons you mentioned? "You'd have to be really careful using any API like that."

Sure, we developers would need to be careful. In SciPy, we wouldn't pass the dangers on to the user. That is, the function would always return arrays that don't share memory with the inputs, regardless of whether we are doing mutations or copies internally as part of the calculaton. It would only be for circumstances in which either one is acceptable. The alternative is just that JAX is left without that functionality.

What would helper APIs for "mutate if you can" look like?

See the comment linked above :
scipy/scipy#20085 (comment)

@asmeurer
Copy link
Member

I would say flip returning a view vs. a copy isn't quite the same because it's not actually mutating anything. It's only problematic if you later do an operation that could mutate the aliased memory. If you never do, the semantics are identical.

x[i] = 0 is different, because it itself is a mutation. So you're doing something that already has different semantics depending on whether x's memory is referenced by other arrays or not.

But you're right that this all comes down to the fact that the standard is agnostic about views vs. copies, so portable code always has to be written in a way that doesn't rely on mutation. See https://data-apis.org/array-api/latest/design_topics/copies_views_and_mutation.html (I pinged Ralf here because I know he has thought about that particular document a lot).

@mdhaber
Copy link
Author

mdhaber commented Jun 11, 2024

And he wrote the linked comment with the prototype helper.

@asmeurer
Copy link
Member

So you're referring to this function?

def at_set(
        x : Array,
        idx: Array | int | slice,
        val: Array | int | float | complex,
        *,
        xp: ModuleType | None = None,
    ) -> Array:
    """In-place update. Use only if no views are involved."""
    xp = array_namespace(x) if xp is None else xp
    if is_jax(xp):
        if xp.isdtype(idx.dtype, 'bool'):
            x = xp.where(idx, x, val)
        else:
            x = x.at[idx].set(val)
    else:
        x[idx] = val
    return x

That seems fine (except I would make a few minor changes), as long as it's well documented that it might or might not actually mutate x.

@mdhaber
Copy link
Author

mdhaber commented Jun 11, 2024

Yes, and yeah, documenting that would be key.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants