Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
T_PathFileOrDataStore,
_find_absolute_paths,
_normalize_path,
datatree_from_dict_with_io_cleanup,
)
from xarray.coders import CFDatetimeCoder, CFTimedeltaCoder
from xarray.core import dtypes, indexing
Expand Down Expand Up @@ -385,6 +386,35 @@ def _datatree_from_backend_datatree(
return tree


async def _maybe_create_default_indexes_async(ds):
import asyncio

# Determine which coords need default indexes
to_index_names = [
name
for name, coord in ds.coords.items()
if coord.dims == (name,) and name not in ds.xindexes
]

if to_index_names:

async def load_var(var):
try:
return await var.load_async()
except NotImplementedError:
return await asyncio.to_thread(var.load)

await asyncio.gather(
*[load_var(ds.coords[name].variable) for name in to_index_names]
)

# Build indexes (now data is in-memory so no remote I/O per coord)
to_index = {name: ds.coords[name].variable for name in to_index_names}
if to_index:
return ds.assign_coords(Coordinates(to_index))
return ds


def open_dataset(
filename_or_obj: T_PathFileOrDataStore,
*,
Expand Down Expand Up @@ -1099,6 +1129,137 @@ def open_datatree(
return tree


async def open_datatree_async(
filename_or_obj: T_PathFileOrDataStore,
*,
engine: T_Engine = None,
chunks: T_Chunks = None,
cache: bool | None = None,
decode_cf: bool | None = None,
mask_and_scale: bool | Mapping[str, bool] | None = None,
decode_times: bool
| CFDatetimeCoder
| Mapping[str, bool | CFDatetimeCoder]
| None = None,
decode_timedelta: bool
| CFTimedeltaCoder
| Mapping[str, bool | CFTimedeltaCoder]
| None = None,
use_cftime: bool | Mapping[str, bool] | None = None,
concat_characters: bool | Mapping[str, bool] | None = None,
decode_coords: Literal["coordinates", "all"] | bool | None = None,
drop_variables: str | Iterable[str] | None = None,
create_default_indexes: bool = True,
inline_array: bool = False,
chunked_array_type: str | None = None,
from_array_kwargs: dict[str, Any] | None = None,
backend_kwargs: dict[str, Any] | None = None,
**kwargs,
) -> DataTree:
"""Async version of open_datatree that concurrently builds default indexes.

Supports the "zarr" engine (both Zarr v2 and v3). For other engines, a
ValueError is raised.
"""
import asyncio

if cache is None:
cache = chunks is None

if backend_kwargs is not None:
kwargs.update(backend_kwargs)

if engine is None:
engine = plugins.guess_engine(filename_or_obj)

if from_array_kwargs is None:
from_array_kwargs = {}

# Only zarr supports async lazy loading at present
if engine != "zarr":
raise ValueError(f"Engine {engine!r} does not support asynchronous operations")

backend = plugins.get_backend(engine)

decoders = _resolve_decoders_kwargs(
decode_cf,
open_backend_dataset_parameters=backend.open_dataset_parameters,
mask_and_scale=mask_and_scale,
decode_times=decode_times,
decode_timedelta=decode_timedelta,
concat_characters=concat_characters,
use_cftime=use_cftime,
decode_coords=decode_coords,
)

overwrite_encoded_chunks = kwargs.pop("overwrite_encoded_chunks", None)

# Prefer backend async group opening if available (currently zarr only)
if hasattr(backend, "open_groups_as_dict_async"):
groups_dict = await backend.open_groups_as_dict_async(
filename_or_obj,
drop_variables=drop_variables,
**decoders,
**kwargs,
)
backend_tree = datatree_from_dict_with_io_cleanup(groups_dict)
else:
backend_tree = backend.open_datatree(
filename_or_obj,
drop_variables=drop_variables,
**decoders,
**kwargs,
)

# Protect variables for caching behavior consistency
_protect_datatree_variables_inplace(backend_tree, cache)

# For each dataset in the tree, concurrently create default indexes (if requested)
results: dict[str, Dataset] = {}

async def process_node(path: str, node_ds: Dataset) -> tuple[str, Dataset]:
ds = node_ds
if create_default_indexes:
ds = await _maybe_create_default_indexes_async(ds)
# Optional chunking (synchronous)
if chunks is not None:
ds = _chunk_ds(
ds,
filename_or_obj,
engine,
chunks,
overwrite_encoded_chunks,
inline_array,
chunked_array_type,
from_array_kwargs,
node=path,
**decoders,
**kwargs,
)
return path, ds

# Build tasks
tasks = [
process_node(path, node.dataset)
for path, [node] in group_subtrees(backend_tree)
]

# Execute concurrently and collect
for fut in asyncio.as_completed(tasks):
path, ds = await fut
results[path] = ds

# Build DataTree
tree = DataTree.from_dict(results)

# Carry over close handlers from backend tree when needed (mirrors sync path)
if create_default_indexes or chunks is not None:
for _path, [node] in group_subtrees(backend_tree):
tree[_path].set_close(node._close)

return tree


def open_groups(
filename_or_obj: T_PathFileOrDataStore,
*,
Expand Down
75 changes: 75 additions & 0 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import base64
import json
import os
Expand Down Expand Up @@ -1791,6 +1792,80 @@ def open_groups_as_dict(
groups_dict[group_name] = group_ds
return groups_dict

async def open_groups_as_dict_async(
self,
filename_or_obj: T_PathFileOrDataStore,
*,
mask_and_scale=True,
decode_times=True,
concat_characters=True,
decode_coords=True,
drop_variables: str | Iterable[str] | None = None,
use_cftime=None,
decode_timedelta=None,
group: str | None = None,
mode="r",
synchronizer=None,
consolidated=None,
chunk_store=None,
storage_options=None,
zarr_version=None,
zarr_format=None,
) -> dict[str, Dataset]:
"""Asynchronously open each group into a Dataset concurrently.

This mirrors open_groups_as_dict but parallelizes per-group Dataset opening,
which can significantly reduce latency on high-RTT object stores.
"""
filename_or_obj = _normalize_path(filename_or_obj)

# Determine parent group path context
if group:
parent = str(NodePath("/") / NodePath(group))
else:
parent = str(NodePath("/"))

# Discover group stores (synchronous metadata step)
stores = ZarrStore.open_store(
filename_or_obj,
group=parent,
mode=mode,
synchronizer=synchronizer,
consolidated=consolidated,
consolidate_on_close=False,
chunk_store=chunk_store,
storage_options=storage_options,
zarr_version=zarr_version,
zarr_format=zarr_format,
)

async def open_one(path_group: str, store) -> tuple[str, Dataset]:
store_entrypoint = StoreBackendEntrypoint()

def _load_sync():
with close_on_error(store):
return store_entrypoint.open_dataset(
store,
mask_and_scale=mask_and_scale,
decode_times=decode_times,
concat_characters=concat_characters,
decode_coords=decode_coords,
drop_variables=drop_variables,
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
)

ds = await asyncio.to_thread(_load_sync)
if group:
group_name = str(NodePath(path_group).relative_to(parent))
else:
group_name = str(NodePath(path_group))
return group_name, ds

tasks = [open_one(path_group, store) for path_group, store in stores.items()]
results = await asyncio.gather(*tasks)
return dict(results)


def _iter_zarr_groups(root: ZarrGroup, parent: str = "/") -> Iterable[str]:
parent_nodepath = NodePath(parent)
Expand Down
Loading
Loading