diff --git a/.github/workflows/doc.yaml b/.github/workflows/doc.yaml index 97fb160..71d1878 100644 --- a/.github/workflows/doc.yaml +++ b/.github/workflows/doc.yaml @@ -66,29 +66,48 @@ jobs: shell: bash -l {0} run: | pip install altimetry-downloader-aviso - - name: Restore cached sample data - uses: actions/cache@v4 - id: restore_samples - with: - path: docs/implementations/data - key: samples-${{ hashFiles('docs/implementations/scripts/pull_data*.py') }} - name: Setup AVISO credentials - if: steps.restore_samples.outputs['cache-hit'] != 'true' run: | echo "machine tds-odatis.aviso.altimetry.fr login ${{ secrets.AVISO_USER }} password ${{ secrets.AVISO_PASSWORD }}" > ~/.netrc - - name: Pull Samples - if: steps.restore_samples.outputs['cache-hit'] != 'true' + # Cached retrieval of L2_LR_SSH data + - name: Restore cached sample data for L2_LR_SSH + uses: actions/cache@v4 + id: restore_samples_l2_lr_ssh + with: + path: docs/implementations/data_l2_lr_ssh + key: samples-${{ hashFiles('docs/implementations/scripts/pull_data_l2_lr_ssh.py') }} + - name: Pull L2_LR_SSH Samples + if: steps.restore_samples_l2_lr_ssh.outputs['cache-hit'] != 'true' shell: bash -l {0} run: | set -e python docs/implementations/scripts/pull_data_l2_lr_ssh.py + - name: Cache L2_LR_SSH data + uses: actions/cache/save@v4 + if: steps.restore_samples_l2_lr_ssh.outputs['cache-hit'] != 'true' + with: + path: docs/implementations/data_l2_lr_ssh + key: samples-${{ hashFiles('docs/implementations/scripts/pull_data_l2_lr_ssh.py') }} + # Cached retrieval of L3_LR_SSH data + - name: Restore cached sample data for L3_LR_SSH + uses: actions/cache@v4 + id: restore_samples_l3_lr_ssh + with: + path: docs/implementations/data_l3_lr_ssh + key: samples-${{ hashFiles('docs/implementations/scripts/pull_data_l3_lr_ssh.py') }} + - name: Pull L3_LR_SSH Samples + if: steps.restore_samples_l3_lr_ssh.outputs['cache-hit'] != 'true' + shell: bash -l {0} + run: | + set -e python docs/implementations/scripts/pull_data_l3_lr_ssh.py - - name: Save fetched data + - name: Cache L3_LR_SSH data uses: actions/cache/save@v4 - if: steps.restore_samples.outputs['cache-hit'] != 'true' + if: steps.restore_samples_l3_lr_ssh.outputs['cache-hit'] != 'true' with: - path: docs/implementations/data - key: samples-${{ hashFiles('docs/implementations/scripts/pull_data*.py') }} + path: docs/implementations/data_l3_lr_ssh + key: samples-${{ hashFiles('docs/implementations/scripts/pull_data_l3_lr_ssh.py') }} + # DOC BUILD - name: Build Sphinx HTML shell: bash -l {0} run: | diff --git a/docs/implementations/l2_lr_ssh.md b/docs/implementations/l2_lr_ssh.md index f2f2849..fcdcdf7 100644 --- a/docs/implementations/l2_lr_ssh.md +++ b/docs/implementations/l2_lr_ssh.md @@ -57,6 +57,10 @@ The following examples can be used to build complex queries ```python fc.query(cycle_number=slice(1, 4), pass_number=[1, 3]) ``` + - A specific orbit of the SWOT mission + ```python + fc.query(phase='CALVAL') + ``` ::: :::{tab-item} Periods - A time stamp diff --git a/docs/implementations/l3_lr_ssh.md b/docs/implementations/l3_lr_ssh.md index 02d201c..82660c2 100644 --- a/docs/implementations/l3_lr_ssh.md +++ b/docs/implementations/l3_lr_ssh.md @@ -53,6 +53,10 @@ The following examples can be used to build complex queries ```python fc.query(cycle_number=slice(1, 4), pass_number=[1, 3]) ``` + - A specific orbit of the SWOT mission + ```python + fc.query(phase='CALVAL') + ``` ::: :::{tab-item} Periods - A time stamp diff --git a/src/fcollections/core/__init__.py b/src/fcollections/core/__init__.py index 27c8b6e..f34ed6e 100644 --- a/src/fcollections/core/__init__.py +++ b/src/fcollections/core/__init__.py @@ -18,6 +18,7 @@ FileNameFieldDatetime, FileNameFieldEnum, FileNameFieldFloat, + FileNameFieldGeoBox, FileNameFieldInteger, FileNameFieldISODuration, FileNameFieldPeriod, @@ -26,7 +27,7 @@ from ._filesdb import ( Deduplicator, FilesDatabase, - IPredicate, + IFilterBuilder, NotExistingPathError, SubsetsUnmixer, ) @@ -74,6 +75,7 @@ "FileNameFieldString", "FileNameFieldPeriod", "FileNameFieldISODuration", + "FileNameFieldGeoBox", "FileNameConvention", "FileListingError", "IFilesReader", @@ -93,7 +95,7 @@ "VariableMetadata", "DiscreteTimesMixin", "ITemporalMixin", - "IPredicate", + "IFilterBuilder", "Layout", "DecodingError", "ICodec", diff --git a/src/fcollections/core/_filenames.py b/src/fcollections/core/_filenames.py index deb2e56..1195355 100644 --- a/src/fcollections/core/_filenames.py +++ b/src/fcollections/core/_filenames.py @@ -27,6 +27,7 @@ DateTimeTester, EnumTester, FloatTester, + GeoBoxTester, IntegerTester, ITester, PeriodTester, @@ -257,6 +258,18 @@ def sanitize(self, reference: str | ISODuration) -> ISODuration: return reference +class FileNameFieldGeoBox(FileNameField, StringCodec, GeoBoxTester): + """Geobox value. + + Geobox is a tuple (lon_min, lat_min, lon_max, lat_max) in °. + + Attributes + ---------- + name: str + name of the field + """ + + class FieldFormatter(string.Formatter): def __init__(self, fields: dict[str, FileNameField]): diff --git a/src/fcollections/core/_filesdb.py b/src/fcollections/core/_filesdb.py index 40cafcb..b51cb49 100644 --- a/src/fcollections/core/_filesdb.py +++ b/src/fcollections/core/_filesdb.py @@ -18,7 +18,7 @@ from fsspec import AbstractFileSystem from fsspec.implementations.local import LocalFileSystem -from ._filenames import FileNameConvention +from ._filenames import FileNameConvention, FileNameField from ._listing import DirNode, FileSystemMetadataCollector, Layout, LayoutMismatchError from ._metadata import GroupMetadata from ._readers import IFilesReader @@ -62,7 +62,7 @@ def __new__(cls, clsname, bases, attrs): setattr( new_class, "listing_parameters", - parameters["convention"][1] | parameters["predicates"][1], + parameters["convention"][1] | parameters["filter_builders"][1], ) method_parameters = _combine_parameters(new_class, parameters) @@ -109,7 +109,9 @@ def _extract_parameters( parameters["convention"] = _convention_parameters( new_class.layouts[0].conventions[-1] ) - parameters["predicates"] = _predicates_parameters(new_class.predicate_classes) + parameters["filter_builders"] = _filter_builders_parameters( + new_class.filter_builders + ) return parameters @@ -127,14 +129,14 @@ def _combine_parameters( ( parameters["reader"][0] | parameters["convention"][0] - | parameters["predicates"][0] + | parameters["filter_builders"][0] ).values() ) query_signature = list( ( parameters["reader"][1] | parameters["convention"][1] - | parameters["predicates"][1] + | parameters["filter_builders"][1] ).values() ) out["query"] = (query_docstring, query_signature) @@ -144,10 +146,10 @@ def _combine_parameters( # parameters and predicates parameters # self is included in the listing_signature_parameters files_docstring = list( - (parameters["convention"][0] | parameters["predicates"][0]).values() + (parameters["convention"][0] | parameters["filter_builders"][0]).values() ) files_signature = list( - (parameters["convention"][1] | parameters["predicates"][1]).values() + (parameters["convention"][1] | parameters["filter_builders"][1]).values() ) out["list_files"] = (files_docstring, files_signature) @@ -237,25 +239,33 @@ def _convention_parameters( return convention_docstring_parameters, convention_signature_parameters -def _predicates_parameters( - predicate_classes: list[IPredicate] | None, +def _filter_builders_parameters( + filter_builders: list[IFilterBuilder] | None, ) -> tuple[dict[str, dcs.DocstringParam], dict[str, inspect.Parameter]]: - if predicate_classes is None: + if filter_builders is None: return {}, {} docstring_parameters, signature_parameters = {}, {} - for predicate_builder in predicate_classes: - docstring_parameters |= { - p.arg_name: p - for p in dcs.parse(predicate_builder.__init__.__doc__).params - if p not in ["self", "indexes"] - } + for filter_builder in filter_builders: + field = filter_builder.parameter() + docstring_parameters[field.name] = dcs.DocstringParam( + ["param", field.name], + textwrap.fill(field.description), + field.name, + # Docstrings in the project do not repeat the typing in the + # Parameters section. We set None to comply with this implicit + # convention + None, + False, + None, + ) - signature_parameters |= { - k: p.replace(kind=inspect.Parameter.KEYWORD_ONLY) - for k, p in inspect.signature(predicate_builder.__init__).parameters.items() - if k not in ["self", "indexes"] - } + signature_parameters[field.name] = inspect.Parameter( + field.name, + default=inspect.Parameter.empty, + annotation=field.type, + kind=inspect.Parameter.KEYWORD_ONLY, + ) return docstring_parameters, signature_parameters @@ -355,12 +365,16 @@ class FilesDatabase(metaclass=FilesDatabaseMeta): The keys is the columns of the file metadata table, the value is a tuple of dimensions for insertion. """ - predicate_classes: list[type[IPredicate]] | None = None - """List of predicates that are built at each query. + filter_builders: list[type[IFilterBuilder]] | None = None + """Builders that will create or modify filters for each query. + + The builders intercept the input parameters to either build: - The predicates intercepts the input parameters to build a custom - record predicate. Usually, it is a complex test involving auxiliary - data, such as ground track footprints or half_orbit/periods tables. + - A custom record predicate. A predicate is a complex test involving + auxiliary data - such as ground track footprints or half_orbit/periods + tables - and is run on a file record. + - A filter converter, to convert an input filter to another filter present + in the layouts. For example: ``query(foo="segmentA") -> query(bar=slice(10, 20)}``. """ def __init__( @@ -421,7 +435,7 @@ def _files( sort: bool = False, deduplicate: bool = False, unmix: bool = False, - predicates: tp.Iterable[IPredicate] = (), + predicates: tp.Iterable[tp.Callable[[tuple[tp.Any, ...]], bool]] = (), stat_fields: tuple[str] = (), **kwargs, ) -> pda.DataFrame: @@ -472,37 +486,57 @@ def _files( # This docstring will be superseded by the metaclass bad_kwargs = [k for k in kwargs if k not in self.listing_parameters] if bad_kwargs != []: - raise ValueError( - f"list_files() got unexpected keyword argument(s): {bad_kwargs}" - ) + msg = f"list_files() got unexpected keyword argument(s): {bad_kwargs}" + raise ValueError(msg) - # Auto-build declared predicates. Parameters used by the predicates are - # expected to be independant of the other parameters from the file name - # convention + # Auto-build declared predicates and additionnal filters. predicates = list(predicates) - if self.predicate_classes is not None: - fields_names = list(map(lambda f: f.name, self.parser.fields)) - for predicate_builder in self.predicate_classes: - # Convert field name into indexes for the record predicate - record_indexes = [ - fields_names.index(requested_field) - for requested_field in predicate_builder.record_fields() - ] + if self.filter_builders is not None: + record_mapping = { + field.name: ii for ii, field in enumerate(self.parser.fields) + } + + for filter_builder in self.filter_builders: try: - predicate = predicate_builder( - record_indexes, - # Extract args from the parameters - *[kwargs.pop(p) for p in predicate_builder.parameters()], + filter_field = filter_builder.parameter() + sanitized_parameter = filter_field.sanitize( + kwargs[filter_field.name] + ) + except KeyError: + logger.debug( + "Predicate build skipped, parameter %s is missing", + filter_field.name, ) + continue + + try: + # Complex filter (predicate) that will be applied on the + # files' record + predicate = filter_builder.build_predicate( + record_mapping, sanitized_parameter + ) + predicates.append(predicate) logger.debug( - "Added predicate over parameters %s", - predicate_builder.parameters(), + "Added predicate from parameter %s", + filter_field.name, ) - except KeyError: + except NotImplementedError: + # Simple converter from one filter to another + filters = filter_builder.build_filter(sanitized_parameter) + common_keys = filters.keys() & kwargs.keys() + if len(common_keys) > 0: + msg = ( + "Incompatible filters, cannot give both " + f"'{filter_field.name}' and {common_keys}." + ) + raise ValueError(msg) + + kwargs |= filters logger.debug( - "Predicate build skipped, missing one of the following parameters %s", - predicate_builder.parameters(), + "Converted filter '%s' to '%s'", + filter_field.name, + filters.keys(), ) df = self.discoverer.to_dataframe( @@ -922,42 +956,49 @@ def keys(self) -> set[str]: return set(self.unique) | set(self.auto_pick_last) -class IPredicate(abc.ABC): - """Interface for defining a complex predicate. - - This predicate will be used to filter records from file names listing and - parsing. - - Attributes - ---------- - indexes - Attributes - *args - Any input that will be used to create the predicate - """ +class IFilterBuilder(abc.ABC): + """Interface for building filters.""" + @classmethod @abc.abstractmethod - def __call__(self, record: tuple[tp.Any, ...]) -> bool: - """Call the predicate. + def build_predicate( + self, record_mapping: dict[str, int], *args: tp.Any + ) -> tp.Callable[[tuple[tp.Any, ...]], bool]: + """Build a complex predicate. Parameters ---------- - record - The record to filter + record_mapping + Mapping between the record names and indexes. Records are given + as a tuple to the predicate, so we need the index to extract the + given fields. + args + Any input argument that is needed to build the predicate. Returns ------- - result - True if the record complies with the criteria given by this - predicate + Callable + A predicate that checks whether the input record fulfills the stated + conditions. """ @classmethod @abc.abstractmethod - def record_fields(cls) -> tuple[str, ...]: - """Record fields needed by the predicate.""" + def build_filter(cls, *args: tp.Any) -> dict[str, tp.Any]: + """Build a simple filter. + + Parameters + ---------- + args + Any input argument that is needed to build the predicate. + + Returns + ------- + dict[str, tp.Any] + Mapping associating the filter name to its authorized values. + """ @classmethod @abc.abstractmethod - def parameters(cls) -> tuple[str, ...]: - """Initialization parameters name for the class.""" + def parameter(cls) -> FileNameField: + """Initialization parameter for the class.""" diff --git a/src/fcollections/core/_testers.py b/src/fcollections/core/_testers.py index a80c7c0..7f35b53 100644 --- a/src/fcollections/core/_testers.py +++ b/src/fcollections/core/_testers.py @@ -243,6 +243,35 @@ def type(self) -> type[Period]: return Period +GeoBox: tp.TypeAlias = tuple[float, float, float, float] + + +class GeoBoxTester(ITester[GeoBox, GeoBox]): + + @property + def test_description(self) -> str: + return ( + "As a GeoBox field, it can be filtered by giving a reference " + "GeoBox. A file will be filtered out if the GeoBox extracted from " + "its file name does not intersect the reference GeoBox." + ) + + def test(self, reference: GeoBox, tested: GeoBox) -> bool: + longitude_match = (reference[0] <= tested[0] <= reference[2]) or ( + reference[0] <= tested[2] <= reference[2] + ) + + latitude_match = (reference[1] <= tested[1] <= reference[3]) or ( + reference[1] <= tested[3] <= reference[3] + ) + + return longitude_match and latitude_match + + @property + def type(self) -> type[GeoBox]: + return GeoBox + + def _sanitize_time( reference: ( tuple[str | None | np.datetime64, str | None | np.datetime64] diff --git a/src/fcollections/implementations/__init__.py b/src/fcollections/implementations/__init__.py index c6d9610..4a9cbba 100644 --- a/src/fcollections/implementations/__init__.py +++ b/src/fcollections/implementations/__init__.py @@ -2,6 +2,7 @@ __all__ = [] +from ._converters import SwotPhaseFilterBuilder from ._dac import ( BasicNetcdfFilesDatabaseDAC, FileNameConventionDAC, @@ -193,6 +194,8 @@ "SwotReaderL2LRSSH", "SwotReaderL3LRSSH", "SwotReaderL3WW", + # Predicates + "SwotPhaseFilterBuilder", # Common definitions "Delay", "ProductLevel", diff --git a/src/fcollections/implementations/_converters.py b/src/fcollections/implementations/_converters.py new file mode 100644 index 0000000..4573924 --- /dev/null +++ b/src/fcollections/implementations/_converters.py @@ -0,0 +1,51 @@ +import typing as tp + +from fcollections.core import CaseType, FileNameField, FileNameFieldEnum, IFilterBuilder +from fcollections.missions import MissionsPhases, Phase + +from ._definitions._swot import SwotPhases + + +class SwotPhaseFilterBuilder(IFilterBuilder): + """Swot phases filter builder. + + Converts a phase filter (science/calval orbit) to a range of valid + cycle numbers. + """ + + @classmethod + def build_filter(cls, phase: SwotPhases) -> dict[str, slice]: + """Converts a ``phase`` filter to a ``cycle_number``. + + Parameters + ---------- + phase + SWOT mission phase (calval or science). + + Returns + ------- + dict[str, tp.Any] + Mapping associating the filter name to its authorized values. + """ + phase: Phase = MissionsPhases[phase.name.lower()].value + return { + "cycle_number": slice(phase.half_orbits[0][0], phase.half_orbits[1][0] + 1) + } + + @classmethod + def build_predicate(self, record_mapping: dict[str, int], *args: tp.Any): + msg = "SwotPhase filter can only be built as a simple filter." + raise NotImplementedError(msg) + + @classmethod + def parameter(cls) -> FileNameField: + return FileNameFieldEnum( + "phase", + SwotPhases, + description=( + "Phase of the SWOT mission that can be used to select the " + "associated cycle numbers range." + ), + case_type_decoded=CaseType.upper, + case_type_encoded=CaseType.lower, + ) diff --git a/src/fcollections/implementations/_l2_lr_ssh.py b/src/fcollections/implementations/_l2_lr_ssh.py index c43c06c..e49c210 100644 --- a/src/fcollections/implementations/_l2_lr_ssh.py +++ b/src/fcollections/implementations/_l2_lr_ssh.py @@ -22,7 +22,8 @@ SubsetsUnmixer, ) -from ._definitions._constants import DESCRIPTIONS, ProductLevel +from ._converters import SwotPhaseFilterBuilder +from ._definitions._constants import DESCRIPTIONS from ._definitions._swot import ProductSubset from ._readers import SwotReaderL2LRSSH @@ -577,16 +578,19 @@ class BasicNetcdfFilesDatabaseSwotLRL2(FilesDatabase, PeriodMixin): unique=("cycle_number", "pass_number"), auto_pick_last=("version",) ) + # Convert phase filter to cycle_numbers filter + filter_builders = [SwotPhaseFilterBuilder] + try: from fcollections.implementations.optional import ( GeoSwotReaderL2LRSSH, - SwotGeometryPredicate, + SwotGeometryFilterBuilder, ) class NetcdfFilesDatabaseSwotLRL2(BasicNetcdfFilesDatabaseSwotLRL2): reader = GeoSwotReaderL2LRSSH() - predicate_classes = [SwotGeometryPredicate] + filter_builders = [SwotGeometryFilterBuilder, SwotPhaseFilterBuilder] except ImportError: import logging diff --git a/src/fcollections/implementations/_l3_lr_ssh.py b/src/fcollections/implementations/_l3_lr_ssh.py index 6d8f711..eeeea4b 100644 --- a/src/fcollections/implementations/_l3_lr_ssh.py +++ b/src/fcollections/implementations/_l3_lr_ssh.py @@ -15,6 +15,7 @@ SubsetsUnmixer, ) +from ._converters import SwotPhaseFilterBuilder from ._definitions._constants import DESCRIPTIONS, ProductLevel from ._definitions._swot import ProductSubset, Temporality from ._readers import SwotReaderL3LRSSH @@ -129,16 +130,19 @@ class BasicNetcdfFilesDatabaseSwotLRL3(FilesDatabase, PeriodMixin): partition_keys=["version", "subset"], auto_pick_last=("version",) ) + # Convert phase filter to cycle_numbers filter + filter_builders = [SwotPhaseFilterBuilder] + try: from fcollections.implementations.optional import ( GeoSwotReaderL3LRSSH, - SwotGeometryPredicate, + SwotGeometryFilterBuilder, ) class NetcdfFilesDatabaseSwotLRL3(BasicNetcdfFilesDatabaseSwotLRL3): reader = GeoSwotReaderL3LRSSH() - predicate_classes = [SwotGeometryPredicate] + filter_builders = [SwotGeometryFilterBuilder, SwotPhaseFilterBuilder] except ImportError: import logging diff --git a/src/fcollections/implementations/_l3_lr_ww.py b/src/fcollections/implementations/_l3_lr_ww.py index 9a6ec7e..777ccb3 100644 --- a/src/fcollections/implementations/_l3_lr_ww.py +++ b/src/fcollections/implementations/_l3_lr_ww.py @@ -14,6 +14,7 @@ SubsetsUnmixer, ) +from ._converters import SwotPhaseFilterBuilder from ._definitions._constants import DESCRIPTIONS from ._definitions._swot import ProductSubset from ._l3_lr_ssh import AVISO_L3_LR_SSH_LAYOUT_V2 @@ -80,16 +81,19 @@ class BasicNetcdfFilesDatabaseSwotLRWW(FilesDatabase, PeriodMixin): partition_keys=["version", "subset"], auto_pick_last=("version",) ) + # Convert phase filter to cycle_numbers filter + filter_builders = [SwotPhaseFilterBuilder] + try: from fcollections.implementations.optional import ( GeoSwotReaderL3WW, - SwotGeometryPredicate, + SwotGeometryFilterBuilder, ) class NetcdfFilesDatabaseSwotLRWW(BasicNetcdfFilesDatabaseSwotLRWW): reader = GeoSwotReaderL3WW() - predicate_classes = [SwotGeometryPredicate] + filter_builders = [SwotGeometryFilterBuilder, SwotPhaseFilterBuilder] except ImportError: import logging diff --git a/src/fcollections/implementations/optional/__init__.py b/src/fcollections/implementations/optional/__init__.py index 86319fa..b177448 100644 --- a/src/fcollections/implementations/optional/__init__.py +++ b/src/fcollections/implementations/optional/__init__.py @@ -17,7 +17,7 @@ SwathAreaSelector, TemporalSerieAreaSelector, ) -from ._predicates import SwotGeometryPredicate +from ._predicates import SwotGeometryFilterBuilder from ._reader import ( GeoOpenMfDataset, GeoSwotReaderL2LRSSH, @@ -30,7 +30,7 @@ "GeoSwotReaderL3WW", "GeoSwotReaderL2LRSSH", "GeoSwotReaderL3LRSSH", - "SwotGeometryPredicate", + "SwotGeometryFilterBuilder", "AreaSelector1D", "AreaSelector2D", "SwathAreaSelector", diff --git a/src/fcollections/implementations/optional/_predicates.py b/src/fcollections/implementations/optional/_predicates.py index c07630c..837bca7 100644 --- a/src/fcollections/implementations/optional/_predicates.py +++ b/src/fcollections/implementations/optional/_predicates.py @@ -4,34 +4,41 @@ import logging import typing as tp -from fcollections.core import IPredicate +from fcollections.core import FileNameField, FileNameFieldGeoBox, IFilterBuilder from fcollections.geometry import query_half_orbits_intersect from fcollections.missions import PHASES, Missions logger = logging.getLogger(__name__) -class SwotGeometryPredicate(IPredicate): - """Predicate builder for swot karin footprints. +class SwotGeometryFilterBuilder(IFilterBuilder): + """Predicate builder for swot karin footprints.""" - This predicate builder can transform a box in a callable that can predict if - a given half orbit crosses the box. It uses KaRIn reference footprints for - one cycle. + @classmethod + def build_predicate( + cls, record_mapping: dict[str, int], bbox: tuple[float, float, float, float] + ) -> tp.Callable[[tuple[tp.Any, ...]], bool]: + """Build a complex predicate. - Parameters - ---------- - indexes - Indexes of the 'cycle_number' and 'pass_number' element in the input - record of the predicate - bbox - Bounding box, given as lon_min, lat_min, lon_max, lat_max - """ + This predicate builder can transform a ``bbox`` filter in a callable + that can predict if a given half orbit crosses the box. It uses KaRIn + reference footprints (reference footprints are the same across cycles). - def __init__( - self, indexes: tuple[int, int], bbox: tuple[float, float, float, float] - ): + Parameters + ---------- + record_mapping + Mapping between the record names and indexes. Records are given + as a tuple to the predicate, so we need the index to extract the + given fields. + bbox + Bounding box, given as lon_min, lat_min, lon_max, lat_max - self.cycle_number_index, self.pass_number_index = indexes + Returns + ------- + Callable + A predicate that checks whether the input record half orbit is in + the bounding box. + """ def selected( cycle_number: int, @@ -63,22 +70,38 @@ def selected( selected_pass_numbers=pass_numbers_intersect, ) ) - self.predicates = predicates - def __call__(self, record: tuple[tp.Any, ...]) -> bool: - cycle_number, pass_number = ( - record[self.cycle_number_index], - record[self.pass_number_index], - ) - return functools.reduce( - lambda x, y: x or y, - [predicate(cycle_number, pass_number) for predicate in self.predicates], - ) + cycle_number_index = record_mapping["cycle_number"] + pass_number_index = record_mapping["pass_number"] + + def _predicate(record: tuple[tp.Any, ...]) -> bool: + cycle_number, pass_number = ( + record[cycle_number_index], + record[pass_number_index], + ) + return functools.reduce( + lambda x, y: x or y, + [predicate(cycle_number, pass_number) for predicate in predicates], + ) + + return _predicate @classmethod - def record_fields(cls) -> tuple[str, ...]: - return ("cycle_number", "pass_number") + def build_filter(cls): + msg = "Swot Geometry Filter can only be built as a predicate for records." + raise NotImplementedError(msg) @classmethod - def parameters(cls) -> tuple[str, ...]: - return ("bbox",) + def parameter(cls) -> FileNameField: + return FileNameFieldGeoBox( + "bbox", + description=( + "The bounding box (lon_min, lat_min, lon_max, lat_max) used to " + "select the data in a given area. Longitude coordinates can be " + "provided in [-180, 180[ or [0, 360[ convention. If bbox's " + "longitude crosses the circularity, it will be split in two " + "subboxes to ensure a proper selection (e.g. longitude interval" + ": [170, -170] -> data in [170, 180[ and [-180, -170] will be " + "retrieved" + ), + ) diff --git a/tests/core/test_filenames.py b/tests/core/test_filenames.py index 4b096d6..6e1b5fd 100644 --- a/tests/core/test_filenames.py +++ b/tests/core/test_filenames.py @@ -17,6 +17,7 @@ FileNameFieldDatetime, FileNameFieldEnum, FileNameFieldFloat, + FileNameFieldGeoBox, FileNameFieldInteger, FileNameFieldISODuration, FileNameFieldPeriod, @@ -310,6 +311,36 @@ def test_fields_decode_error(field: FileNameField, input_string: str): np.datetime64("2023-01-01"), False, ), + ( + FileNameFieldGeoBox(""), + (-60, -10, 60, 10), + (-50, -8, 58, -6), + True, + ), + ( + FileNameFieldGeoBox(""), + (-60, -10, 60, 10), + (48, 8, 52, 12), + True, + ), + ( + FileNameFieldGeoBox(""), + (-60, -10, 60, 10), + (-80, -20, -59, -9), + True, + ), + ( + FileNameFieldGeoBox(""), + (-60, -10, 60, 10), + (-50, 20, -40, 30), + False, + ), + ( + FileNameFieldGeoBox(""), + (-60, -10, 60, 10), + (70, -5, 80, 5), + False, + ), ], ) def test_field_test(field, reference, tested, filtered): @@ -334,6 +365,7 @@ def test_field_test(field, reference, tested, filtered): ), (FileNameFieldDateDelta("", "", np.timedelta64(1, "D")), Period), (FileNameFieldISODuration(""), ISODuration), + (FileNameFieldGeoBox(""), tuple[float, float, float, float]), ], ) def test_field_type(field, expected_type): @@ -362,6 +394,7 @@ def test_field_type_name(field: FileNameField, expected_type_name: str): (FileNameFieldPeriod("pfield", ""), ["[%Y-%m-%dT%H:%M:%S]"]), (FileNameFieldEnum("efield", Color), ["RED", "BLUE", "GREEN", "gray"]), (FileNameFieldISODuration("Ifield"), ["ISO8601"]), + (FileNameFieldGeoBox("gfield"), ["GeoBox", "intersect"]), ], ) def test_field_description(field: FileNameField, elements: list[str]): diff --git a/tests/core/test_filesdb.py b/tests/core/test_filesdb.py index 67a6e90..7346ccb 100644 --- a/tests/core/test_filesdb.py +++ b/tests/core/test_filesdb.py @@ -15,12 +15,13 @@ from fcollections.core import ( Deduplicator, FileNameConvention, + FileNameField, FileNameFieldDatetime, FileNameFieldInteger, FileNameFieldString, FilesDatabase, IFilesReader, - IPredicate, + IFilterBuilder, Layout, LayoutMismatchError, NotExistingPathError, @@ -99,26 +100,47 @@ class FilesDatabaseTest(FilesDatabaseTestNoUnmixer): unmixer = SubsetsUnmixer(("a_number",)) -class ModuloPredicate(IPredicate): +class ModuloFilterBuilder(IFilterBuilder): - def __init__(self, indexes: tuple[int], b_number: int): - self.index = indexes[0] - self.b_number = b_number + @classmethod + def build_predicate( + cls, record_mapping: dict[str, int], b_number: int + ) -> tp.Callable: + index = record_mapping["a_number"] + + def _predicate(record: tuple[tp.Any, ...]) -> bool: + return record[index] % b_number == 0 - def __call__(self, record: tuple[tp.Any, ...]) -> bool: - return record[self.index] % self.b_number == 0 + return _predicate + + @classmethod + def build_filter(cls, *args): + raise NotImplementedError() @classmethod - def record_fields(cls) -> tuple[str, ...]: - return ("a_number",) + def parameter(cls) -> FileNameField: + return FileNameFieldInteger("b_number") + + +class RangeFilterBuilder(IFilterBuilder): @classmethod - def parameters(cls) -> tuple[str, ...]: - return ("b_number",) + def build_filter(cls, c_number: int) -> dict[str, list[int]]: + return {"a_number": list(range(0, 100, c_number))} + + @classmethod + def build_predicate( + cls, record_mapping: dict[str, int], _number: int + ) -> tp.Callable: + raise NotImplementedError() + + @classmethod + def parameter(cls) -> FileNameField: + return FileNameFieldInteger("c_number") class FilesDatabaseTestPredicate(FilesDatabaseTestNoUnmixer): - predicate_classes = [ModuloPredicate] + filter_builders = [ModuloFilterBuilder, RangeFilterBuilder] def test_bad_path(): @@ -339,7 +361,7 @@ def test_list_files_wrong_filter(db_with_files: FilesDatabaseTest): @pytest.fixture(scope="session") -def db_predicate() -> FilesDatabaseTestPredicate: +def db_predicate_converter() -> FilesDatabaseTestPredicate: fs = fs_mem.MemoryFileSystem() fs.touch("predicate/a_file_001_20250101.nc") fs.touch("predicate/a_file_002_20250101.nc") @@ -349,8 +371,13 @@ def db_predicate() -> FilesDatabaseTestPredicate: return db -def test_list_files_predicate( - db_with_files: FilesDatabaseTest, db_predicate: FilesDatabaseTestPredicate +@pytest.mark.parametrize( + "filters", [dict(b_number=2), dict(c_number=2)], ids=["predicate", "converter"] +) +def test_list_files_filter_builders( + db_with_files: FilesDatabaseTest, + db_predicate_converter: FilesDatabaseTestPredicate, + filters: dict[str, int], ): expected = pda.DataFrame( [ @@ -362,13 +389,20 @@ def test_list_files_predicate( with pytest.raises(ValueError): # Predicate parameter is unknown in DB not setup properly - assert db_with_files.list_files(b_number=2) + assert db_with_files.list_files(**filters) # We should have applied a 'modulo' filter using the b_number argument - assert expected.equals(db_predicate.list_files(b_number=2)) + assert expected.equals(db_predicate_converter.list_files(**filters)) # Auto predicate will not be built - assert expected.equals(db_predicate.list_files(a_number=[2, 4])) + assert expected.equals(db_predicate_converter.list_files(a_number=[2, 4])) + + +def test_list_files_filter_builders_error( + db_predicate_converter: FilesDatabaseTestPredicate, +): + with pytest.raises(ValueError, match="Incompatible"): + db_predicate_converter.list_files(a_number=[2, 4], c_number=2) def test_query_empty(db_with_files: FilesDatabaseTest): diff --git a/tests/implementations/collections/test_l2_lr_ssh.py b/tests/implementations/collections/test_l2_lr_ssh.py index cfc71b7..11f5c87 100644 --- a/tests/implementations/collections/test_l2_lr_ssh.py +++ b/tests/implementations/collections/test_l2_lr_ssh.py @@ -16,6 +16,7 @@ ProductLevel, ProductSubset, StackLevel, + SwotPhases, SwotReaderL2LRSSH, Timeliness, ) @@ -465,6 +466,35 @@ class TestListing: ), ({"subset": ProductSubset.Unsmoothed}, [(10, 4)]), ({"version": L2Version(baseline="A")}, [(577, 18)]), + ( + {"phase": SwotPhases.CALVAL}, + [ + (577, 18), + (546, 11), + (546, 18), + (577, 11), + (577, 18), + (482, 11), + (482, 12), + (482, 25), + (482, 26), + (483, 25), + (483, 26), + (546, 18), + (546, 18), + ], + ), + ( + {"phase": SwotPhases.SCIENCE}, + [ + (6, 11), + (6, 532), + (6, 533), + (7, 532), + (7, 533), + (10, 4), + ], + ), ], ) def test_list( diff --git a/tests/implementations/collections/test_l3_lr_ssh.py b/tests/implementations/collections/test_l3_lr_ssh.py index b5d5018..0857fe1 100644 --- a/tests/implementations/collections/test_l3_lr_ssh.py +++ b/tests/implementations/collections/test_l3_lr_ssh.py @@ -20,6 +20,7 @@ ProductLevel, ProductSubset, StackLevel, + SwotPhases, SwotReaderL3LRSSH, Temporality, ) @@ -468,6 +469,26 @@ class TestListing: {"version": "2.0.1"}, [(532, 25), (532, 26), (533, 25), (533, 26), (10, 532)], ), + ( + {"phase": SwotPhases.CALVAL}, + [ + (531, 25), + (531, 26), + (532, 25), + (532, 26), + (532, 25), + (532, 25), + (532, 26), + (533, 25), + (533, 26), + ], + ), + ( + {"phase": SwotPhases.SCIENCE}, + [ + (10, 532), + ], + ), ], ) def test_list( diff --git a/tests/implementations/collections/test_l3_lr_windwave.py b/tests/implementations/collections/test_l3_lr_windwave.py index d415d38..265941a 100644 --- a/tests/implementations/collections/test_l3_lr_windwave.py +++ b/tests/implementations/collections/test_l3_lr_windwave.py @@ -14,6 +14,7 @@ NetcdfFilesDatabaseSwotLRWW, ProductLevel, ProductSubset, + SwotPhases, SwotReaderL3WW, ) from fcollections.time import Period @@ -272,6 +273,65 @@ def test_read_extended_geographical_selection_disabled( ) +class TestListing: + + @pytest.mark.parametrize( + "query, half_orbits", + [ + ( + {}, + [ + (482, 11), + (482, 12), + (10, 10), + ], + ), + ( + {"cycle_number": [482]}, + [(482, 11), (482, 12)], + ), + ({"pass_number": [10]}, [(10, 10)]), + ( + { + "time": ( + np.datetime64("2024-01-25T03"), + np.datetime64("2024-01-25T03:30"), + ) + }, + [(10, 10)], + ), + ( + {"subset": ProductSubset.Light}, + [(482, 11), (482, 12)], + ), + ( + {"version": "2.0"}, + [(482, 11), (482, 12)], + ), + ({"phase": SwotPhases.CALVAL}, [(482, 11), (482, 12)]), + ( + {"phase": SwotPhases.SCIENCE}, + [ + (10, 10), + ], + ), + ], + ) + def test_list( + self, + l3_lr_ww_dir_layout: Path, + query: dict[str, tp.Any], + half_orbits: list[tuple[int, int]], + ): + + db = NetcdfFilesDatabaseSwotLRWW(l3_lr_ww_dir_layout) + files = db.list_files(**query, sort=True) + actual_half_orbits = sorted( + [tuple(x) for x in files[["cycle_number", "pass_number"]].to_numpy()] + ) + assert actual_half_orbits == sorted(half_orbits) + + class TestQuery: @pytest.mark.without_geo_packages diff --git a/tests/implementations/test_filter_builders.py b/tests/implementations/test_filter_builders.py new file mode 100644 index 0000000..08e6c76 --- /dev/null +++ b/tests/implementations/test_filter_builders.py @@ -0,0 +1,22 @@ +import pytest + +from fcollections.implementations import SwotPhaseFilterBuilder, SwotPhases +from fcollections.implementations.optional import SwotGeometryFilterBuilder + + +def test_geometry_filter_builder_no_filter(): + builder = SwotGeometryFilterBuilder() + with pytest.raises(NotImplementedError): + builder.build_filter() + + +def test_phase_filter_builder(): + builder = SwotPhaseFilterBuilder() + actual = builder.build_filter(SwotPhases.CALVAL) + assert actual == {"cycle_number": slice(402, 579)} + + +def test_phase_filter_builder_no_predicate(): + builder = SwotPhaseFilterBuilder() + with pytest.raises(NotImplementedError): + builder.build_predicate({})