Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7052,15 +7052,22 @@ def map(

if keep_attrs:
for k, v in variables.items():
v._copy_attrs_from(self.data_vars[k])
if not (v.attrs and v.attrs != self.data_vars[k].attrs):
v._copy_attrs_from(self.data_vars[k])
for k, v in coords.items():
if k in self.coords:
v._copy_attrs_from(self.coords[k])
if not (v.attrs and v.attrs != self.coords[k].attrs):
v._copy_attrs_from(self.coords[k])
else:
for v in variables.values():
v.attrs = {}
for v in coords.values():
v.attrs = {}
for k, v in variables.items():
if not (v.attrs and v.attrs != self.data_vars[k].attrs):
v.attrs = {}
for k, v in coords.items():
if k in self.coords:
if not (v.attrs and v.attrs != self.coords[k].attrs):
v.attrs = {}
else:
v.attrs = {}

attrs = self.attrs if keep_attrs else None
return type(self)(variables, coords=coords, attrs=attrs)
Expand Down
34 changes: 34 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6572,6 +6572,40 @@ def func(arr):
ds["x"].attrs["y"] = "x"
assert ds["x"].attrs != actual["x"].attrs

def test_map_preserves_func_attrs(self) -> None:
# Regression test for GH#11356
# Dataset.map should preserve attrs explicitly set by the mapped function
ds = xr.Dataset(
{
"a": ("x", [1, 2, 3], {"units": "kg"}),
"b": ("x", [4, 5, 6], {"units": "kg"}),
}
)

# keep_attrs=True, func sets attrs -> func's attrs preserved
result = ds.map(
lambda x: (x / x.sum()).assign_attrs(units="unitless"), keep_attrs=True
)
assert result["a"].attrs == {"units": "unitless"}
assert result["b"].attrs == {"units": "unitless"}

# keep_attrs=False, func sets attrs -> func's attrs preserved
result = ds.map(
lambda x: (x / x.sum()).assign_attrs(units="unitless"), keep_attrs=False
)
assert result["a"].attrs == {"units": "unitless"}
assert result["b"].attrs == {"units": "unitless"}

# keep_attrs=True, func doesn't set attrs -> original attrs restored
result = ds.map(lambda x: x / x.sum(), keep_attrs=True)
assert result["a"].attrs == {"units": "kg"}
assert result["b"].attrs == {"units": "kg"}

# keep_attrs=False, func doesn't set attrs -> attrs wiped
result = ds.map(lambda x: x / x.sum(), keep_attrs=False)
assert result["a"].attrs == {}
assert result["b"].attrs == {}

def test_map_non_dataarray_outputs(self) -> None:
# Test that map handles non-DataArray outputs by converting them
# Regression test for GH10835
Expand Down
Loading