Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 18 additions & 31 deletions src/post_processing/utils/core_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,10 @@ def get_season(ts: Timestamp, *, northern: bool = True) -> tuple[str, int]:

"""
if northern:
winter = [1, 2, 12]
spring = [3, 4, 5]
summer = [6, 7, 8]
autumn = [9, 10, 11]
else:
winter = [6, 7, 8]
spring = [9, 10, 11]
summer = [1, 2, 12]
autumn = [3, 4, 5]
Expand All @@ -79,11 +77,8 @@ def get_season(ts: Timestamp, *, northern: bool = True) -> tuple[str, int]:
season = "summer"
elif ts.month in autumn:
season = "autumn"
elif ts.month in winter:
season = "winter"
else:
msg = "Invalid timestamp"
raise ValueError(msg)
season = "winter"

return season, ts.year - 1 if ts.month in [1, 2] else ts.year

Expand All @@ -93,10 +88,7 @@ def get_sun_times(
stop: Timestamp,
lat: float,
lon: float,
) -> (
list[float],
list[float],
):
) -> tuple[list[float], list[float]]:
"""Fetch sunrise and sunset hours for dates between start and stop.

Parameters
Expand Down Expand Up @@ -171,7 +163,7 @@ def get_coordinates() -> tuple:
f"'{lat}' is not a valid latitude. It must be between -90 and 90.\n"
)
except ValueError:
errmsg += f"'{lat}' is not a valid entry for latitude.\n"
errmsg += f"'lat', invalid entry: '{lat}'.\n"

try:
lon_val = float(lon.strip()) # Convert to float for longitude
Expand Down Expand Up @@ -349,7 +341,7 @@ def set_bar_height(ax: plt.Axes, pixel_height: int = 10) -> float:

"""
if not ax.has_data():
msg = "Axe has no data"
msg = "Axe have no data"
raise ValueError(msg)

display_to_data = ax.transData.inverted().transform
Expand Down Expand Up @@ -380,7 +372,7 @@ def add_recording_period(

"""
if not ax.has_data():
msg = "Axe has no data"
msg = "Axe have no data"
raise ValueError(msg)

