diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 51daaf7e7..5ef119a18 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -71,7 +71,7 @@ jobs: id: status run: | uv run --no-dev python -c "import xarray; xarray.show_versions()" || true - uv run --no-dev pytest --durations=20 --durations-min=0.5 -n auto --cov=./ --cov-report=xml --hypothesis-profile ci + uv run --no-dev pytest --durations=20 --durations-min=0.5 -n auto --cov=./ --cov-report=xml --hypothesis-profile ci --log-disable=flox - name: Upload code coverage to Codecov uses: codecov/codecov-action@v5.5.1 with: diff --git a/.gitignore b/.gitignore index d0fdf4f0b..d8e029f52 100644 --- a/.gitignore +++ b/.gitignore @@ -113,3 +113,22 @@ venv.bak/ # Git worktrees worktrees/ + +# Auto-generated version file +flox/_version.py + +# Temporary files +Untitled.ipynb +*.rej +*.py.rej +mutmut-cache +.mutmut-cache +mydask.png +profile.json +profile.html +test.png +uv.lock +devel/ + +# Claude Code +.claude/ diff --git a/CLAUDE.md b/CLAUDE.md index 830ee55d6..28dcefad3 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -149,3 +149,70 @@ asv preview - Integration testing with xarray upstream development branch - **Python Support**: Minimum version 3.11 (updated from 3.10) - **Git Worktrees**: `worktrees/` directory is ignored for development workflows +- **Running Tests**: Always use `uv run pytest` to run tests (not just `pytest`) + +## Key Implementation Details + +### Map-Reduce Combine Strategies (`flox/dask.py`) + +There are two strategies for combining intermediate results in dask's tree reduction: + +1. **`_simple_combine`**: Used for most reductions. Tree-reduces the reduction itself (not the groupby-reduction) for performance. Requirements: + + - All blocks must contain all groups after blockwise step (reindex.blockwise=True) + - Must know expected_groups + - Inserts DUMMY_AXIS=-2 via `_expand_dims`, reduces along it, then squeezes it out + - Used when: not an arg reduction, not first/last with non-float dtype, and labels are known + +1. **`_grouped_combine`**: More general solution that tree-reduces the groupby-reduction itself. Used for: + + - Arg reductions (argmax, argmin, etc.) + - When labels are unknown (dask arrays without expected_groups) + - First/last reductions with non-float dtypes + +### Aggregations with New Dimensions + +Some aggregations add new dimensions to the output (e.g., topk, quantile): + +- **`new_dims_func`**: Function that returns tuple of Dim objects for new dimensions +- These MUST use `_simple_combine` because intermediate results have an extra dimension that needs to be reduced along DUMMY_AXIS +- Check if `new_dims_func(**finalize_kwargs)` returns non-empty tuple to determine if aggregation actually adds dimensions +- **Note**: argmax/argmin have `new_dims_func` but return empty tuple, so they use `_grouped_combine` + +### topk Implementation + +The topk aggregation is special: + +- Uses `_simple_combine` (has non-empty new_dims_func) +- First intermediate (topk values) combines along axis 0, not DUMMY_AXIS +- Does NOT squeeze out DUMMY_AXIS in final aggregate step +- `_expand_dims` only expands non-topk intermediates (the second one, nanlen) + +### Axis Parameter Handling + +- **`_simple_combine`**: Always receives axis as tuple (e.g., `(-2,)` for DUMMY_AXIS) +- **numpy functions**: Most accept both tuple and integer axis (e.g., np.max, np.sum) +- **Exception**: argmax/argmin don't accept tuple axis, but these use `_grouped_combine` +- **Custom functions**: Like `_var_combine` should normalize axis to tuple if needed for iteration + +### Test Organization + +- **`test_groupby_reduce_all`**: Comprehensive test for all aggregations with various parameters (nby, chunks, etc.) + + - Tests both with and without NaN handling + - For topk: sorts results along axis 0 before comparison (k dimension is at axis 0) + - Uses `np.moveaxis` not `np.swapaxes` for topk to avoid swapping other dimensions + +- **`test_groupby_reduce_axis_subset_against_numpy`**: Tests reductions over subsets of axes + + - Compares dask results against numpy results + - Tests various axis combinations: None, single int, tuples + - Skip arg reductions with axis=None or multiple axes (not supported) + +### Common Pitfalls + +1. **Axis transformations for topk**: Use `np.moveaxis(expected, src, 0)` not `np.swapaxes(expected, src, 0)` to move k dimension to position 0 without reordering other dimensions + +1. **new_dims_func checking**: Check if it returns non-empty dimensions, not just if it exists (argmax has one that returns `()`) + +1. **Axis parameter types**: Custom combine functions should handle both tuple and integer axis by normalizing at the start diff --git a/devel b/devel new file mode 120000 index 000000000..13dedd1ad --- /dev/null +++ b/devel @@ -0,0 +1 @@ +../devel/flox \ No newline at end of file diff --git a/docs/source/aggregations.md b/docs/source/aggregations.md index d3591d2dc..82562cc3a 100644 --- a/docs/source/aggregations.md +++ b/docs/source/aggregations.md @@ -9,19 +9,16 @@ the `func` kwarg: - `"mean"`, `"nanmean"` - `"var"`, `"nanvar"` - `"std"`, `"nanstd"` -- `"argmin"` -- `"argmax"` +- `"argmin"`, `"nanargmax"` +- `"argmax"`, `"nanargmin"` - `"first"`, `"nanfirst"` - `"last"`, `"nanlast"` - `"median"`, `"nanmedian"` - `"mode"`, `"nanmode"` - `"quantile"`, `"nanquantile"` +- `"topk"` -```{tip} -We would like to add support for `cumsum`, `cumprod` ([issue](https://github.com/xarray-contrib/flox/issues/91)). Contributions are welcome! -``` - -## Custom Aggregations +## Custom Reductions `flox` also allows you to specify a custom Aggregation (again inspired by dask.dataframe), though this might not be fully functional at the moment. See `aggregations.py` for examples. @@ -46,3 +43,7 @@ mean = Aggregation( final_fill_value=np.nan, ) ``` + +## Custom Scans + +Coming soon! diff --git a/flox/aggregate_flox.py b/flox/aggregate_flox.py index 127506453..798a14d8b 100644 --- a/flox/aggregate_flox.py +++ b/flox/aggregate_flox.py @@ -47,14 +47,32 @@ def _lerp(a, b, *, t, dtype, out=None): return out -def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=None): +def quantile_or_topk( + array, + inv_idx, + *, + q=None, + k=None, + axis, + skipna, + group_idx, + dtype=None, + out=None, + fill_value=None, +): + assert q is not None or k is not None + assert axis == -1 + inv_idx = np.concatenate((inv_idx, [array.shape[-1]])) array_validmask = notnull(array) actual_sizes = np.add.reduceat(array_validmask, inv_idx[:-1], axis=axis) newshape = (1,) * (array.ndim - 1) + (inv_idx.size - 1,) - full_sizes = np.reshape(np.diff(inv_idx), newshape) - nanmask = full_sizes != actual_sizes + if k is not None: + nanmask = actual_sizes < abs(k) + else: + full_sizes = np.reshape(np.diff(inv_idx), newshape) + nanmask = full_sizes != actual_sizes # The approach here is to use (complex_array.partition) because # 1. The full np.lexsort((array, labels), axis=-1) is slow and unnecessary @@ -72,36 +90,63 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non # So we determine which indices we need using the fact that NaNs get sorted to the end. # This *was* partly inspired by https://krstn.eu/np.nanpercentile()-there-has-to-be-a-faster-way/ # but not any more now that I use partition and avoid replacing NaNs - qin = q - q = np.atleast_1d(qin) - q = np.reshape(q, (len(q),) + (1,) * array.ndim) + if k is not None: + is_scalar_param = False + param = np.sort(np.arange(abs(k)) * np.sign(k)) + else: + is_scalar_param = is_scalar(q) + param = np.atleast_1d(q) + param = np.reshape(param, (param.size,) + (1,) * array.ndim) # This is numpy's method="linear" # TODO: could support all the interpolations here offset = actual_sizes.cumsum(axis=-1) - actual_sizes -= 1 - virtual_index = q * actual_sizes - # virtual_index is relative to group starts, so now offset that - virtual_index[..., 1:] += offset[..., :-1] - - is_scalar_q = is_scalar(qin) - if is_scalar_q: - virtual_index = virtual_index.squeeze(axis=0) - idxshape = array.shape[:-1] + (actual_sizes.shape[-1],) - else: - idxshape = (q.shape[0],) + array.shape[:-1] + (actual_sizes.shape[-1],) + # For topk(.., k=+1 or -1), we always return the singleton dimension. + idxshape = (param.shape[0],) + array.shape[:-1] + (actual_sizes.shape[-1],) - lo_ = np.floor( - virtual_index, - casting="unsafe", - out=np.empty(virtual_index.shape, dtype=np.int64), - ) - hi_ = np.ceil( - virtual_index, - casting="unsafe", - out=np.empty(virtual_index.shape, dtype=np.int64), - ) - kth = np.unique(np.concatenate([lo_.reshape(-1), hi_.reshape(-1)])) + if q is not None: + # This is numpy's method="linear" + # TODO: could support all the interpolations here + actual_sizes -= 1 + virtual_index = param * actual_sizes + # virtual_index is relative to group starts, so now offset that + virtual_index[..., 1:] += offset[..., :-1] + + if is_scalar_param: + virtual_index = virtual_index.squeeze(axis=0) + idxshape = array.shape[:-1] + (actual_sizes.shape[-1],) + + lo_ = np.floor(virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64)) + hi_ = np.ceil(virtual_index, casting="unsafe", out=np.empty(virtual_index.shape, dtype=np.int64)) + kth = np.unique(np.concatenate([lo_.reshape(-1), hi_.reshape(-1)])) + + else: + virtual_index = (actual_sizes - k) if k > 0 else (np.zeros_like(actual_sizes) + abs(k) - 1) + # virtual_index is relative to group starts, so now offset that + virtual_index[..., 1:] += offset[..., :-1] + k_offset = param.reshape((abs(k),) + (1,) * virtual_index.ndim) + lo_ = k_offset + virtual_index[np.newaxis, ...] + # For groups with fewer than k elements, clamp extraction indices to valid range + # and mark out-of-bounds positions for filling with fill_value. + # Compute group boundaries: starts = [0, offset[:-1]], ends = offset + # We prepend 0 to offset[:-1] to get group start positions + group_starts = np.insert(offset[..., :-1], 0, 0, axis=-1) + + # Mark positions outside group boundaries (before clamping to detect invalid indices) + # Broadcasting happens implicitly in comparison + badmask = (lo_ < group_starts) | (lo_ >= offset) + + # Clamp lo_ in-place to [group_starts, array.shape[axis]-1] + # Using out= avoids intermediate array allocations + np.clip(lo_, group_starts, array.shape[axis] - 1, out=lo_) + # Note: we don't include nanmask here because for intermediate chunk results, + # we want to keep partial results. nanmask is used separately for final output. + # kth must include ALL indices we'll extract, not just the starting index per group. + # np.partition only guarantees correct values at kth positions; other positions may + # have elements from different groups due to how introselect works with complex numbers. + kth = np.unique(np.concatenate([np.unique(offset), np.unique(lo_)])) + kth = kth[kth >= 0] + kth[kth >= array.shape[axis]] = array.shape[axis] - 1 # partition the complex array in-place labels_broadcast = np.broadcast_to(group_idx, array.shape) @@ -111,20 +156,33 @@ def quantile_(array, inv_idx, *, q, axis, skipna, group_idx, dtype=None, out=Non # a simple (labels + 1j * array) will yield `nan+inf * 1j` instead of `0 + inf * j` cmplx.real = labels_broadcast cmplx.partition(kth=kth, axis=-1) - if is_scalar_q: - a_ = cmplx.imag - else: - a_ = np.broadcast_to(cmplx.imag, (q.shape[0],) + array.shape) - # get bounds, Broadcast to (num quantiles, ..., num labels) - loval = np.take_along_axis(a_, np.broadcast_to(lo_, idxshape), axis=axis) - hival = np.take_along_axis(a_, np.broadcast_to(hi_, idxshape), axis=axis) + a_ = cmplx.imag + if not is_scalar_param: + a_ = np.broadcast_to(cmplx.imag, (param.shape[0],) + array.shape) - # TODO: could support all the interpolations here - gamma = np.broadcast_to(virtual_index, idxshape) - lo_ - result = _lerp(loval, hival, t=gamma, out=out, dtype=dtype) - if not skipna and np.any(nanmask): - result[..., nanmask] = np.nan + if array.dtype.kind in "Mm": + a_ = a_.view(array.dtype) + + loval = np.take_along_axis(a_, np.broadcast_to(lo_, idxshape), axis=axis) + if q is not None: + # get bounds, Broadcast to (num quantiles, ..., num labels) + hival = np.take_along_axis(a_, np.broadcast_to(hi_, idxshape), axis=axis) + + # TODO: could support all the interpolations here + gamma = np.broadcast_to(virtual_index, idxshape) - lo_ + result = _lerp(loval, hival, t=gamma, out=out, dtype=dtype) + if not skipna and np.any(nanmask): + result[..., nanmask] = fill_value + else: + result = loval + if badmask.any(): + result[badmask] = fill_value + + if k is not None: + result = result.astype(dtype, copy=False) + if out is not None: + np.copyto(out, result) return result @@ -158,12 +216,14 @@ def _np_grouped_op( if out is None: q = kwargs.get("q", None) - if q is None: + k = kwargs.get("k", None) + if q is None and k is None: out = np.full(array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype) else: - nq = len(np.atleast_1d(q)) + nq = len(np.atleast_1d(q)) if q is not None else abs(k) out = np.full((nq,) + array.shape[:-1] + (size,), fill_value=fill_value, dtype=dtype) kwargs["group_idx"] = group_idx + kwargs["fill_value"] = fill_value if (len(uniques) == size) and (uniques == np.arange(size, like=aux)).all(): # The previous version of this if condition @@ -200,10 +260,11 @@ def _nan_grouped_op(group_idx, array, func, fillna, *args, **kwargs): nanmax = partial(_nan_grouped_op, func=max, fillna=dtypes.NINF) min = partial(_np_grouped_op, op=np.minimum.reduceat) nanmin = partial(_nan_grouped_op, func=min, fillna=dtypes.INF) -quantile = partial(_np_grouped_op, op=partial(quantile_, skipna=False)) -nanquantile = partial(_np_grouped_op, op=partial(quantile_, skipna=True)) -median = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_, skipna=False)) -nanmedian = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_, skipna=True)) +topk = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=True)) +quantile = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=False)) +nanquantile = partial(_np_grouped_op, op=partial(quantile_or_topk, skipna=True)) +median = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_or_topk, skipna=False)) +nanmedian = partial(partial(_np_grouped_op, q=0.5), op=partial(quantile_or_topk, skipna=True)) # TODO: all, any diff --git a/flox/aggregations.py b/flox/aggregations.py index e29feb758..217342ffb 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -140,7 +140,7 @@ def _atleast_1d(inp, min_length: int = 1): return inp -def returns_empty_tuple(*args, **kwargs): +def returns_empty_tuple(*args, **kwargs) -> tuple: return () @@ -287,6 +287,7 @@ def __dask_tokenize__(self): self.finalize, self.fill_value, self.dtype, + tuple(sorted(self.finalize_kwargs.items())) if self.finalize_kwargs else (), ) def __repr__(self) -> str: @@ -390,6 +391,10 @@ def var_chunk( def _var_combine(array, axis, keepdims=True): + # Ensure axis is always a tuple for iteration + if not isinstance(axis, tuple): + axis = (axis,) + def clip_last(array, ax, n=1): """Return array except the last element along axis Purely included to tidy up the adj_terms line @@ -689,6 +694,10 @@ def quantile_new_dims_func(q) -> tuple[Dim]: return (Dim(name="quantile", values=q),) +def topk_new_dims_func(k) -> tuple[Dim]: + return (Dim(name="k", values=np.arange(abs(k))),) + + # if the input contains integers or floats smaller than float64, # the output data-type is float64. Otherwise, the output data-type is the same as that # of the input. @@ -712,6 +721,30 @@ def quantile_new_dims_func(q) -> tuple[Dim]: nanmode = Aggregation(name="nanmode", fill_value=dtypes.NA, chunk=None, combine=None, preserves_dtype=True) +def _topk_finalize(values, counts, *, k): + """Convert -inf fill values back to NaN for topk results.""" + import numpy as np + + # After combine with nantopk, -inf values need to be converted to NaN + # k > 0: -inf was used as fill value + # k < 0: +inf was used as fill value + fill_val = -np.inf if k > 0 else np.inf + return np.where(values == fill_val, np.nan, values) + + +topk = Aggregation( + name="topk", + fill_value=(dtypes.NINF, 0), + final_fill_value=dtypes.NA, + # FIXME: set numpy + chunk=("topk", "nanlen"), + combine=(xrutils.nantopk, "sum"), + finalize=_topk_finalize, + new_dims_func=topk_new_dims_func, + preserves_dtype=True, +) + + @dataclass class Scan: # This dataclass is separate from Aggregations since there's not much in common @@ -910,6 +943,7 @@ def scan_binary_op(left_state: ScanState, right_state: ScanState, *, agg: Scan) "nanquantile": nanquantile, "mode": mode, "nanmode": nanmode, + "topk": topk, "cumsum": cumsum, "nancumsum": nancumsum, "ffill": ffill, @@ -964,6 +998,12 @@ def _initialize_aggregation( ), } + if finalize_kwargs is not None: + assert isinstance(finalize_kwargs, dict) + agg.finalize_kwargs = finalize_kwargs + + if agg.name == "topk" and agg.finalize_kwargs.get("k", 1) < 0: + agg.fill_value["intermediate"] = (dtypes.INF, 0) # Replace sentinel fill values according to dtype agg.fill_value["user"] = fill_value agg.fill_value["intermediate"] = tuple( @@ -978,10 +1018,6 @@ def _initialize_aggregation( else: agg.fill_value["numpy"] = (agg.fill_value[func],) - if finalize_kwargs is not None: - assert isinstance(finalize_kwargs, dict) - agg.finalize_kwargs = finalize_kwargs - # This is needed for the dask pathway. # Because we use intermediate fill_value since a group could be # absent in one block, but present in another block @@ -1018,6 +1054,11 @@ def _initialize_aggregation( else: simple_combine.append(getattr(np, combine)) else: + # TODO: bah, we need to pass `k` to the combine topk function + # this is ugly. + if agg.name == "topk" and not isinstance(combine, str): + assert combine is not None + combine = partial(combine, **agg.finalize_kwargs) simple_combine.append(combine) agg.simple_combine = tuple(simple_combine) diff --git a/flox/core.py b/flox/core.py index ec101cf0b..36b8bf94e 100644 --- a/flox/core.py +++ b/flox/core.py @@ -27,6 +27,7 @@ generic_aggregate, is_var_chunk_reduction, quantile_new_dims_func, + topk_new_dims_func, ) from .factorize import ( _factorize_multiple, @@ -94,12 +95,6 @@ T = TypeVar("T") -# This dummy axis is inserted using np.expand_dims -# and then reduced over during the combine stage by -# _simple_combine. -DUMMY_AXIS = -2 - - logger = logging.getLogger("flox") @@ -347,7 +342,8 @@ def chunk_reduce( for reduction, fv, kw, dt in zip(funcs, fill_values, kwargss, dtypes): # UGLY! but this is because the `var` breaks our design assumptions if empty and not is_var_chunk_reduction(reduction): - result = np.full(shape=final_array_shape, fill_value=fv, like=array) + empty_shape = (abs(kw["k"]), *final_array_shape) if reduction == "topk" else final_array_shape + result = np.full(shape=empty_shape, fill_value=fv, like=array) elif _is_nanlen(reduction) and _is_nanlen(previous_reduction): result = results["intermediates"][-1] else: @@ -383,6 +379,8 @@ def chunk_reduce( # TODO: Figure out how to generalize this if reduction in ("quantile", "nanquantile"): new_dims_shape = tuple(dim.size for dim in quantile_new_dims_func(**kw) if not dim.is_scalar) + elif reduction == "topk": + new_dims_shape = tuple(dim.size for dim in topk_new_dims_func(**kw) if not dim.is_scalar) else: new_dims_shape = tuple() result = result.reshape(new_dims_shape + final_array_shape[:-1] + found_groups_shape) @@ -713,7 +711,7 @@ def _choose_engine(by, agg: Aggregation): not_arg_reduce = not _is_arg_reduction(agg) - if agg.name in ["quantile", "nanquantile", "median", "nanmedian"]: + if agg.name in ["quantile", "nanquantile", "median", "nanmedian", "topk"]: logger.debug(f"_choose_engine: Choosing 'flox' since {agg.name}") return "flox" @@ -764,7 +762,7 @@ def groupby_reduce( equality check are for dimensions of size 1 in ``by``. func : {"all", "any", "count", "sum", "nansum", "mean", "nanmean", \ "max", "nanmax", "min", "nanmin", "argmax", "nanargmax", "argmin", "nanargmin", \ - "quantile", "nanquantile", "median", "nanmedian", "mode", "nanmode", \ + "quantile", "nanquantile", "median", "nanmedian", "topk", "mode", "nanmode", \ "first", "nanfirst", "last", "nanlast"} or Aggregation Single function name or an Aggregation instance expected_groups : (optional) Sequence @@ -840,6 +838,11 @@ def groupby_reduce( finalize_kwargs : dict, optional Kwargs passed to finalize the reduction such as ``ddof`` for var, std or ``q`` for quantile. + Notes + ----- + ``topk`` and ``quantile`` are implemented by converting to a complex number and so are limited to values between +-``2**53-1`` + i.e. the limit of a ``float64`` dtype. Offset your data appropriately if you need the larger range. + Returns ------- result @@ -876,6 +879,8 @@ def groupby_reduce( "Use engine='flox' instead (it is also much faster), " "or set engine=None to use the default." ) + if func == "topk" and (finalize_kwargs is None or "k" not in finalize_kwargs): + raise ValueError("Please pass `k` in ``finalize_kwargs`` for topk calculations.") bys: T_Bys = tuple(np.asarray(b) if not is_duck_array(b) else b for b in by) nby = len(bys) @@ -902,6 +907,12 @@ def groupby_reduce( if not is_duck_array(array): array = np.asarray(array) + # topk with reindex=False not yet supported + if func == "topk" and reindex is False: + raise NotImplementedError( + "topk with reindex=False is not yet supported. Use reindex=True or reindex=None." + ) + reindex = _validate_reindex( reindex, func, diff --git a/flox/dask.py b/flox/dask.py index fda455e41..e49a504bd 100644 --- a/flox/dask.py +++ b/flox/dask.py @@ -4,6 +4,7 @@ import itertools import operator +import warnings from collections.abc import Callable, Sequence from functools import partial from numbers import Integral @@ -21,7 +22,6 @@ from .types import DaskArray, Graph, IntermediateDict, T_By from .core import ( - DUMMY_AXIS, _get_chunk_reduction, _reduce_blockwise, _unique, @@ -37,6 +37,11 @@ from .types import FinalResultsDict, IntermediateDict from .xrutils import is_duck_dask_array, notnull +# This dummy axis is inserted using np.expand_dims +# and then reduced over during the combine stage by +# _simple_combine. +DUMMY_AXIS = -2 + def listify_groups(x: IntermediateDict): return list(np.atleast_1d(x["groups"].squeeze())) @@ -99,8 +104,6 @@ def _simple_combine( DUMMY_AXIS 4. At the final aggregate step, we squeeze out DUMMY_AXIS """ - import warnings - from dask.array.core import deepfirst from dask.utils import deepmap @@ -122,15 +125,19 @@ def _simple_combine( results: IntermediateDict = {"groups": unique_groups} results["intermediates"] = [] - axis_ = axis[:-1] + (DUMMY_AXIS,) for idx, combine in enumerate(agg.simple_combine): + if agg.name == "topk" and idx == 0: + axis_ = axis[:-1] + (0,) + else: + axis_ = axis[:-1] + (DUMMY_AXIS,) + array = _conc2(x_chunk, key1="intermediates", key2=idx, axis=axis_) assert array.ndim >= 2 with warnings.catch_warnings(): warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered") assert callable(combine) result = combine(array, axis=axis_, keepdims=True) - if is_aggregate: + if is_aggregate and agg.name != "topk": # squeeze out DUMMY_AXIS if this is the last step i.e. called from _aggregate # can't just pass DUMMY_AXIS, because of sparse.COO result = result.squeeze(range(result.ndim)[DUMMY_AXIS]) @@ -145,10 +152,16 @@ def _extract_result(result_dict: FinalResultsDict, key) -> np.ndarray: return deepfirst(result_dict)[key] -def _expand_dims(results: IntermediateDict) -> IntermediateDict: - results["intermediates"] = tuple( - np.expand_dims(array, axis=DUMMY_AXIS) for array in results["intermediates"] - ) +def _expand_dims(results: IntermediateDict, agg: Aggregation) -> IntermediateDict: + if agg.name == "topk": + # don't expand the topk intermediates, but expand all else + results["intermediates"] = tuple(results["intermediates"][:1]) + tuple( + np.expand_dims(array, axis=DUMMY_AXIS) for array in results["intermediates"][1:] + ) + else: + results["intermediates"] = tuple( + np.expand_dims(array, axis=DUMMY_AXIS) for array in results["intermediates"] + ) return results @@ -374,7 +387,12 @@ def dask_groupby_agg( # This allows us to discover groups at compute time, support argreductions, lower intermediate # memory usage (but method="cohorts" would also work to reduce memory in some cases) labels_are_unknown = is_duck_dask_array(by_input) and expected_groups is None - do_grouped_combine = ( + # For reductions with new_dims_func that actually add dimensions (quantile, topk), + # we must use _simple_combine because the intermediate results have an extra dimension + # that needs to be reduced along DUMMY_AXIS, not along the groups axis. + # Check if new_dims_func actually returns non-empty dimensions + must_use_simple_combine = agg.num_new_vector_dims > 0 + do_grouped_combine = not must_use_simple_combine and ( _is_arg_reduction(agg) or labels_are_unknown or (_is_first_last_reduction(agg) and array.dtype.kind != "f") @@ -385,6 +403,9 @@ def dask_groupby_agg( # use the "non dask" code path, but applied blockwise blockwise_method = partial(_reduce_blockwise, agg=agg, fill_value=fill_value, reindex=reindex) else: + extra = {} + if agg.name == "topk": + extra["kwargs"] = (agg.finalize_kwargs, *(({},) * (len(agg.chunk) - 1))) # choose `chunk_reduce` or `chunk_argreduce` blockwise_method = partial( _get_chunk_reduction(agg.reduction_type), @@ -393,10 +414,11 @@ def dask_groupby_agg( fill_value=agg.fill_value["intermediate"], dtype=agg.dtype["intermediate"], user_dtype=agg.dtype["user"], + **extra, ) if do_simple_combine: # Add a dummy dimension that then gets reduced over - blockwise_method = tlz.compose(_expand_dims, blockwise_method) + blockwise_method = tlz.compose(partial(_expand_dims, agg=agg), blockwise_method) # apply reduction on chunk intermediate = dask.array.blockwise( diff --git a/flox/lib.py b/flox/lib.py index 3cca5532e..380fde3ec 100644 --- a/flox/lib.py +++ b/flox/lib.py @@ -110,7 +110,7 @@ def _is_bool_supported_reduction(func: T_Agg) -> bool: if isinstance(func, Aggregation): func = func.name return ( - func in ["all", "any"] + func in ["all", "any", "topk"] # TODO: enable in npg # or _is_first_last_reduction(func) # or _is_minmax_reduction(func) diff --git a/flox/xarray.py b/flox/xarray.py index ff4b12bd6..82ef0494a 100644 --- a/flox/xarray.py +++ b/flox/xarray.py @@ -9,7 +9,13 @@ import xarray as xr from packaging.version import Version -from .aggregations import Aggregation, Dim, _atleast_1d, quantile_new_dims_func +from .aggregations import ( + Aggregation, + Dim, + _atleast_1d, + quantile_new_dims_func, + topk_new_dims_func, +) from .core import ( _convert_expected_groups_to_index, _get_expected_groups, @@ -92,7 +98,7 @@ def xarray_reduce( Variables with which to group by ``obj`` func : {"all", "any", "count", "sum", "nansum", "mean", "nanmean", \ "max", "nanmax", "min", "nanmin", "argmax", "nanargmax", "argmin", "nanargmin", \ - "quantile", "nanquantile", "median", "nanmedian", "mode", "nanmode", \ + "quantile", "nanquantile", "median", "nanmedian", "topk", "mode", "nanmode", \ "first", "nanfirst", "last", "nanlast"} or Aggregation Single function name or an Aggregation instance expected_groups : str or sequence @@ -181,6 +187,11 @@ def xarray_reduce( DataArray or Dataset Reduced object + Notes + ----- + ``topk`` and ``quantile`` are implemented by converting to a complex number and so are limited to values between +-``2**53-1`` + i.e. the limit of a ``float64`` dtype. Offset your data appropriately if you need the larger range. + See Also -------- flox.core.groupby_reduce @@ -372,16 +383,20 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs): result, *groups = groupby_reduce(array, *by, func=func, **kwargs) - # Transpose the new quantile dimension to the end. This is ugly. + # Transpose the new quantile or topk dimension to the end. This is ugly. # but new core dimensions are expected at the end :/ # but groupby_reduce inserts them at the beginning if func in ["quantile", "nanquantile"]: (newdim,) = quantile_new_dims_func(**finalize_kwargs) - if not newdim.is_scalar: - # NOTE: _restore_dim_order will move any new dims to the end anyway. - # This transpose is simply makes it easy to specify output_core_dims - # output dim order: (*broadcast_dims, *group_dims, quantile_dim) - result = np.moveaxis(result, 0, -1) + elif func == "topk": + (newdim,) = topk_new_dims_func(**finalize_kwargs) + else: + newdim = None + if newdim is not None and not newdim.is_scalar: + # NOTE: _restore_dim_order will move any new dims to the end anyway. + # This transpose is simply makes it easy to specify output_core_dims + # output dim order: (*broadcast_dims, *group_dims, quantile_dim) + result = np.moveaxis(result, 0, -1) return result @@ -401,9 +416,11 @@ def wrapper(array, *by, func, skipna, core_dims, **kwargs): input_core_dims = [[d for d in grouper_dims if d not in dim_tuple] + list(dim_tuple)] input_core_dims += [list(b.dims) for b in by_da] - newdims: tuple[Dim, ...] = ( - quantile_new_dims_func(**finalize_kwargs) if func in ["quantile", "nanquantile"] else () - ) + newdims: tuple[Dim, ...] = () + if func in ["quantile", "nanquantile"]: + newdims = quantile_new_dims_func(**finalize_kwargs) + elif func == "topk": + newdims = topk_new_dims_func(**finalize_kwargs) output_core_dims = [d for d in input_core_dims[0] if d not in dim_tuple] output_core_dims.extend(group_names) diff --git a/flox/xrdtypes.py b/flox/xrdtypes.py index b9a83a87a..cfcc65eb2 100644 --- a/flox/xrdtypes.py +++ b/flox/xrdtypes.py @@ -109,6 +109,9 @@ def get_pos_infinity(dtype, max_for_int=False): if issubclass(dtype.type, np.complexfloating): return np.inf + 1j * np.inf + if issubclass(dtype.type, np.bool_): + return True + return INF @@ -142,6 +145,9 @@ def get_neg_infinity(dtype, min_for_int=False): if issubclass(dtype.type, np.complexfloating): return -np.inf - 1j * np.inf + if issubclass(dtype.type, np.bool_): + return False + return NINF diff --git a/flox/xrutils.py b/flox/xrutils.py index 4adc56b2c..2e398c1c0 100644 --- a/flox/xrutils.py +++ b/flox/xrutils.py @@ -9,6 +9,7 @@ import numpy as np import pandas as pd +from numpy.lib.array_utils import normalize_axis_tuple from packaging.version import Version @@ -410,3 +411,36 @@ def nanlast(values, axis, keepdims=False): return np.expand_dims(result, axis=axis) else: return result + + +def topk(a: np.ndarray, k: int, axis, keepdims: bool = True) -> np.ndarray: + """Chunk and combine function of topk + + Extract the k largest elements from a on the given axis. + If k is negative, extract the -k smallest elements instead. + Note that, unlike in the parent function, the returned elements + are not sorted internally. + + NOTE: This function was copied from the dask project under the terms + of their LICENSE. + """ + assert keepdims is True + (axis,) = normalize_axis_tuple(axis, a.ndim) # type: ignore[misc] + if abs(k) >= a.shape[axis]: + return a + + a.partition(-k, axis=axis) + k_slice = slice(-k, None) if k > 0 else slice(-k) + result = a[tuple(k_slice if i == axis else slice(None) for i in range(a.ndim))] + return result.astype(a.dtype, copy=False) + + +def nantopk(a: np.ndarray, k: int, axis, keepdims: bool = True) -> np.ndarray: + """NaN-aware version of topk. + + Replaces NaN with -inf (for k > 0) or +inf (for k < 0) before calling topk, + so that NaN values don't end up in the top/bottom k results. + """ + fill_val = -np.inf if k > 0 else np.inf + a = np.where(isnull(a), fill_val, a) + return topk(a, k=k, axis=axis, keepdims=keepdims) diff --git a/tests/test_core.py b/tests/test_core.py index 97f5c965f..08581102b 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -90,7 +90,7 @@ def npfunc(x, **kwargs): x = np.asarray(x) return (~xrutils.isnull(x)).sum(**kwargs) - elif func in ["nanfirst", "nanlast"]: + elif func in ["nanfirst", "nanlast", "topk"]: npfunc = getattr(xrutils, func) elif func in SCIPY_STATS_FUNCS: @@ -266,6 +266,10 @@ def test_groupby_reduce_all(to_sparse, nby, size, chunks, func, add_nan_by, engi ] fill_value = None tolerance = None + elif func == "topk": + finalize_kwargs = [{"k": 3}, {"k": -3}] + fill_value = None + tolerance = None else: fill_value = None tolerance = None @@ -276,6 +280,8 @@ def test_groupby_reduce_all(to_sparse, nby, size, chunks, func, add_nan_by, engi for kwargs in finalize_kwargs: if "quantile" in func and isinstance(kwargs["q"], list) and engine != "flox": continue + if "topk" in func and engine != "flox": + continue flox_kwargs = dict(func=func, engine=engine, finalize_kwargs=kwargs, fill_value=fill_value) with np.errstate(invalid="ignore", divide="ignore"): with warnings.catch_warnings(): @@ -295,6 +301,12 @@ def test_groupby_reduce_all(to_sparse, nby, size, chunks, func, add_nan_by, engi expected = getattr(np, func_)(array_, axis=-1, **kwargs) else: expected = array_func(array_[..., ~nanmask], axis=-1, **kwargs) + if func == "topk": + if (~nanmask).sum(axis=-1) < kwargs["k"]: + # FIXME: update this expectation + assert False + expected = np.full(expected.shape[:-1] + (abs(kwargs["k"]),), np.nan) + expected = np.sort(np.swapaxes(expected, array.ndim - 1, 0), axis=0) for _ in range(nby): expected = np.expand_dims(expected, axis=-1) @@ -302,7 +314,7 @@ def test_groupby_reduce_all(to_sparse, nby, size, chunks, func, add_nan_by, engi assert chunks == -1 actual, *groups = groupby_reduce(array, *by, **flox_kwargs) - if "quantile" in func and isinstance(kwargs["q"], list): + if ("quantile" in func and isinstance(kwargs["q"], list)) or func == "topk": assert actual.ndim == expected.ndim == (array.ndim + nby) else: assert actual.ndim == expected.ndim == (array.ndim + nby - 1) @@ -312,9 +324,12 @@ def test_groupby_reduce_all(to_sparse, nby, size, chunks, func, add_nan_by, engi assert_equal(actual_group, expect) if "arg" in func: assert actual.dtype.kind == "i" + if func == "topk": + actual = np.sort(actual, axis=0) assert_equal(expected, actual, tolerance) - if "nan" not in func and "arg" not in func: + # FIXME: topk vs nantopk + if "nan" not in func and "arg" not in func and "topk" not in func: # test non-NaN skipping behaviour when NaNs are present nanned = array_.copy() # remove nans in by to reduce complexity @@ -324,6 +339,10 @@ def test_groupby_reduce_all(to_sparse, nby, size, chunks, func, add_nan_by, engi nanned.reshape(-1)[0] = np.nan actual, *_ = groupby_reduce(nanned, *by_, **flox_kwargs) expected_0 = array_func(nanned, axis=-1, **kwargs) + if func == "topk": + expected_0 = np.sort(np.swapaxes(expected_0, array.ndim - 1, 0), axis=-1) + actual = np.sort(actual, axis=-1) + for _ in range(nby): expected_0 = np.expand_dims(expected_0, -1) assert_equal(expected_0, actual, tolerance) @@ -358,6 +377,11 @@ def test_groupby_reduce_all(to_sparse, nby, size, chunks, func, add_nan_by, engi with pytest.raises(NotImplementedError): call() continue + if func == "topk" and reindex is False: + # topk with reindex=False not yet supported + with pytest.raises(NotImplementedError, match="topk with reindex=False"): + call() + continue if method == "blockwise": # no combine necessary @@ -378,6 +402,8 @@ def test_groupby_reduce_all(to_sparse, nby, size, chunks, func, add_nan_by, engi assert_equal(actual_group, expect, tolerance) if "arg" in func: assert actual.dtype.kind == "i" + if func == "topk": + actual = np.sort(actual, axis=0) if isinstance(reindex, ReindexStrategy): import sparse @@ -2187,6 +2213,106 @@ def raise_error(self): assert mocked_reindex_func.call_count > 1 +@pytest.mark.parametrize( + "k,expected", + [ + # k=3: top 3 largest from each group (NaNs excluded) + # Group 0: [5.0, 2.0, nan, 8.0] -> top 3: [8.0, 5.0, 2.0] + # Group 1: [1.0, nan, 9.0, 3.0] -> top 3: [9.0, 3.0, 1.0] + (3, np.array([[8.0, 5.0, 2.0], [9.0, 3.0, 1.0]])), + # k=-3: bottom 3 smallest from each group (NaNs excluded) + # Group 0: [5.0, 2.0, nan, 8.0] -> bottom 3: [2.0, 5.0, 8.0] + # Group 1: [1.0, nan, 9.0, 3.0] -> bottom 3: [1.0, 3.0, 9.0] + (-3, np.array([[2.0, 5.0, 8.0], [1.0, 3.0, 9.0]])), + ], +) +def test_topk_with_nan(k, expected): + """Test topk handles NaN values correctly for both k > 0 and k < 0.""" + # Test data with NaNs + array = np.array([5.0, 2.0, np.nan, 8.0, 1.0, np.nan, 9.0, 3.0]) + by = np.array([0, 0, 0, 0, 1, 1, 1, 1]) + + actual, groups = groupby_reduce(array, by, func="topk", finalize_kwargs={"k": k}) + + # Verify shape: should be (abs(k), num_groups) + assert actual.shape == (abs(k), 2) + + # Sort for comparison since order within topk is not guaranteed + actual_sorted = np.sort(actual, axis=0) + expected_sorted = np.sort(expected.T, axis=0) + assert_equal(actual_sorted, expected_sorted) + + # Verify no NaNs in the results + assert not np.isnan(actual).any() + + +@pytest.mark.parametrize( + "k,expected", + [ + (5, np.array([[2.0, 5.0, 8.0, np.nan, np.nan]])), + (-5, np.array([[2.0, 5.0, 8.0, np.nan, np.nan]])), + ], +) +def test_topk_fewer_than_k_elements(k, expected): + """Test topk when group has fewer than k elements.""" + # Group has only 3 elements but k=5 + array = np.array([5.0, 2.0, 8.0]) + by = np.array([0, 0, 0]) + + actual, groups = groupby_reduce(array, by, func="topk", finalize_kwargs={"k": k}) + + # Should return all elements plus NaN padding + assert actual.shape == (abs(k), 1) + + # Sort for comparison (order not guaranteed) + actual_sorted = np.sort(actual, axis=0) + expected_sorted = np.sort(expected.T, axis=0) + assert_equal(actual_sorted, expected_sorted) + + +@pytest.mark.parametrize( + "k,expected", + [ + (2, np.array([[1.0, 2.0], [np.nan, np.nan]])), + (-2, np.array([[1.0, 2.0], [np.nan, np.nan]])), + ], +) +def test_topk_all_nan_group(k, expected): + """Test topk when a group has all NaN values.""" + array = np.array([1.0, 2.0, np.nan, np.nan]) + by = np.array([0, 0, 1, 1]) + + actual, groups = groupby_reduce(array, by, func="topk", finalize_kwargs={"k": k}) + + assert actual.shape == (abs(k), 2) + + # Sort for comparison (order not guaranteed for group 0) + actual_sorted = np.sort(actual, axis=0) + expected_sorted = np.sort(expected.T, axis=0) + assert_equal(actual_sorted, expected_sorted) + + +@pytest.mark.parametrize("k", [3, -3]) +def test_topk_fill_value_correctness(k): + """Test that fill_value is correctly set based on sign of k.""" + # Create array with missing groups + array = np.array([5.0, 2.0, 8.0]) + by = np.array([0, 0, 2]) # Missing group 1 + + actual, groups = groupby_reduce( + array, + by, + func="topk", + finalize_kwargs={"k": k}, + expected_groups=([0, 1, 2],), + ) + + assert actual.shape == (abs(k), 3) + + # Group 1 (missing) should be all NaN (final_fill_value) + assert np.isnan(actual[:, 1]).all() + + @requires_dask def test_sparse_errors(): call = partial( diff --git a/tests/test_properties.py b/tests/test_properties.py index 26f564c01..4797e2d95 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -360,3 +360,28 @@ def test_agg_dtype_specified(func, array_dtype, dtype, engine): ) expected = getattr(np, func)(counts, keepdims=True, dtype=dtype) assert actual.dtype == expected.dtype + + +@given(data=st.data(), array=chunked_arrays()) +def test_topk_max_min(data, array): + "top 1 == nanmax; top -1 == nanmin" + + if array.dtype.kind in "iu": + # we cast to float and back, so this is the effective limit + assume((np.abs(array) < 2**53).all()) + elif array.dtype.kind in "Mm": + assume((np.abs(array.view(np.int64)) < 2**53).all()) + # we cast to float and back, so this is the effective limit + elif _contains_cftime_datetimes(array): + asint = datetime_to_numeric(array, datetime_unit="us") + assume((np.abs(asint.view(np.int64)) < 2**53).all()) + + size = array.shape[-1] + by = data.draw(by_arrays(shape=(size,))) + k, npfunc = data.draw(st.sampled_from([(1, "nanmax"), (-1, "nanmin")])) + + for a in (array, array.compute()): + actual, _ = groupby_reduce(a, by, func="topk", finalize_kwargs={"k": k}) + # TODO: do numbagg, flox + expected, _ = groupby_reduce(a, by, func=npfunc, engine="numpy") + assert_equal(actual, expected[np.newaxis, :]) diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 866ceee5d..e0e6fa8b9 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -811,3 +811,82 @@ def test_resample_first_last_empty(): dims=["date"], ).chunk(date=(1, 1)) arr.resample(date="QE").last().compute() + + +def test_xarray_topk_basic(): + """Test basic topk functionality with xarray.""" + # Create test data with clear ordering + data = np.array([[5, 1, 3, 8], [2, 9, 4, 7], [6, 0, 10, 1]]) + + da = xr.DataArray( + data, + dims=("x", "y"), + coords={"labels": ("y", ["a", "a", "b", "b"])}, + ) + + # Test k=2 (top 2 values) + result = xarray_reduce( + da, + "labels", + func="topk", + k=2, + ) + + # Check dimensions are correct + assert "k" in result.dims + assert result.sizes["k"] == 2 + assert result.sizes["labels"] == 2 + assert result.sizes["x"] == 3 + + +def test_xarray_topk_negative_k(): + """Test topk with negative k (bottom k values).""" + data = np.array([[5, 1, 3, 8], [2, 9, 4, 7], [6, 0, 10, 1]]) + + da = xr.DataArray( + data, + dims=("x", "y"), + coords={"labels": ("y", ["a", "a", "b", "b"])}, + ) + + # Test k=-2 (bottom 2 values) + result = xarray_reduce( + da, + "labels", + func="topk", + k=-2, + ) + + # Check dimensions + assert "k" in result.dims + assert result.sizes["k"] == 2 + assert result.sizes["labels"] == 2 + + +@requires_dask +def test_xarray_topk_dask(): + """Test topk with dask arrays.""" + import dask.array as dask_array + + data = np.array([[5, 1, 3, 8], [2, 9, 4, 7], [6, 0, 10, 1]]) + + da = xr.DataArray( + dask_array.from_array(data, chunks=(2, 2)), + dims=("x", "y"), + coords={"labels": ("y", ["a", "a", "b", "b"])}, + ) + + result = xarray_reduce( + da, + "labels", + func="topk", + k=2, + ) + + # Force computation + result = result.compute() + + # Check dimensions + assert "k" in result.dims + assert result.sizes["k"] == 2 + assert result.sizes["labels"] == 2