diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 1ce84904623..e78261d485f 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -7052,15 +7052,22 @@ def map( if keep_attrs: for k, v in variables.items(): - v._copy_attrs_from(self.data_vars[k]) + if not (v.attrs and v.attrs != self.data_vars[k].attrs): + v._copy_attrs_from(self.data_vars[k]) for k, v in coords.items(): if k in self.coords: - v._copy_attrs_from(self.coords[k]) + if not (v.attrs and v.attrs != self.coords[k].attrs): + v._copy_attrs_from(self.coords[k]) else: - for v in variables.values(): - v.attrs = {} - for v in coords.values(): - v.attrs = {} + for k, v in variables.items(): + if not (v.attrs and v.attrs != self.data_vars[k].attrs): + v.attrs = {} + for k, v in coords.items(): + if k in self.coords: + if not (v.attrs and v.attrs != self.coords[k].attrs): + v.attrs = {} + else: + v.attrs = {} attrs = self.attrs if keep_attrs else None return type(self)(variables, coords=coords, attrs=attrs) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 762c647f15c..652dcfe1377 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -6572,6 +6572,40 @@ def func(arr): ds["x"].attrs["y"] = "x" assert ds["x"].attrs != actual["x"].attrs + def test_map_preserves_func_attrs(self) -> None: + # Regression test for GH#11356 + # Dataset.map should preserve attrs explicitly set by the mapped function + ds = xr.Dataset( + { + "a": ("x", [1, 2, 3], {"units": "kg"}), + "b": ("x", [4, 5, 6], {"units": "kg"}), + } + ) + + # keep_attrs=True, func sets attrs -> func's attrs preserved + result = ds.map( + lambda x: (x / x.sum()).assign_attrs(units="unitless"), keep_attrs=True + ) + assert result["a"].attrs == {"units": "unitless"} + assert result["b"].attrs == {"units": "unitless"} + + # keep_attrs=False, func sets attrs -> func's attrs preserved + result = ds.map( + lambda x: (x / x.sum()).assign_attrs(units="unitless"), keep_attrs=False + ) + assert result["a"].attrs == {"units": "unitless"} + assert result["b"].attrs == {"units": "unitless"} + + # keep_attrs=True, func doesn't set attrs -> original attrs restored + result = ds.map(lambda x: x / x.sum(), keep_attrs=True) + assert result["a"].attrs == {"units": "kg"} + assert result["b"].attrs == {"units": "kg"} + + # keep_attrs=False, func doesn't set attrs -> attrs wiped + result = ds.map(lambda x: x / x.sum(), keep_attrs=False) + assert result["a"].attrs == {} + assert result["b"].attrs == {} + def test_map_non_dataarray_outputs(self) -> None: # Test that map handles non-DataArray outputs by converting them # Regression test for GH10835