diff --git a/src/parcels/_core/particlefile.py b/src/parcels/_core/particlefile.py index 0f1a2d2c5..24c270793 100644 --- a/src/parcels/_core/particlefile.py +++ b/src/parcels/_core/particlefile.py @@ -58,6 +58,8 @@ class ParticleFile: Interval which dictates the update frequency of file output while ParticleFile is given as an argument of ParticleSet.execute() It is either a numpy.timedelta64, a datimetime.timedelta object or a positive float (in seconds). + compression : {"zstd", "gzip", "snappy", "brotli", None}, optional + Compression algorithm to use for the Parquet file. Default is "zstd". Returns ------- @@ -65,11 +67,14 @@ class ParticleFile: ParticleFile object that can be used to write particle data to file """ - def __init__(self, path: PathLike, outputdt): + def __init__( + self, path: PathLike, outputdt, compression: Literal["zstd", "gzip", "snappy", "brotli", None] = "zstd" + ): if not isinstance(outputdt, (np.timedelta64, timedelta, float)): raise ValueError( f"Expected outputdt to be a np.timedelta64, datetime.timedelta or float (in seconds), got {type(outputdt)}" ) + self._compression = compression outputdt = timedelta_to_float(outputdt) path = Path(path) @@ -133,7 +138,11 @@ def write(self, pset: ParticleSet, time, indices=None): if self._writer is None: assert not self.path.exists(), "If the file exists, the writer should already be set" - self._writer = pq.ParquetWriter(self.path, _get_schema(pclass, self.metadata, pset.fieldset.time_interval)) + self._writer = pq.ParquetWriter( + self.path, + _get_schema(pclass, self.metadata, pset.fieldset.time_interval), + compression=self._compression, + ) if isinstance(time, (np.timedelta64, np.datetime64)): time = timedelta_to_float(time - time_interval.left) diff --git a/tests/test_particlefile.py b/tests/test_particlefile.py index 5173cba7e..b0938db64 100755 --- a/tests/test_particlefile.py +++ b/tests/test_particlefile.py @@ -57,6 +57,23 @@ def test_metadata(fieldset, tmp_parquet): assert tab.schema.metadata[b"parcels_kernels"].decode().lower() == "DoNothing".lower() +@pytest.mark.parametrize("compression", ["zstd", "gzip", "snappy", "brotli", None]) +def test_compression(fieldset, tmp_parquet, compression): + pset = ParticleSet(fieldset, pclass=Particle, lon=0, lat=0) + + ofile = ParticleFile(tmp_parquet, outputdt=np.timedelta64(1, "s"), compression=compression) + pset.execute(DoNothing, runtime=np.timedelta64(1, "s"), dt=np.timedelta64(1, "s"), output_file=ofile) + + tab = pq.ParquetFile(tmp_parquet) + for i in range(tab.num_row_groups): + row_group = tab.metadata.row_group(i) + for j in range(row_group.num_columns): + col = row_group.column(j) + assert col.compression.lower() == compression or ( + compression is None and col.compression.lower() == "uncompressed" + ) + + def test_write_fieldset_without_time(tmp_parquet): ds = peninsula_dataset() # DataSet without time assert "time" not in ds.dims