diff --git a/src/post_processing/dataclass/detection_filter.py b/src/post_processing/dataclass/detection_filter.py index c7938f4..d636c4c 100644 --- a/src/post_processing/dataclass/detection_filter.py +++ b/src/post_processing/dataclass/detection_filter.py @@ -49,12 +49,12 @@ def from_yaml( cls, file: Path, ) -> DetectionFilter | list[DetectionFilter]: - """Return a DetectionFilter object from a yaml file. + """Return a DetectionFilter object from a YAML file. Parameters ---------- file: Path - The path to a yaml configuration file. + The path to a YAML configuration file. Returns ------- @@ -86,7 +86,7 @@ def from_dict( """ filters = [] for detection_file, filters_dict in parameters.items(): - df_preview = read_dataframe(Path(detection_file), nrows=5) + df_preview = read_dataframe(Path(detection_file), rows=5) filters_dict["timebin_origin"] = Timedelta( max(df_preview["end_time"]), "s", diff --git a/src/post_processing/utils/filtering_utils.py b/src/post_processing/utils/filtering_utils.py index 036aca5..c391ff6 100644 --- a/src/post_processing/utils/filtering_utils.py +++ b/src/post_processing/utils/filtering_utils.py @@ -26,21 +26,50 @@ def find_delimiter(file: Path) -> str: """Find the proper delimiter for a csv file.""" - with file.open(newline="") as csv_file: - try: - temp_lines = csv_file.readline() + "\n" + csv_file.readline() - dialect = csv.Sniffer().sniff(temp_lines, delimiters=",;") - delimiter = dialect.delimiter - except csv.Error as err: - msg = "Could not determine delimiter" - raise ValueError(msg) from err - return delimiter + allowed_delimiters = {",", ";", "\t", "|"} + try: + with file.open("r", encoding="utf-8") as f: + # Read first few lines to detect delimiter + sample = f.read(4096) + + if not sample.strip(): + msg = f"Could not determine delimiter for '{file}': file is empty" + raise ValueError(msg) + + sniffer = csv.Sniffer() + dialect = sniffer.sniff(sample) + + if dialect.delimiter not in allowed_delimiters: + msg = (f"Could not determine delimiter for '{file}': " + f"unsupported delimiter '{dialect.delimiter}'") + raise ValueError(msg) + + return dialect.delimiter + + except csv.Error as e: + msg = f"Could not determine delimiter for '{file}': {e}" + raise ValueError(msg) from e def filter_strong_detection( df: DataFrame, ) -> DataFrame: - """Filter strong detections of a DataFrame.""" + """Filter to keep only weak detections (exclude box/strong annotations). + + This function identifies and removes "strong" or "box" type annotations, + keeping only "weak" detections. It checks for either an 'is_box' or 'type' column. + + Parameters + ---------- + df : DataFrame + APLOSE-formatted DataFrame with either 'is_box' or 'type' column. + + Returns + ------- + DataFrame + Filtered DataFrame containing only weak detections. + + """ if "type" in df.columns: df = df[df["type"] == "WEAK"] elif "is_box" in df.columns: @@ -59,7 +88,23 @@ def filter_by_time( begin: Timestamp | None, end: Timestamp | None, ) -> DataFrame: - """Filter a DataFrame based on begin and/or end timestamps.""" + """Filter detections by time range. + + Parameters + ---------- + df : DataFrame + APLOSE DataFrame containing 'start_datetime' and 'end_datetime' columns. + begin : Timestamp, optional + Start of time range (inclusive). If None, no lower bound is applied. + end : Timestamp, optional + End of time range (inclusive). If None, no upper bound is applied. + + Returns + ------- + DataFrame + Filtered DataFrame containing only detections within the specified time range. + + """ if begin is not None: df = df[df["start_datetime"] >= begin] if df.empty: @@ -79,7 +124,21 @@ def filter_by_annotator( df: DataFrame, annotator: str | list[str] | None, ) -> DataFrame: - """Filter a DataFrame based on annotator selection.""" + """Filter a DataFrame based on annotator selection. + + Parameters + ---------- + df : DataFrame + APLOSE-formatted DataFrame containing an 'annotator' column. + annotator : str or list of str + Single annotator name or list of annotator names to filter by. + + Returns + ------- + DataFrame + Filtered DataFrame containing only detections from the specified annotator(s). + + """ if annotator is None: return df @@ -99,7 +158,21 @@ def filter_by_label( df: DataFrame, label: str | list[str] | None, ) -> DataFrame: - """Filter a DataFrame based on label selection.""" + """Filter a DataFrame based on label selection. + + Parameters + ---------- + df : DataFrame + APLOSE-formatted DataFrame containing an 'annotation' column. + label : str or list of str + Single label or list of labels to filter by. + + Returns + ------- + DataFrame + Filtered DataFrame containing only detections with the specified label(s). + + """ if label is None: return df @@ -120,7 +193,23 @@ def filter_by_freq( f_min: int | None, f_max: int | None, ) -> DataFrame: - """Filter a DataFrame based on frequency selection.""" + """Filter a DataFrame based on frequency selection. + + Parameters + ---------- + df : DataFrame + APLOSE DataFrame containing 'start_frequency' and 'end_frequency' columns. + f_min : float, optional + Minimum frequency in Hz (inclusive). If None, no lower bound is applied. + f_max : float, optional + Maximum frequency in Hz (inclusive). If None, no upper bound is applied. + + Returns + ------- + DataFrame + Filtered DataFrame with only detections within the specified frequency range. + + """ if f_min is not None: df = df[df["start_frequency"] >= f_min] if df.empty: @@ -136,7 +225,21 @@ def filter_by_freq( def filter_by_score(df: DataFrame, score: float) -> DataFrame: - """Filter a DataFrame based on minimum score.""" + """Filter detections by confidence score. + + Parameters + ---------- + df : DataFrame + APLOSE-formatted DataFrame containing a 'score' column. + score : float + The minimum confidence score threshold (inclusive). + + Returns + ------- + DataFrame + Filtered DataFrame containing only detections with score >= min_score. + + """ if not score: return df @@ -151,14 +254,14 @@ def filter_by_score(df: DataFrame, score: float) -> DataFrame: return df -def read_dataframe(file: Path, nrows: int | None = None) -> DataFrame: - """Read csv file.""" +def read_dataframe(file: Path, rows: int | None = None) -> DataFrame: + """Read an APLOSE-formatted CSV file into a DataFrame.""" delimiter = find_delimiter(file) return ( read_csv(file, sep=delimiter, parse_dates=["start_datetime", "end_datetime"], - nrows=nrows, + nrows=rows, ) .drop_duplicates() .dropna(subset=["annotation"]) @@ -168,33 +271,48 @@ def read_dataframe(file: Path, nrows: int | None = None) -> DataFrame: def get_annotators(df: DataFrame) -> list[str]: - """Return list of annotators.""" + """Return the annotator list of APLOSE DataFrame.""" return sorted(set(df["annotator"])) def get_labels(df: DataFrame) -> list[str]: - """Return list of labels.""" + """Return the label list of APLOSE DataFrame.""" return sorted(set(df["annotation"])) def get_max_freq(df: DataFrame) -> float: - """Return the maximum frequency of DataFrame.""" + """Return the maximum frequency of APLOSE DataFrame.""" return df["end_frequency"].max() def get_max_time(df: DataFrame) -> float: - """Return the maximum time of DataFrame.""" + """Return the maximum time of APLOSE DataFrame.""" return df["end_time"].max() -def get_dataset(df: DataFrame) -> list[str]: - """Return list of datasets.""" +def get_dataset(df: DataFrame) -> str | list[str]: + """Return dataset list of APLOSE DataFrame.""" datasets = sorted(set(df["dataset"])) return datasets if len(datasets) > 1 else datasets[0] -def get_canonical_tz(tz): - """Return timezone of object as a pytz timezone.""" +def get_canonical_tz(tz: datetime.tzinfo) -> pytz.tzinfo.BaseTzInfo: + """Convert a timezone object to its canonical pytz representation. + + This function ensures compatibility between different timezone implementations + (pytz, zoneinfo) by converting them to pytz timezone objects. + + Parameters + ---------- + tz : datetime.tzinfo + Timezone object (can be pytz timezone or ZoneInfo). + + Returns + ------- + pytz.tzinfo.BaseTzInfo + Canonical pytz timezone object. + + """ if isinstance(tz, datetime.timezone): if tz == datetime.UTC: return pytz.utc @@ -208,7 +326,8 @@ def get_canonical_tz(tz): raise TypeError(msg) -def get_timezone(df: DataFrame): +def get_timezone(df: DataFrame)\ + -> pytz.tzinfo.BaseTzInfo | list[pytz.tzinfo.BaseTzInfo]: """Return timezone(s) from APLOSE DataFrame. Parameters @@ -218,8 +337,8 @@ def get_timezone(df: DataFrame): Returns ------- - tzoffset: list[tzoffset] - list of timezones + tzoffset: list[tzoffset] + list of timezones """ timezones = {get_canonical_tz(ts.tzinfo) for ts in df["start_datetime"]} @@ -230,13 +349,15 @@ def get_timezone(df: DataFrame): def check_timestamp(df: DataFrame, timestamp_audio: list[Timestamp]) -> None: - """Check if provided timestamp_audio list is correctly formated. + """Check if a provided timestamp_audio list is correctly formated. Parameters ---------- - df: DataFrame APLOSE results Dataframe. - timestamp_audio: A list of timestamps. Each timestamp is the start datetime of the - corresponding audio file for each detection in df. + df: DataFrame + APLOSE results Dataframe. + timestamp_audio: list[Timestamp] + A list of timestamps. Each timestamp is + the start datetime of the corresponding audio file for each detection in df. """ if timestamp_audio is None: @@ -247,11 +368,137 @@ def check_timestamp(df: DataFrame, timestamp_audio: list[Timestamp]) -> None: raise ValueError(msg) +def _build_filename_vector( + time_vector: list[Timestamp], + ts_detect_beg: list[Timestamp], + timestamp_audio: list[Timestamp], + filenames: list[str], +) -> list[str]: + """Build the filename vector for each time bin.""" + filename_vector = [] + for ts in time_vector: + idx = bisect.bisect_left(ts_detect_beg, ts) + + if idx == 0: + filename_vector.append(filenames[0]) + elif idx == len(ts_detect_beg): + filename_vector.append(filenames[-1]) + else: + # Choose a filename based on timestamp_audio + filename_vector.append( + filenames[idx] if timestamp_audio[idx] <= ts else filenames[idx - 1], + ) + + return filename_vector + + +def _build_detection_vector( + time_vector: list[Timestamp], + ts_detect_beg: list[Timestamp], + ts_detect_end: list[Timestamp], +) -> list[int]: + """Build a binary detection vector indicating presence in each time bin.""" + detect_vec = [0] * len(time_vector) + + for start, end in zip(ts_detect_beg, ts_detect_end, strict=False): + idx = bisect.bisect_left(time_vector, start) + idx = idx if start in time_vector else max(0, idx - 1) + + while idx < len(time_vector) and time_vector[idx] < end: + detect_vec[idx] = 1 + idx += 1 + + return detect_vec + + +def _create_result_dataframe( + file_vector: list[str], + start_datetime: list[Timestamp], + timebin_new: Timedelta, + max_freq: float, + dataset: str, + label: str, + annotator: str, +) -> DataFrame: + """Create result DataFrame for one annotator-label combination.""" + return DataFrame({ + "dataset": [dataset] * len(file_vector), + "filename": file_vector, + "start_time": [0] * len(file_vector), + "end_time": [timebin_new.total_seconds()] * len(file_vector), + "start_frequency": [0] * len(file_vector), + "end_frequency": [max_freq] * len(file_vector), + "annotation": [label] * len(file_vector), + "annotator": [annotator] * len(file_vector), + "start_datetime": start_datetime, + "end_datetime": [t + timebin_new for t in start_datetime], + "type": ["WEAK"] * len(file_vector), + }) + + +def _normalize_timezones(df: DataFrame) -> DataFrame: + """Convert all timestamps to UTC if multiple timezones are present.""" + if isinstance(get_timezone(df), list): + df["start_datetime"] = [ + to_datetime(elem, utc=True) for elem in df["start_datetime"] + ] + df["end_datetime"] = [ + to_datetime(elem, utc=True) for elem in df["end_datetime"] + ] + return df + + +def _process_annotator_label_pair( + df: DataFrame, + annotator: str, + label: str, + timebin_new: Timedelta, + timestamp_audio: list[Timestamp], + max_freq: float, + dataset: str, +) -> DataFrame | None: + """Process detections for one annotator-label combination.""" + df_subset = df[(df["annotator"] == annotator) & (df["annotation"] == label)] + + if df_subset.empty: + return None + + # Create a time vector + t1 = min(df_subset["start_datetime"]).floor(timebin_new) + t2 = max(df_subset["end_datetime"]).ceil(timebin_new) + time_vector = date_range(start=t1, end=t2, freq=timebin_new) + + # Extract detection data + ts_detect_beg = df_subset["start_datetime"].to_list() + ts_detect_end = df_subset["end_datetime"].to_list() + filenames = df_subset["filename"].to_list() + + # Build vectors + filename_vector = _build_filename_vector( + time_vector, ts_detect_beg, timestamp_audio, filenames, + ) + detect_vec = _build_detection_vector(time_vector, ts_detect_beg, ts_detect_end) + + # Filter to only detected time bins + start_datetime = [ + time_vector[i] for i, detected in enumerate(detect_vec) if detected + ] + file_vector = [ + filename_vector[i] for i, detected in enumerate(detect_vec) if detected + ] + + if not start_datetime: + return None + + return _create_result_dataframe( + file_vector, start_datetime, timebin_new, max_freq, dataset, label, annotator, + ) + + def reshape_timebin( df: DataFrame, - *, timebin_new: Timedelta | None, - timestamp_audio: list[Timestamp] | None = None, + timestamp_audio: list[Timestamp], ) -> DataFrame: """Reshape an APLOSE result DataFrame according to a new time bin. @@ -280,103 +527,34 @@ def reshape_timebin( check_timestamp(df, timestamp_audio) + # Extract metadata annotators = get_annotators(df) labels = get_labels(df) max_freq = get_max_freq(df) dataset = get_dataset(df) - if isinstance(get_timezone(df), list): - df["start_datetime"] = [to_datetime(elem, utc=True) - for elem in df["start_datetime"] - ] - df["end_datetime"] = [to_datetime(elem, utc=True) - for elem in df["end_datetime"] - ] + # Normalize timezones if needed + df = _normalize_timezones(df) + # Process each annotator-label combination results = [] for ant in annotators: for lbl in labels: - df_1annot_1label = df[(df["annotator"] == ant) & (df["annotation"] == lbl)] - - if df_1annot_1label.empty: - continue - - if timestamp_audio is not None: - # I do not remember if this is a regular case or not - # might need to be deleted - #origin_timebin = timestamp_audio[1] - timestamp_audio[0] - #step = int(timebin_new / origin_timebin) - #time_vector = timestamp_audio[0::step] - #else: - t1 = min(df_1annot_1label["start_datetime"]).floor(timebin_new) - t2 = max(df_1annot_1label["end_datetime"]).ceil(timebin_new) - time_vector = date_range(start=t1, end=t2, freq=timebin_new) - - ts_detect_beg = df_1annot_1label["start_datetime"].to_list() - ts_detect_end = df_1annot_1label["end_datetime"].to_list() - filenames = df_1annot_1label["filename"].to_list() - - # filename_vector - filename_vector = [] - for ts in time_vector: - if (bisect.bisect_left(ts_detect_beg, ts) > 0 and - bisect.bisect_left(ts_detect_beg, ts) != len(ts_detect_beg)): - idx = bisect.bisect_left(ts_detect_beg, ts) - filename_vector.append( - filenames[idx] if timestamp_audio[idx] <= ts else - filenames[idx - 1], - ) - elif bisect.bisect_left(ts_detect_beg, ts) == len(ts_detect_beg): - filename_vector.append(filenames[-1]) - else: - filename_vector.append(filenames[0]) - - # detection vector - detect_vec = [0] * len(time_vector) - for start, end in zip(ts_detect_beg, ts_detect_end, strict=False): - idx = bisect.bisect_left(time_vector, start) - idx = idx if start in time_vector else max(0, idx - 1) - while idx < len(time_vector) and time_vector[idx] < end: - detect_vec[idx] = 1 - idx += 1 - - # rows for dataframe - start_datetime = [ - time_vector[i] for i in range(len(time_vector)) if detect_vec[i] - ] - end_datetime = [t + timebin_new for t in start_datetime] - file_vector = [ - filename_vector[i] for i in range(len(time_vector)) if detect_vec[i] - ] - - if start_datetime: - results.append( - DataFrame( - { - "dataset": [dataset] * len(file_vector), - "filename": file_vector, - "start_time": [0] * len(file_vector), - "end_time": [timebin_new.total_seconds()] - * len(file_vector), - "start_frequency": [0] * len(file_vector), - "end_frequency": [max_freq] * len(file_vector), - "annotation": [lbl] * len(file_vector), - "annotator": [ant] * len(file_vector), - "start_datetime": start_datetime, - "end_datetime": end_datetime, - "type": ["WEAK"] * len(file_vector), - }, - ), - ) - - return (concat(results). - sort_values(by=["start_datetime", "end_datetime", - "annotator", "annotation"]).reset_index(drop=True) + result = _process_annotator_label_pair( + df, ant, lbl, timebin_new, timestamp_audio, max_freq, dataset, ) + if result is not None: + results.append(result) + + return ( + concat(results) + .sort_values(by=["start_datetime", "end_datetime", "annotator", "annotation"]) + .reset_index(drop=True) + ) def get_filename_timestamps(df: DataFrame, date_parser: str) -> list[Timestamp]: - """Get start timestamps of the wav files of each detection contained in df. + """Get audio file start timestamps of each detection contained in df. Parameters. ---------- @@ -403,6 +581,7 @@ def get_filename_timestamps(df: DataFrame, date_parser: str) -> list[Timestamp]: msg = """Could not parse timestamps from `df["filename"]`.""" raise ValueError(msg) from None + def ensure_in_list(value: str, candidates: list[str], label: str) -> None: """Check for non-valid elements of a list.""" if value not in candidates: @@ -442,7 +621,7 @@ def load_detections(filters: DetectionFilter) -> DataFrame: filename_ts = get_filename_timestamps(df, filters.filename_format) df = reshape_timebin(df, timebin_new=filters.timebin_new, - timestamp_audio=filename_ts + timestamp_audio=filename_ts, ) annotators = get_annotators(df) diff --git a/tests/test_filtering_utils.py b/tests/test_filtering_utils.py index 68b8d20..95fd987 100644 --- a/tests/test_filtering_utils.py +++ b/tests/test_filtering_utils.py @@ -9,24 +9,24 @@ from pandas import DataFrame, Timedelta, Timestamp, concat, to_datetime from post_processing.utils.filtering_utils import ( + ensure_no_invalid, filter_by_annotator, - filter_strong_detection, filter_by_freq, filter_by_label, filter_by_score, filter_by_time, + filter_strong_detection, find_delimiter, get_annotators, + get_canonical_tz, get_dataset, get_labels, get_max_freq, get_max_time, get_timezone, + intersection_or_union, read_dataframe, reshape_timebin, - get_canonical_tz, - ensure_no_invalid, - intersection_or_union, ) # %% find delimiter @@ -49,9 +49,15 @@ def test_find_delimiter_valid(tmp_path: Path, assert detected == delimiter -def test_find_delimiter_invalid(tmp_path: Path) -> None: - file = tmp_path / "invalid.csv" - file.write_text("this is not really&csv&content") +def test_find_delimiter_invalid(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + file = tmp_path / "bad.csv" + file.write_text("a,b,c") + + def raise_error(*args, **kwargs): + raise csv.Error("sniff failed") + + monkeypatch.setattr(csv.Sniffer, "sniff", raise_error) + with pytest.raises(ValueError, match="Could not determine delimiter"): find_delimiter(file) @@ -63,6 +69,19 @@ def test_find_delimiter_empty_file(tmp_path: Path) -> None: find_delimiter(file) +def test_find_delimiter_unsupported_delimiter(tmp_path: Path) -> None: + file = tmp_path / "lame.csv" + + # '&' is consistent and sniffable, but not allowed + file.write_text("a&b&c\n1&2&3\n") + + with pytest.raises( + ValueError, + match=r"unsupported delimiter '&'" + ): + find_delimiter(file) + + # %% filter utils # filter_by_time @@ -369,7 +388,7 @@ def test_read_dataframe_nrows(tmp_path: Path) -> None: "2025-01-01 13:00:00,2025-01-01 13:05:00,dolphin\n", ) - df = read_dataframe(csv_file, nrows=1) + df = read_dataframe(csv_file, rows=1) assert len(df) == 1 assert df.iloc[0]["annotation"] in {"whale", "dolphin"} @@ -377,7 +396,7 @@ def test_read_dataframe_nrows(tmp_path: Path) -> None: # %% reshape_timebin def test_no_timebin_returns_original(sample_df: DataFrame) -> None: - df_out = reshape_timebin(sample_df, timebin_new=None) + df_out = reshape_timebin(sample_df, timebin_new=None, timestamp_audio=None) assert df_out.equals(sample_df)