recorder_intervals = [
Expand Down Expand Up @@ -419,10 +411,6 @@ def get_count(df: DataFrame, bin_size: Timedelta | BaseOffset) -> DataFrame:
"<label>-<annotator>", containing the count of observations in that bin.

"""
if not isinstance(df, DataFrame):
msg = "`df` must be a DataFrame"
raise TypeError(msg)

if df.empty:
msg = "`df` contains no data"
raise ValueError(msg)
Expand Down Expand Up @@ -464,10 +452,6 @@ def get_labels_and_annotators(df: DataFrame) -> tuple[list, list]:
A tuple containing the labels and annotators lists.

"""
if not isinstance(df, DataFrame):
msg = "`df` must be a DataFrame"
raise TypeError(msg)

if df.empty:
msg = "`df` contains no data"
raise ValueError(msg)
Expand Down Expand Up @@ -497,9 +481,10 @@ def localize_timestamps(timestamps: list[Timestamp], tz: tzinfo) -> list[Timesta
return localized


def get_time_range_and_bin_size(timestamp_list: list[Timestamp],
bin_size: Timedelta | BaseOffset,
) -> (DatetimeIndex, Timedelta):
def get_time_range_and_bin_size(
timestamp_list: list[Timestamp],
bin_size: Timedelta | BaseOffset,
) -> tuple[DatetimeIndex, Timedelta]:
"""Return time vector given a bin size."""
if (not isinstance(timestamp_list, list) or
not all(isinstance(ts, Timestamp) for ts in timestamp_list)):
Expand All @@ -512,18 +497,20 @@ def get_time_range_and_bin_size(timestamp_list: list[Timestamp],

start, end, _ = round_begin_end_timestamps(timestamp_list, bin_size)
timestamp_range = date_range(start=start, end=end, freq=bin_size)

if isinstance(bin_size, Timedelta):
return timestamp_range, bin_size
if isinstance(bin_size, BaseOffset):
elif isinstance(bin_size, BaseOffset):
return timestamp_range, timestamp_range[1] - timestamp_range[0]

msg = "Could not get time range."
raise ValueError(msg)
else:
msg = "bin_size must be a Timedelta or BaseOffset."
raise TypeError(msg)


def round_begin_end_timestamps(timestamp_list: list[Timestamp],
bin_size: Timedelta | BaseOffset,
) -> (Timestamp, Timestamp, Timedelta):
def round_begin_end_timestamps(
timestamp_list: list[Timestamp],
bin_size: Timedelta | BaseOffset,
) -> tuple[Timestamp, Timestamp, Timedelta]:
"""Return time vector given a bin size."""
if (not isinstance(timestamp_list, list) or
not all(isinstance(ts, Timestamp) for ts in timestamp_list)):
Expand Down
33 changes: 18 additions & 15 deletions src/post_processing/utils/filtering_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import bisect
import csv
import datetime
from typing import TYPE_CHECKING

import pytz
Expand Down Expand Up @@ -68,7 +69,7 @@ def filter_by_time(
if end is not None:
df = df[df["end_datetime"] <= end]
if df.empty:
msg = f"No detection found after '{end}'."
msg = f"No detection found before '{end}'."
raise ValueError(msg)

return df
Expand All @@ -81,9 +82,6 @@ def filter_by_annotator(
"""Filter a DataFrame based on annotator selection."""
list_annotators = get_annotators(df)

if not annotator:
return df

if isinstance(annotator, str):
ensure_in_list(annotator, list_annotators, "annotator")
return df[df["annotator"] == annotator]
Expand All @@ -101,9 +99,6 @@ def filter_by_label(
"""Filter a DataFrame based on label selection."""
list_labels = get_labels(df)

if not label:
return df

if isinstance(label, str):
ensure_in_list(label, list_labels, "label")
return df[df["annotation"] == label]
Expand Down Expand Up @@ -192,16 +187,24 @@ def get_dataset(df: DataFrame) -> list[str]:
return datasets if len(datasets) > 1 else datasets[0]


def get_timezone(df: DataFrame):
"""Return timezone(s) from DataFrame."""
def get_canonical_tz(tz):
"""Return timezone of object as a pytz timezone."""
if isinstance(tz, datetime.timezone):
if tz == datetime.timezone.utc:
return pytz.utc
offset_minutes = int(tz.utcoffset(None).total_seconds() / 60)
return pytz.FixedOffset(offset_minutes)
if hasattr(tz, "zone") and tz.zone:
return pytz.timezone(tz.zone)
if hasattr(tz, "key"):
return pytz.timezone(tz.key)
else:
msg = f"Unknown timezone: {tz}"
raise TypeError(msg)

def get_canonical_tz(tz):
if hasattr(tz, "zone") and tz.zone:
return pytz.timezone(tz.zone)
if hasattr(tz, "key"):
return pytz.timezone(tz.key)
return pytz.UTC

def get_timezone(df: DataFrame):
"""Return timezone(s) from DataFrame."""
timezones = {get_canonical_tz(ts.tzinfo) for ts in df["start_datetime"]}

if len(timezones) == 1:
Expand Down
2 changes: 0 additions & 2 deletions tests/test_audio_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ def test_normalize_audio_default_folder(sample_audio: Path, tmp_path: Path) -> N

def test_normalize_audio_custom_folder(sample_audio: Path, tmp_path: Path) -> None:
out_folder = tmp_path / "output"
out_folder.mkdir()

normalize_audio(sample_audio, output_folder=out_folder)

normalized_file = out_folder / sample_audio.name
Expand Down
126 changes: 113 additions & 13 deletions tests/test_core_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from unittest.mock import patch

import pytest
from matplotlib import pyplot as plt
from pandas import DataFrame, Timedelta, Timestamp, date_range
from pandas.tseries import frequencies
from pytz import timezone
Expand All @@ -15,6 +18,10 @@
localize_timestamps,
round_begin_end_timestamps,
timedelta_to_str,
add_season_period,
add_recording_period,
set_bar_height,
json2df,
)


Expand All @@ -38,7 +45,7 @@ def fake_box(msg: str, title: str, fields: list[str]) -> None:


def test_coordinates_invalid_then_valid_input(monkeypatch: pytest.MonkeyPatch) -> None:
inputs = [["900", "50"], ["45", "100"]] # first invalid, then valid
inputs = [["900", "50"], ["45", "900"], ["45", "100"]]

def fake_box(msg: str, title: str, fields: list[str]) -> list[str]:
return inputs.pop(0)
Expand All @@ -49,7 +56,7 @@ def fake_box(msg: str, title: str, fields: list[str]) -> list[str]:


def test_coordinates_non_numeric_input(monkeypatch: pytest.MonkeyPatch) -> None:
inputs = [["abc", "-20"], ["10", "-20"]]
inputs = [["abc", "-20"], ["-20", "abc"], ["10", "-20"]]

def fake_box(msg: str, title: str, fields: list[str]) -> list[str]:
return inputs.pop(0)
Expand Down Expand Up @@ -127,11 +134,12 @@ def test_get_sun_times_valid_input(start: Timestamp,
-1.5167),
],
)
def test_get_sun_times_naive_timestamps(start: Timestamp,
stop: Timestamp,
lat: float,
lon: float,
) -> None:
def test_get_sun_times_naive_timestamps(
start: Timestamp,
stop: Timestamp,
lat: float,
lon: float,
) -> None:
with pytest.raises(ValueError, match="start and stop must be timezone-aware"):
get_sun_times(start, stop, lat, lon)

Expand Down Expand Up @@ -369,14 +377,106 @@ def test_round_begin_end_timestamps_valid_entry_2() -> None:
def test_timedelta_to_str(td, expected) -> None:
assert timedelta_to_str(td) == expected


# %% add_weak_detection / json2df


def test_add_wd(sample_df: DataFrame) -> None:
df_only_wd = sample_df[sample_df["is_box"] == 1]
strong_det = sample_df[sample_df["is_box"] == 0].iloc[0]
add_weak_detection(df=df_only_wd.copy(),
datetime_format="%Y_%m_%d_%H_%M_%S",
max_time=strong_det["end_time"],
max_freq=strong_det["end_frequency"],
)
add_weak_detection(
df=df_only_wd.copy(),
datetime_format="%Y_%m_%d_%H_%M_%S",
)


# %% add_season_period

def test_add_season_valid() -> None:
fig, ax = plt.subplots()
start = Timestamp("2025-01-01T00:00:00+00:00")
stop = Timestamp("2025-01-02T00:00:00+00:00")

ts = date_range(start=start, end=stop, freq="H", tz="UTC")
values = list(range(len(ts)))
ax.plot(ts, values)
add_season_period(ax=ax)


def test_add_season_no_data() -> None:
fig, ax = plt.subplots()
with pytest.raises(ValueError, match=r"have no data"):
add_season_period(ax=ax)

# %% add_recording_period

def test_add_recording_period_valid() -> None:
fig, ax = plt.subplots()
start = Timestamp("2025-01-01T00:00:00+00:00")
stop = Timestamp("2025-01-02T00:00:00+00:00")

ts = date_range(start=start, end=stop, freq="H", tz="UTC")
values = list(range(len(ts)))
ax.plot(ts, values)

df = DataFrame(
data=[
[
Timestamp("2025-01-01T00:00:00+00:00"),
Timestamp("2025-01-02T00:00:00+00:00"),
]
],
columns=["deployment_date", "recovery_date"],
)
add_recording_period(df=df, ax=ax)


def test_add_recording_period_no_data() -> None:
fig, ax = plt.subplots()
df = DataFrame()
with pytest.raises(ValueError, match=r"have no data"):
add_recording_period(df=df, ax=ax)

# %% set_bar_height

def test_set_bar_height_valid() -> None:
fig, ax = plt.subplots()
start = Timestamp("2025-01-01T00:00:00+00:00")
stop = Timestamp("2025-01-02T00:00:00+00:00")

ts = date_range(start=start, end=stop, freq="H", tz="UTC")
values = list(range(len(ts)))
ax.plot(ts, values)

set_bar_height(ax=ax)


def test_set_bar_height_no_data() -> None:
fig, ax = plt.subplots()
with pytest.raises(ValueError, match=r"have no data"):
set_bar_height(ax=ax)

# %% json2df

def test_json2df_valid(tmp_path):
fake_json = {
"deployment_date": "2025-01-01T00:00:00+00:00",
"recovery_date": "2025-01-02T00:00:00+00:00",
}

json_file = tmp_path / "metadatax.json"
json_file.write_text("{}", encoding="utf-8")

with patch("json.load", return_value=fake_json):
df = json2df(json_file)

expected = DataFrame(
data=[
[
Timestamp("2025-01-01T00:00:00+00:00"),
Timestamp("2025-01-02T00:00:00+00:00"),
]
],
columns=["deployment_date", "recovery_date"],
)

assert df.equals(expected)
Loading