Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

map_over_datasets throws error on nodes without datasets #9693

Open
dhruvbalwada opened this issue Oct 29, 2024 · 10 comments
Open

map_over_datasets throws error on nodes without datasets #9693

dhruvbalwada opened this issue Oct 29, 2024 · 10 comments
Labels
topic-DataTree Related to the implementation of a DataTree class

Comments

@dhruvbalwada
Copy link

dhruvbalwada commented Oct 29, 2024

map_over_datasets -- a way to compute over datatrees -- currently seems to try an operate even on nodes which contain no datasets, and consequently raises an error.
This seems to be a new issue, and was not a problem when this function was called map_over_subtree, which was part of the experimental datatree versions.

An example to reproduce this problem is below:

## Generate datatree, using example from documentation
def time_stamps(n_samples, T):
    """Create an array of evenly-spaced time stamps"""
    return xr.DataArray(
        data=np.linspace(0, 2 * np.pi * T, n_samples), dims=["time"]
    )


def signal_generator(t, f, A, phase):
    """Generate an example electrical-like waveform"""
    return A * np.sin(f * t.data + phase)


time_stamps1 = time_stamps(n_samples=15, T=1.5)

time_stamps2 = time_stamps(n_samples=10, T=1.0)

voltages = xr.DataTree.from_dict(
    {
        "/oscilloscope1": xr.Dataset(
            {
                "potential": (
                    "time",
                    signal_generator(time_stamps1, f=2, A=1.2, phase=0.5),
                ),
                "current": (
                    "time",
                    signal_generator(time_stamps1, f=2, A=1.2, phase=1),
                ),
            },
            coords={"time": time_stamps1},
        ),
        "/oscilloscope2": xr.Dataset(
            {
                "potential": (
                    "time",
                    signal_generator(time_stamps2, f=1.6, A=1.6, phase=0.2),
                ),
                "current": (
                    "time",
                    signal_generator(time_stamps2, f=1.6, A=1.6, phase=0.7),
                ),
            },
            coords={"time": time_stamps2},
        ),
    }
)

## Write some function to add resistance
def add_resistance_only_do(dtree): 
    def calculate_resistance(ds):
        ds_new = ds.copy()
        
        ds_new['resistance'] = ds_new['potential']/ds_new['current']
        return ds_new 
        
    dtree = dtree.map_over_datasets(calculate_resistance)
    
    return dtree
    
def add_resistance_try(dtree): 
    def calculate_resistance(ds):
        ds_new = ds.copy()
        try:
            ds_new['resistance'] = ds_new['potential']/ds_new['current']
            return ds_new 
        except:
            return ds_new

    dtree = dtree.map_over_datasets(calculate_resistance)
    
    return dtree

Calling voltages = add_resistance_only_do(voltages) raises the error:

KeyError: "No variable named 'potential'. Variables on the dataset include []"
Raised whilst mapping function over node with path '.'

This can be easily resolved by putting try statements in (e.g. voltages = add_resistance_try(voltages)), but we know that Yoda would not recommend try (right @TomNicholas).

Can this be built in as a default feature of map_over_datasets? as many examples of datatree will have nodes without datasets.

@dhruvbalwada dhruvbalwada added the needs triage Issue that has not been reviewed by xarray team member label Oct 29, 2024
@shoyer
Copy link
Member

shoyer commented Oct 29, 2024

This was an intentional change, because a special case to skip empty nodes felt surprsing to me.

On the other hand, maybe it does make sense to skip nodes without datasets specifically for a method that maps over datasets (but not for a method that maps over nodes). So I'm open to changing this. The other option would be to add a new keyword argument to map_over_datasets for controlling this, something like skip_empty_nodes=True.

For what it's worth, the canonical way to write this today would be something like:

def add_resistance_try(dtree): 
    def calculate_resistance(ds):
        if not ds:
            return None
        ds_new = ds.copy()
        ds_new['resistance'] = ds_new['potential']/ds_new['current']
        return ds_new 

    dtree = dtree.map_over_datasets(calculate_resistance)
    return dtree

