From 6eecafa104475d762baa8696af9ccc95c7ff54c6 Mon Sep 17 00:00:00 2001 From: C1-BA-B1-F3 Date: Fri, 26 Jun 2026 21:58:00 +0800 Subject: [PATCH] Fix Dataset.map to preserve attrs set by the applied function Previously, when a function passed to Dataset.map explicitly set attributes on a DataArray (e.g., via .assign_attrs()), those attrs were lost because: - With keep_attrs=True: original attrs were copied back, overwriting func's changes - With keep_attrs=False: all attrs were wiped to empty dicts Now, if the function returns a DataArray with non-empty attrs that differ from the original, those attrs are preserved regardless of the keep_attrs setting. This allows users to update variable attributes through Dataset.map, which was previously impossible. Fixes #11356 --- xarray/core/dataset.py | 19 +++++++++++++------ xarray/tests/test_dataset.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 1ce84904623..e78261d485f 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -7052,15 +7052,22 @@ def map( if keep_attrs: for k, v in variables.items(): - v._copy_attrs_from(self.data_vars[k]) + if not (v.attrs and v.attrs != self.data_vars[k].attrs): + v._copy_attrs_from(self.data_vars[k]) for k, v in coords.items(): if k in self.coords: - v._copy_attrs_from(self.coords[k]) + if not (v.attrs and v.attrs != self.coords[k].attrs): + v._copy_attrs_from(self.coords[k]) else: - for v in variables.values(): - v.attrs = {} - for v in coords.values(): - v.attrs = {} + for k, v in variables.items(): + if not (v.attrs and v.attrs != self.data_vars[k].attrs): + v.attrs = {} + for k, v in coords.items(): + if k in self.coords: + if not (v.attrs and v.attrs != self.coords[k].attrs): + v.attrs = {} + else: + v.attrs = {} attrs = self.attrs if keep_attrs else None return type(self)(variables, coords=coords, attrs=attrs) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 762c647f15c..652dcfe1377 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -6572,6 +6572,40 @@ def func(arr): ds["x"].attrs["y"] = "x" assert ds["x"].attrs != actual["x"].attrs + def test_map_preserves_func_attrs(self) -> None: + # Regression test for GH#11356 + # Dataset.map should preserve attrs explicitly set by the mapped function + ds = xr.Dataset( + { + "a": ("x", [1, 2, 3], {"units": "kg"}), + "b": ("x", [4, 5, 6], {"units": "kg"}), + } + ) + + # keep_attrs=True, func sets attrs -> func's attrs preserved + result = ds.map( + lambda x: (x / x.sum()).assign_attrs(units="unitless"), keep_attrs=True + ) + assert result["a"].attrs == {"units": "unitless"} + assert result["b"].attrs == {"units": "unitless"} + + # keep_attrs=False, func sets attrs -> func's attrs preserved + result = ds.map( + lambda x: (x / x.sum()).assign_attrs(units="unitless"), keep_attrs=False + ) + assert result["a"].attrs == {"units": "unitless"} + assert result["b"].attrs == {"units": "unitless"} + + # keep_attrs=True, func doesn't set attrs -> original attrs restored + result = ds.map(lambda x: x / x.sum(), keep_attrs=True) + assert result["a"].attrs == {"units": "kg"} + assert result["b"].attrs == {"units": "kg"} + + # keep_attrs=False, func doesn't set attrs -> attrs wiped + result = ds.map(lambda x: x / x.sum(), keep_attrs=False) + assert result["a"].attrs == {} + assert result["b"].attrs == {} + def test_map_non_dataarray_outputs(self) -> None: # Test that map handles non-DataArray outputs by converting them # Regression test for GH10835