Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduction operations fail with Weighted DataArrayWeighted #9841

Open
5 tasks done
3tilley opened this issue Nov 29, 2024 · 4 comments
Open
5 tasks done

Reduction operations fail with Weighted DataArrayWeighted #9841

3tilley opened this issue Nov 29, 2024 · 4 comments

Comments

@3tilley
Copy link

3tilley commented Nov 29, 2024

What happened?

I have a DataSet with some weighted DataArrays. This set-up is extremely useful to me as I can filter and perform operations over the whole dataset and all shared dimensions. One of the DataArrays is weighted, and I was hoping this would be automatically handled in groupbys and general reduction operations, but the error thrown is below. If I call mean on the dataset.

I'm happy to raise a PR to fix if I can work out how to do it, but I just want to make sure that it's agreed that this isn't correct behaviour.

What did you expect to happen?

I would like DataArrays that are unweighted to return the usual mean, and for DataArrayWeighted to return a mean reflecting their weights, as if I'd just called da_weighted.mean(). This would allow me to calculate means in groupbys on the DataSet.

Minimal Complete Verifiable Example

import numpy as np
import xarray as xr

da = xr.DataArray(
    data=[[4.0, 5.0, 6.0], [1.0, 2.0, np.nan], [np.nan, np.nan, np.nan]],
    dims=["t", "x"],
    coords={"t": [0, 1, 2], "x": ["a", "b", "c"]}
)

dw = xr.DataArray(
    data=[0.1, 0.2, 0.3],
    dims=["t"],
    coords={"t": [0, 1, 2]}
)
db = da.copy(deep=True).weighted(dw)
ds = xr.Dataset({"a": da, "b": db, "w": dw})

# This works
print(db.mean())

# Errors
print(ds["b"].mean())

# Errors
print(ds.mean(dim="t"))

# Errors
print(ds.groupby_bins("t", bins=[0, 2.5]).mean())

MVCE confirmation

  • Minimal example — the example is as focused as reasonably possible to demonstrate the underlying issue in xarray.
  • Complete example — the example is self-contained, including all data and the text of any traceback.
  • Verifiable example — the example copy & pastes into an IPython prompt or Binder notebook, returning the result.
  • New issue — a search of GitHub Issues suggests this is not a duplicate.
  • Recent environment — the issue occurs with the latest version of xarray and its dependencies.

Relevant log output