@TomNicholas
Copy link
Member

Thanks for raising this @dhruvbalwada !

I would be in favor of changing this. It came up before for users and I'm not surprised it has come up almost immediately again.

I think it's reasonable for "map over datasets" to not map over a node where there is no dataset by default. The subtleties are with inherited variables and attrs. There are multiple issues on the old repo discussing this.

@TomNicholas TomNicholas added topic-DataTree Related to the implementation of a DataTree class and removed needs triage Issue that has not been reviewed by xarray team member labels Oct 29, 2024
@dcherian
Copy link
Contributor

The other option would be to add a new keyword argument to map_over_datasets for controlling this, something like skip_empty_nodes=True.

I like this idea with default False. With deep hierarchies it can be easy to miss that a node might be unexpectedly empty. So it'd be good to force users to opt in.

@kmuehlbauer
Copy link
Contributor

kmuehlbauer commented Oct 30, 2024

I can see uses-cases for both skip_empty_nodes=False/skip_empty_nodes=True. So we wont make all users happy using one or the other default.

But I think we should not add that skip_empty_nodes-kwarg at all. Instead we could encourage users to work with solutions along @shoyer's above suggestion. In more complex scenarios users will need such solutions anyway, since their functions might only work on dedicated nodes as their tree layout might differ significantly and nodes wont be equivalent in terms of their content.

To assist users with that task xarray could provide the same functionality the OP is looking for using a simple decorator, (Update: now tested, finally):

import functools
def skip_empty_nodes(func):
    @functools.wraps(func)
    def _func(ds, *args, **kwargs):
        if not ds:
            return ds
        return func(ds, *args, **kwargs)
    return _func

def add_resistance_try(dtree):
    @skip_empty_nodes
    def calculate_resistance(ds):
        ds_new = ds.copy()
        ds_new['resistance'] = ds_new['potential']/ds_new['current']
        return ds_new 

    dtree = dtree.map_over_datasets(calculate_resistance)
    return dtree
    
    
voltages = add_resistance_try(voltages)

Anyway, if the kwarg-solution is preferred, I'm opting for skip_empty_nodes=False.

@shoyer
Copy link
Member

shoyer commented Oct 30, 2024

I don't think we need extensive helper functions or options in map_over_datasets. It's a convenience function, which is why I'm OK skipping empty nodes by default.

For cases where users need control, they can just iterate over DataTree.subtree_with_keys or xarray.group_subtrees() themselves.

@kmuehlbauer
Copy link
Contributor

Fine with that, too. Are Datasets with only attrs considered empty?

@shoyer
Copy link
Member

shoyer commented Oct 30, 2024

Fine with that, too. Are Datasets with only attrs considered empty?

There are a few different edge cases:

  • Only attrs
  • Only coordinates/attrs

The original map_over_subtrees had special logic to propagate forward attributes only for empty nodes, without calling the mapped over function. That seems reasonable to me.

I'm not sure whether or not to call the mapped over function for nodes that only define coordinates. Certainly I would not blindly copy coordinates from otherwise empty nodes onto the result, because those coordinates may no longer be relevant on the result.

@kmuehlbauer
Copy link
Contributor

Thanks @shoyer for the details. Good to see that there are solutions for many use-cases already built-in or available via external helper functions.

I'm diverting a bit from the issue now. I've had to do this kind of wrapping to feed kwargs to my mapping function. What is the canonical way to feed kwargs to map_over_datasets? I should open a separate issue for that.

@shoyer
Copy link
Member

shoyer commented Oct 30, 2024

I'm diverting a bit from the issue now. I've had to do this kind of wrapping to feed kwargs to my mapping function. What is the canonical way to feed kwargs to map_over_datasets? I should open a separate issue for that.

You can pass in a helper function or use functools.partial. We could also add a kwargs argument like xarray.apply_ufunc.

@keewis
Copy link
Collaborator

keewis commented Oct 30, 2024

or use functools.wraps

shouldn't that be functools.partial?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
topic-DataTree Related to the implementation of a DataTree class
Projects
None yet
Development

No branches or pull requests

6 participants