From 3e429b4af63b48182d629afef0c001cc94743d86 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Sat, 1 Jun 2024 10:06:46 +0200 Subject: [PATCH 1/7] Add arrow import/export to make shuffle work --- dask_geopandas/backends.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/dask_geopandas/backends.py b/dask_geopandas/backends.py index ee88e115..d96d140a 100644 --- a/dask_geopandas/backends.py +++ b/dask_geopandas/backends.py @@ -19,7 +19,12 @@ from dask.dataframe.utils import meta_nonempty from dask.dataframe.extensions import make_array_nonempty, make_scalar from dask.base import normalize_token -from dask.dataframe.dispatch import make_meta_dispatch, pyarrow_schema_dispatch +from dask.dataframe.dispatch import ( + make_meta_dispatch, + pyarrow_schema_dispatch, + to_pyarrow_table_dispatch, + from_pyarrow_table_dispatch, +) from dask.dataframe.backends import _nonempty_index, meta_nonempty_dataframe import shapely.geometry @@ -86,3 +91,21 @@ def get_pyarrow_schema_geopandas(obj): for col in obj.columns[obj.dtypes == "geometry"]: df[col] = obj[col].to_wkb() return pa.Schema.from_pandas(df) + + +@to_pyarrow_table_dispatch.register((geopandas.GeoDataFrame,)) +def get_pyarrow_table_from_geopandas(obj, **kwargs): + # `kwargs` must be supported by `pyarrow.Table.to_pandas` + import pyarrow as pa + + return pa.table(obj.to_arrow()) + # return pa.Table.from_pandas(obj, **kwargs) + + +@from_pyarrow_table_dispatch.register((pd.DataFrame,)) +def get_geopandas_geodataframe_from_pyarrow(meta, table, **kwargs): + # `kwargs` must be supported by `pyarrow.Table.to_pandas` + try: + return geopandas.GeoDataFrame.from_arrow(table) + except: + return table.to_pandas(**kwargs) From 46b1d823fe45f3e60c94f08a75f7bf7e7304b5fe Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Mon, 3 Jun 2024 17:06:25 +0200 Subject: [PATCH 2/7] add test --- dask_geopandas/backends.py | 24 ++++++++++---- dask_geopandas/tests/test_distributed.py | 41 ++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 7 deletions(-) create mode 100644 dask_geopandas/tests/test_distributed.py diff --git a/dask_geopandas/backends.py b/dask_geopandas/backends.py index d96d140a..d56a9d42 100644 --- a/dask_geopandas/backends.py +++ b/dask_geopandas/backends.py @@ -95,17 +95,27 @@ def get_pyarrow_schema_geopandas(obj): @to_pyarrow_table_dispatch.register((geopandas.GeoDataFrame,)) def get_pyarrow_table_from_geopandas(obj, **kwargs): - # `kwargs` must be supported by `pyarrow.Table.to_pandas` + # `kwargs` must be supported by `pyarrow.Table.from_pandas` import pyarrow as pa - return pa.table(obj.to_arrow()) - # return pa.Table.from_pandas(obj, **kwargs) + if Version(geopandas.__version__).major < 1: + return pa.Table.from_pandas(obj.to_wkb(), **kwargs) + else: + # TODO handle kwargs? + return pa.table(obj.to_arrow()) -@from_pyarrow_table_dispatch.register((pd.DataFrame,)) +@from_pyarrow_table_dispatch.register((geopandas.GeoDataFrame,)) def get_geopandas_geodataframe_from_pyarrow(meta, table, **kwargs): # `kwargs` must be supported by `pyarrow.Table.to_pandas` - try: + if Version(geopandas.__version__).major < 1: + df = table.to_pandas(**kwargs) + + for col in meta.columns[meta.dtypes == "geometry"]: + df[col] = geopandas.GeoSeries.from_wkb(df[col], crs=meta[col].crs) + + return df + + else: + # TODO handle kwargs? return geopandas.GeoDataFrame.from_arrow(table) - except: - return table.to_pandas(**kwargs) diff --git a/dask_geopandas/tests/test_distributed.py b/dask_geopandas/tests/test_distributed.py new file mode 100644 index 00000000..65fd08c1 --- /dev/null +++ b/dask_geopandas/tests/test_distributed.py @@ -0,0 +1,41 @@ +import pytest + +import geopandas +import dask_geopandas + +from geopandas.testing import assert_geodataframe_equal + +distributed = pytest.importorskip("distributed") + + +from distributed import LocalCluster, Client + +# from distributed.utils_test import gen_cluster + + +def test_spatial_shuffle(naturalearth_cities): + df_points = geopandas.read_file(naturalearth_cities) + + with LocalCluster(n_workers=1) as cluster: + with Client(cluster): + ddf_points = dask_geopandas.from_geopandas(df_points, npartitions=4) + + ddf_result = ddf_points.spatial_shuffle( + by="hilbert", calculate_partitions=False + ) + result = ddf_result.compute() + + expected = df_points.sort_values("geometry").reset_index(drop=True) + assert_geodataframe_equal(result.reset_index(drop=True), expected) + + +# @gen_cluster(client=True) +# async def test_spatial_shuffle(c, s, a, b, naturalearth_cities): +# df_points = geopandas.read_file(naturalearth_cities) +# ddf_points = dask_geopandas.from_geopandas(df_points, npartitions=4) + +# ddf_result = ddf_points.spatial_shuffle(by="hilbert", calculate_partitions=False) +# result = (await c.compute(ddf_result)).val + +# expected = df_points.sort_values("geometry").reset_index(drop=True) +# assert_geodataframe_equal(result.reset_index(drop=True), expected) From b413ee6283b0427cd904f8823cd2c46de6546a0e Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Tue, 18 Jun 2024 08:58:52 +0200 Subject: [PATCH 3/7] only register dispatch for newer dask --- dask_geopandas/backends.py | 48 +++++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/dask_geopandas/backends.py b/dask_geopandas/backends.py index d56a9d42..29823c22 100644 --- a/dask_geopandas/backends.py +++ b/dask_geopandas/backends.py @@ -15,6 +15,7 @@ QUERY_PLANNING_ON = True +import dask from dask.dataframe.core import get_parallel_type from dask.dataframe.utils import meta_nonempty from dask.dataframe.extensions import make_array_nonempty, make_scalar @@ -22,8 +23,6 @@ from dask.dataframe.dispatch import ( make_meta_dispatch, pyarrow_schema_dispatch, - to_pyarrow_table_dispatch, - from_pyarrow_table_dispatch, ) from dask.dataframe.backends import _nonempty_index, meta_nonempty_dataframe @@ -93,29 +92,34 @@ def get_pyarrow_schema_geopandas(obj): return pa.Schema.from_pandas(df) -@to_pyarrow_table_dispatch.register((geopandas.GeoDataFrame,)) -def get_pyarrow_table_from_geopandas(obj, **kwargs): - # `kwargs` must be supported by `pyarrow.Table.from_pandas` - import pyarrow as pa +if Version(dask.__version__) >= Version("2023.6.1"): + from dask.dataframe.dispatch import ( + to_pyarrow_table_dispatch, + from_pyarrow_table_dispatch, + ) - if Version(geopandas.__version__).major < 1: - return pa.Table.from_pandas(obj.to_wkb(), **kwargs) - else: - # TODO handle kwargs? - return pa.table(obj.to_arrow()) + @to_pyarrow_table_dispatch.register((geopandas.GeoDataFrame,)) + def get_pyarrow_table_from_geopandas(obj, **kwargs): + # `kwargs` must be supported by `pyarrow.Table.from_pandas` + import pyarrow as pa + if Version(geopandas.__version__).major < 1: + return pa.Table.from_pandas(obj.to_wkb(), **kwargs) + else: + # TODO handle kwargs? + return pa.table(obj.to_arrow()) -@from_pyarrow_table_dispatch.register((geopandas.GeoDataFrame,)) -def get_geopandas_geodataframe_from_pyarrow(meta, table, **kwargs): - # `kwargs` must be supported by `pyarrow.Table.to_pandas` - if Version(geopandas.__version__).major < 1: - df = table.to_pandas(**kwargs) + @from_pyarrow_table_dispatch.register((geopandas.GeoDataFrame,)) + def get_geopandas_geodataframe_from_pyarrow(meta, table, **kwargs): + # `kwargs` must be supported by `pyarrow.Table.to_pandas` + if Version(geopandas.__version__).major < 1: + df = table.to_pandas(**kwargs) - for col in meta.columns[meta.dtypes == "geometry"]: - df[col] = geopandas.GeoSeries.from_wkb(df[col], crs=meta[col].crs) + for col in meta.columns[meta.dtypes == "geometry"]: + df[col] = geopandas.GeoSeries.from_wkb(df[col], crs=meta[col].crs) - return df + return df - else: - # TODO handle kwargs? - return geopandas.GeoDataFrame.from_arrow(table) + else: + # TODO handle kwargs? + return geopandas.GeoDataFrame.from_arrow(table) From 4d53aa3483831989f5ad55a1176c0c754b1e77ec Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Tue, 18 Jun 2024 09:07:45 +0200 Subject: [PATCH 4/7] skip older geopandas --- dask_geopandas/tests/test_distributed.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/dask_geopandas/tests/test_distributed.py b/dask_geopandas/tests/test_distributed.py index 65fd08c1..aa318565 100644 --- a/dask_geopandas/tests/test_distributed.py +++ b/dask_geopandas/tests/test_distributed.py @@ -1,3 +1,5 @@ +from packaging.version import Version + import pytest import geopandas @@ -13,6 +15,10 @@ # from distributed.utils_test import gen_cluster +@pytest.mark.skipif( + Version(geopandas.__version__) < Version("0.13"), + reason="geopandas < 0.13 does not implement sorting geometries", +) def test_spatial_shuffle(naturalearth_cities): df_points = geopandas.read_file(naturalearth_cities) From 9ed2d6ccaf2c426e4fcb53b06b767b0799c864cd Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Tue, 18 Jun 2024 09:10:37 +0200 Subject: [PATCH 5/7] add skip for older distributed --- dask_geopandas/tests/test_distributed.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dask_geopandas/tests/test_distributed.py b/dask_geopandas/tests/test_distributed.py index aa318565..caa0b499 100644 --- a/dask_geopandas/tests/test_distributed.py +++ b/dask_geopandas/tests/test_distributed.py @@ -16,7 +16,12 @@ @pytest.mark.skipif( - Version(geopandas.__version__) < Version("0.13"), + Version(distributed.__version__) < Version("2024.6.0"), + reason="distributed < 2024.6 has a wrong assertion", + # https://github.com/dask/distributed/pull/8667 +) +@pytest.mark.skipif( + Version(distributed.__version__) < Version("0.13"), reason="geopandas < 0.13 does not implement sorting geometries", ) def test_spatial_shuffle(naturalearth_cities): From 8c58e7ec8a8ad840183c531b40856f62760066e3 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Mon, 24 Jun 2024 12:37:53 +0200 Subject: [PATCH 6/7] fixup linting --- dask_geopandas/backends.py | 3 ++- dask_geopandas/tests/test_distributed.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/dask_geopandas/backends.py b/dask_geopandas/backends.py index 2301a605..60408a1e 100644 --- a/dask_geopandas/backends.py +++ b/dask_geopandas/backends.py @@ -1,6 +1,7 @@ import uuid from packaging.version import Version +import dask from dask import config # Check if dask-dataframe is using dask-expr (default of None means True as well) @@ -88,8 +89,8 @@ def get_pyarrow_schema_geopandas(obj): if Version(dask.__version__) >= Version("2023.6.1"): from dask.dataframe.dispatch import ( - to_pyarrow_table_dispatch, from_pyarrow_table_dispatch, + to_pyarrow_table_dispatch, ) @to_pyarrow_table_dispatch.register((geopandas.GeoDataFrame,)) diff --git a/dask_geopandas/tests/test_distributed.py b/dask_geopandas/tests/test_distributed.py index caa0b499..e74e2ca1 100644 --- a/dask_geopandas/tests/test_distributed.py +++ b/dask_geopandas/tests/test_distributed.py @@ -1,16 +1,16 @@ from packaging.version import Version -import pytest - import geopandas + import dask_geopandas +import pytest from geopandas.testing import assert_geodataframe_equal distributed = pytest.importorskip("distributed") -from distributed import LocalCluster, Client +from distributed import Client, LocalCluster # from distributed.utils_test import gen_cluster From f4a37e017aed9a8ec2209801a2bef025e3a03f2e Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Mon, 24 Jun 2024 15:09:30 +0200 Subject: [PATCH 7/7] cleanup --- dask_geopandas/tests/test_distributed.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/dask_geopandas/tests/test_distributed.py b/dask_geopandas/tests/test_distributed.py index e74e2ca1..9222df3d 100644 --- a/dask_geopandas/tests/test_distributed.py +++ b/dask_geopandas/tests/test_distributed.py @@ -12,8 +12,6 @@ from distributed import Client, LocalCluster -# from distributed.utils_test import gen_cluster - @pytest.mark.skipif( Version(distributed.__version__) < Version("2024.6.0"), @@ -38,15 +36,3 @@ def test_spatial_shuffle(naturalearth_cities): expected = df_points.sort_values("geometry").reset_index(drop=True) assert_geodataframe_equal(result.reset_index(drop=True), expected) - - -# @gen_cluster(client=True) -# async def test_spatial_shuffle(c, s, a, b, naturalearth_cities): -# df_points = geopandas.read_file(naturalearth_cities) -# ddf_points = dask_geopandas.from_geopandas(df_points, npartitions=4) - -# ddf_result = ddf_points.spatial_shuffle(by="hilbert", calculate_partitions=False) -# result = (await c.compute(ddf_result)).val - -# expected = df_points.sort_values("geometry").reset_index(drop=True) -# assert_geodataframe_equal(result.reset_index(drop=True), expected)