python weighted_demo.py
<xarray.DataArray ()> Size: 8B
array(3.)
Traceback (most recent call last):
  File <redact>, line 22, in <module>
    print(ds["b"].mean())
          ^^^^^^^^^^^^^^
  File "<redact>/.venv/lib/python3.12/site-packages/xarray/core/_aggregations.py", line 2982, in mean
    return self.reduce(
           ^^^^^^^^^^^^
  File "<redact>/.venv/lib/python3.12/site-packages/xarray/core/dataarray.py", line 3839, in reduce
    var = self.variable.reduce(func, dim, axis, keep_attrs, keepdims, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<redact>/.venv/lib/python3.12/site-packages/xarray/core/variable.py", line 1677, in reduce
    result = super().reduce(
             ^^^^^^^^^^^^^^^
  File "<redact>/.venv/lib/python3.12/site-packages/xarray/namedarray/core.py", line 918, in reduce
    data = func(self.data, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<redact>/.venv/lib/python3.12/site-packages/xarray/core/duck_array_ops.py", line 680, in mean
    return _mean(array, axis=axis, skipna=skipna, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<redact>/.venv/lib/python3.12/site-packages/xarray/core/duck_array_ops.py", line 447, in f
    return func(values, axis=axis, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<redact>/.venv/lib/python3.12/site-packages/xarray/core/nanops.py", line 124, in nanmean
    return _nanmean_ddof_object(0, a, axis=axis, dtype=dtype)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<redact>/.venv/lib/python3.12/site-packages/xarray/core/nanops.py", line 117, in _nanmean_ddof_object
    data = np.sum(value, axis=axis, dtype=dtype, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<redact>/.venv/lib/python3.12/site-packages/numpy/_core/fromnumeric.py", line 2389, in sum
    return _wrapreduction(
           ^^^^^^^^^^^^^^^
  File "<redact>/.venv/lib/python3.12/site-packages/numpy/_core/fromnumeric.py", line 86, in _wrapreduction
    return ufunc.reduce(obj, axis, dtype, out, **passkwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: float() argument must be a string or a real number, not 'DataArrayWeighted'

Anything else we need to know?

There are several functions that might fall into this category like std, but I think they could all be handled similarly.

Environment

INSTALLED VERSIONS

commit: None
python: 3.12.3 (main, Jul 31 2024, 17:43:48) [GCC 13.2.0]
python-bits: 64
OS: Linux
OS-release: 5.15.153.1-microsoft-standard-WSL2
machine: x86_64
processor: x86_64
byteorder: little
LC_ALL: None
LANG: C.UTF-8
LOCALE: ('C', 'UTF-8')
libhdf5: 1.14.4
libnetcdf: None

xarray: 2024.10.0
pandas: 2.2.3
numpy: 2.0.2
scipy: 1.14.1
netCDF4: None
pydap: None
h5netcdf: 1.4.0
h5py: 3.12.1
zarr: None
cftime: None
nc_time_axis: None
iris: None
bottleneck: 1.4.2
dask: None
distributed: None
matplotlib: 3.9.2
cartopy: None
seaborn: 0.13.2
numbagg: None
fsspec: None
cupy: None
pint: None
sparse: 0.15.4
flox: None
numpy_groupies: None
setuptools: 75.3.0
pip: 24.0
conda: None
pytest: 8.3.3
mypy: 1.13.0
IPython: 8.29.0
sphinx: None

@3tilley 3tilley added bug needs triage Issue that has not been reviewed by xarray team member labels Nov 29, 2024
Copy link

welcome bot commented Nov 29, 2024

Thanks for opening your first issue here at xarray! Be sure to follow the issue template!
If you have an idea for a solution, we would really welcome a Pull Request with proposed changes.
See the Contributing Guide for more.
It may take us a while to respond here, but we really value your contribution. Contributors like you help make xarray better.
Thank you!

@max-sixty
Copy link
Collaborator

Very much agree with your suggestion.

It's likely worth spiking a quick example of what the code changes would look like.
Responding quickly from memory — it's plausible that the current dataset code doesn't expect data variables to be of different types, and the dataset handles the .mean operation itself rather than delegating it down to each data variable. If that's the case, it might require quite some large-ish changes with lots of if isweighted, which wouldn't be great.

If it does delegate the .mean to each data variable (or we could change it to do that), then this could work quite nicely. And might also be generalizable to reduction operations on other arrays, such as sparse arrays...

Does that make sense?

@max-sixty max-sixty added enhancement and removed bug needs triage Issue that has not been reviewed by xarray team member labels Dec 1, 2024
@dcherian
Copy link
Contributor

dcherian commented Dec 2, 2024

I would like DataArrays that are unweighted to return the usual mean,

I don't think this is a good idea. Consider the case when the weights Dataset is mistakenly missing a couple of data vars. Then you'll unintentionally get unweighted means and not know about it!

You might consider simply adding a scalar 1 in the weights Dataset for any missing data var.

@max-sixty
Copy link
Collaborator

Consider the case when the weights Dataset is mistakenly missing a couple of data vars. Then you'll unintentionally get unweighted means and not know about it!

I'm interpreting this differently — the dataset has some data variables that are weighted and some that are unweighted. There's no ds.weighted(ds_weights) where a missing data variable in ds_weights creates an unweighted data variable?

Instead it's db = da.weighted(dw), where db is an array, and that array is assigned to the dataset.

(when I'm confused during a discussion of ours, it's 3 times out of 4 me who's missing something, so asking from the perspective of likely being wrong but hopefully nonetheless helpful)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants