Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions boem_conf/boem_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ active_learning:
drop_n_most_common: 5 # Number of most common classes to drop
rarest_confidence_selection: highest # "highest" or "lowest" - select highest or lowest confidence samples from rarest classes

# Crop + H-CAST ensemble (requires hierarchical.checkpoint for match_or_genus_consistent and model-disagreement):
# ensemble_target_mode: crop_only # crop_only | match_or_genus_consistent — for target-labels / taxonomy
# strategy: model-disagreement # prioritize images where cropmodel_label != hcast_species
# disagreement_require_genus_mismatch: false # if true, exclude congeneric species disagreements
# disagreement_target_labels: [] # optional list — keep disagreements where crop OR H-CAST species is in this set

# Optional parameters:
evaluation:
dask_client:
Expand Down
227 changes: 181 additions & 46 deletions src/active_learning.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,97 @@
import json
import os
import random
from pathlib import Path

import pandas as pd

from src import detection
from src import hierarchical


def load_species_to_genus_from_csv(label_csv_path: str | None) -> dict[str, str]:
"""Build species -> genus map from H-CAST label CSV (species + genus columns)."""
if not label_csv_path or not os.path.isfile(label_csv_path):
return {}
df = pd.read_csv(label_csv_path).dropna(subset=["species"])
if "genus" not in df.columns:
return {}
return dict(zip(df["species"].astype(str), df["genus"].astype(str)))


def crop_hcast_supported_match_or_genus_consistent(row: pd.Series, species_to_genus: dict[str, str]) -> bool:
"""H-CAST supports crop target species if species heads agree or genus agrees."""
crop = row.get("cropmodel_label")
if crop is None or pd.isna(crop):
return False
crop_s = str(crop)
h_sp = row.get("hcast_species")
if h_sp is None or pd.isna(h_sp):
return False
if crop_s == str(h_sp):
return True
crop_g = species_to_genus.get(crop_s)
h_gn = row.get("hcast_genus")
if crop_g is None or h_gn is None or pd.isna(h_gn):
return False
return str(crop_g) == str(h_gn)


def row_crop_hcast_disagrees(row: pd.Series, species_to_genus: dict[str, str], strict_genus_mismatch: bool) -> bool:
"""True when crop species != H-CAST species; if strict_genus_mismatch, exclude same-genus congeneric swaps."""
crop = row.get("cropmodel_label")
h_sp = row.get("hcast_species")
if crop is None or pd.isna(crop) or h_sp is None or pd.isna(h_sp):
return False
crop_s, h_s = str(crop), str(h_sp)
if crop_s == h_s:
return False
if not strict_genus_mismatch:
return True
crop_g = species_to_genus.get(crop_s)
h_gn = row.get("hcast_genus")
if crop_g is None or h_gn is None or pd.isna(h_gn):
return True
return str(crop_g) != str(h_gn)


def _row_min_class_confidence(row: pd.Series) -> float:
vals = []
for key in ("cropmodel_score", "hcast_species_score"):
x = row.get(key)
if x is not None and pd.notna(x) and isinstance(x, (int, float)):
vals.append(float(x))
return min(vals) if vals else 0.0


def format_ensemble_suggestion_line(row: pd.Series, species_to_genus: dict[str, str] | None) -> str | None:
"""Human-readable ensemble suggestion for Label Studio summary (genus fallback when species disagree)."""
if "hcast_species" not in row or pd.isna(row.get("hcast_species")):
return None
crop = row.get("cropmodel_label")
if crop is None or pd.isna(crop):
return None
crop_s, h_sp = str(crop), str(row["hcast_species"])
h_gn = row.get("hcast_genus")
gn_s = str(h_gn) if h_gn is not None and pd.notna(h_gn) else None
if crop_s == h_sp:
return f"Ensemble suggestion: {crop_s} (species agreement)"
crop_genus = species_to_genus.get(crop_s) if species_to_genus else None
if crop_genus is not None and gn_s is not None and str(crop_genus) == gn_s:
return (
f"Ensemble suggestion: genus {crop_genus} (species disagree: crop={crop_s}, H-CAST={h_sp})"
)
fa = row.get("hcast_family")
fa_s = str(fa) if fa is not None and pd.notna(fa) else None
parts = [f"Ensemble ambiguous: crop={crop_s}, H-CAST species={h_sp}"]
if gn_s:
parts.append(f", H-CAST genus={gn_s}")
if fa_s:
parts.append(f", H-CAST family={fa_s}")
parts.append(" — verify taxonomy manually")
return "".join(parts)


def _collect_leaf_aliases(node: dict) -> list[str]:
"""Collect all leaf (species) alias strings under a taxonomy node."""
if node.get("isLeaf"):
Expand Down Expand Up @@ -41,6 +127,7 @@ def walk(node: dict) -> None:
walk(item)
return result


def human_review(predictions, min_detection_score=0.6, min_classification_score=0.5, confident_threshold=0.5):
"""
Predict on images and divide into confident and uncertain predictions.
Expand All @@ -51,22 +138,22 @@ def human_review(predictions, min_detection_score=0.6, min_classification_score=
predictions (pd.DataFrame, optional): A DataFrame of existing predictions. Defaults to None.
Returns:
tuple: A tuple of confident and uncertain predictions.
"""
"""
filtered_predictions = predictions[
(predictions["score"] >= min_detection_score) &
(predictions["cropmodel_score"] < min_classification_score)
]

# Split predictions into confident and uncertain
uncertain_predictions = filtered_predictions[
filtered_predictions["cropmodel_score"] <= confident_threshold]

confident_predictions = filtered_predictions[
~filtered_predictions["image_path"].isin(
uncertain_predictions["image_path"])]

return confident_predictions, uncertain_predictions


def generate_pool_predictions(
pool,
patch_size=512,
Expand Down Expand Up @@ -135,6 +222,7 @@ def generate_pool_predictions(

return preannotations


def _validate_target_labels(target_labels: list[str], valid_labels: set[str] | list[str] | None) -> None:
"""Raise ValueError if any target label is not in the crop model's label set (catches typos)."""
if valid_labels is None:
Expand All @@ -161,43 +249,35 @@ def select_images(
taxonomy_path=None,
taxonomy_aliases=None,
valid_labels=None,
ensemble_target_mode: str = "crop_only",
species_to_genus: dict[str, str] | None = None,
disagreement_require_genus_mismatch: bool = False,
disagreement_target_labels: list[str] | None = None,
):
"""
Select images to annotate based on the strategy.

Args:
preannotations (pd.DataFrame): A DataFrame of predictions.
strategy (str): The strategy for choosing images. Available strategies are:
- "random": Choose images randomly from the pool.
- "most-detections": Choose images with the most detections based on predictions.
- "target-labels": Choose images with target labels (species-level).
- "taxonomy": Like target-labels but taxonomy_aliases (e.g. Aves, Mammalia, Cepphus)
are expanded to all leaf species under those nodes using transformed_taxonomy.json.
- "rarest": Choose images with rarest class labels.
n (int, optional): The number of images to choose. Defaults to 10.
target_labels (list, optional): For target-labels: list of species labels. Defaults to None.
min_score (float, optional): The minimum detection score for a prediction to be included. Defaults to 0.3.
drop_n_most_common (int, optional): For rarest strategy, number of most common classes to drop. Defaults to 1.
rarest_confidence_selection (str, optional): For rarest strategy, "highest" or "lowest" confidence selection. Defaults to "lowest".
min_classification_score (float, optional): Minimum classification confidence score. Defaults to None (no filter).
taxonomy_path (str | Path, optional): Path to transformed_taxonomy.json. Required for strategy "taxonomy".
taxonomy_aliases (list[str], optional): For strategy "taxonomy": e.g. ["Aves", "Mammalia", "Cepphus"]. Defaults to None.
valid_labels (set | list, optional): Crop model label set (e.g. label_dict.keys()). If provided, target-labels
and taxonomy-expanded labels are validated to catch typos/misspellings.
strategy (str): One of random, most-detections, target-labels, taxonomy, rarest, model-disagreement.
ensemble_target_mode (str): For target-labels/taxonomy: crop_only or match_or_genus_consistent.
species_to_genus (dict): Species binomial -> genus for ensemble consistency checks.
disagreement_require_genus_mismatch (bool): For model-disagreement: exclude congeneric species flips.
disagreement_target_labels (list): Optional union filter on crop or H-CAST species labels.

Returns:
list: A list of image paths.
pd.DataFrame: A DataFrame of preannotations for the chosen images.
tuple: (chosen_image_paths, chosen_preannotations_df, al_stats_dict)
"""
al_stats: dict = {}
if preannotations.empty:
return [], None
return [], None, al_stats

if strategy == "random":
n = min(n, len(preannotations["image_path"].unique()))
chosen_images = random.sample(preannotations["image_path"].unique().tolist(), n)

else:
preannotations = preannotations[preannotations["score"] >= min_score]
preannotations = preannotations[preannotations["score"] >= min_score].copy()

if strategy == "taxonomy":
if taxonomy_aliases is None or not taxonomy_aliases:
Expand All @@ -210,7 +290,7 @@ def select_images(
)
target_labels = list(get_leaf_labels_for_taxonomy_aliases(taxonomy_path, taxonomy_aliases))
if not target_labels:
return [], None
return [], None, al_stats
if valid_labels is not None:
valid_set = set(valid_labels)
target_labels = [lbl for lbl in target_labels if lbl in valid_set]
Expand All @@ -227,45 +307,100 @@ def select_images(
_validate_target_labels(target_labels, valid_labels)

if strategy == "most-detections":
# Sort images by total number of predictions
chosen_images = preannotations.groupby("image_path").size().sort_values(ascending=False).head(n).index.tolist()
elif strategy == "target-labels":
# Filter images by target labels (already validated above if valid_labels provided)
chosen_images = preannotations[preannotations.cropmodel_label.isin(target_labels)].groupby("image_path")["score"].mean().sort_values(ascending=False).head(n).index.tolist()
mask_crop_target = preannotations["cropmodel_label"].isin(target_labels)
al_stats["al_target_crop_hits_rows"] = int(mask_crop_target.sum())

if ensemble_target_mode == "match_or_genus_consistent":
if "hcast_species" not in preannotations.columns:
raise ValueError(
"ensemble_target_mode='match_or_genus_consistent' requires hierarchical prediction columns "
"(hcast_species). Enable hierarchical.checkpoint or use ensemble_target_mode='crop_only'."
)
sg = species_to_genus if species_to_genus is not None else {}
supported = preannotations.apply(
lambda r: crop_hcast_supported_match_or_genus_consistent(r, sg),
axis=1,
)
mask = mask_crop_target & supported
al_stats["al_target_after_ensemble_rows"] = int(mask.sum())
al_stats["al_ensemble_target_mode"] = ensemble_target_mode
elif ensemble_target_mode == "crop_only":
mask = mask_crop_target
else:
raise ValueError(
f"Unknown ensemble_target_mode {ensemble_target_mode!r}. "
"Use 'crop_only' or 'match_or_genus_consistent'."
)

filtered = preannotations[mask]
if filtered.empty:
return [], None, al_stats
chosen_images = (
filtered.groupby("image_path")["score"].mean().sort_values(ascending=False).head(n).index.tolist()
)
elif strategy == "model-disagreement":
if "hcast_species" not in preannotations.columns:
raise ValueError(
"strategy 'model-disagreement' requires hcast_species (enable hierarchical.checkpoint)."
)
sg = species_to_genus if species_to_genus is not None else {}
disagree_mask = preannotations.apply(
lambda r: row_crop_hcast_disagrees(r, sg, disagreement_require_genus_mismatch),
axis=1,
)
pool = preannotations[disagree_mask].copy()
al_stats["al_disagreement_boxes_before_target_filter"] = len(pool)
if disagreement_target_labels:
tl = set(disagreement_target_labels)
pool = pool[pool["cropmodel_label"].isin(tl) | pool["hcast_species"].isin(tl)]
if min_classification_score is not None and "cropmodel_score" in pool.columns:
pool = pool[pool["cropmodel_score"] >= min_classification_score]
al_stats["al_disagreement_boxes_after_filters"] = len(pool)
al_stats["al_disagreement_strict_genus"] = float(bool(disagreement_require_genus_mismatch))
if pool.empty:
print("model-disagreement: no disagreeing boxes after filters")
return [], None, al_stats
pool["_joint_conf"] = pool.apply(_row_min_class_confidence, axis=1)
agg = (
pool.groupby("image_path")
.agg(disagreement_count=("image_path", "count"), mean_joint_conf=("_joint_conf", "mean"))
.sort_values(["disagreement_count", "mean_joint_conf"], ascending=[False, False])
)
al_stats["al_disagreement_images_available"] = int(agg.shape[0])
print(
f"model-disagreement: {len(pool)} disagreeing boxes across "
f"{agg.shape[0]} images (after filters); selecting top {n} images"
)
chosen_images = agg.head(n).index.tolist()
elif strategy == "rarest":
# Filter by minimum classification score if provided
if min_classification_score is not None and "cropmodel_score" in preannotations.columns:
preannotations = preannotations[preannotations["cropmodel_score"] >= min_classification_score]

# Drop n most common classes

if drop_n_most_common > 0:
most_common_labels = preannotations["cropmodel_label"].value_counts().nlargest(drop_n_most_common).index
preannotations = preannotations[~preannotations["cropmodel_label"].isin(most_common_labels)]

if preannotations.empty:
return [], None

# Sort images by least common label
return [], None, al_stats

label_counts = preannotations.groupby("cropmodel_label").size().sort_values(ascending=True)

# Sort preannotations by least common label
preannotations["label_count"] = preannotations["cropmodel_label"].map(label_counts)

# Sort by label count first, then by confidence score

if "cropmodel_score" in preannotations.columns:
ascending_conf = rarest_confidence_selection == "lowest"
preannotations.sort_values(["label_count", "cropmodel_score"], ascending=[True, ascending_conf], inplace=True)
else:
preannotations.sort_values("label_count", ascending=True, inplace=True)

chosen_images = preannotations.drop_duplicates(subset=["image_path"], keep="first").head(n)["image_path"].tolist()
else:
raise ValueError(
"Invalid strategy. Must be one of 'random', 'most-detections', 'target-labels', 'taxonomy', or 'rarest'."
"Invalid strategy. Must be one of 'random', 'most-detections', 'target-labels', 'taxonomy', "
"'rarest', or 'model-disagreement'."
)

# Get preannotations for chosen images
chosen_preannotations = preannotations[preannotations["image_path"].isin(chosen_images)]

# Chosen preannotations is a dict with image_path as the key
return chosen_images, chosen_preannotations
al_stats["al_selected_images"] = len(chosen_images)
return chosen_images, chosen_preannotations, al_stats
4 changes: 4 additions & 0 deletions src/annotators.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def upload(
images: List[str],
instance_name: str,
preannotations: Optional[Dict[str, pd.DataFrame]] = None,
species_to_genus: Optional[Dict[str, str]] = None,
) -> None:
raise NotImplementedError

Expand Down Expand Up @@ -114,6 +115,7 @@ def upload(
images: List[str],
instance_name: str,
preannotations: Optional[Dict[str, pd.DataFrame]] = None,
species_to_genus: Optional[Dict[str, str]] = None,
) -> None:
project_name = self.cfg.annotation.label_studio.instances[instance_name].project_name
ls_mod.upload_to_label_studio(
Expand All @@ -124,6 +126,7 @@ def upload(
images_to_annotate_dir=self.cfg.image_dir,
folder_name=self.cfg.annotation.label_studio.folder_name,
preannotations=preannotations,
species_to_genus=species_to_genus,
)

def check_for_new_annotations(self, instance_name: str, image_dir: str) -> Optional[pd.DataFrame]:
Expand Down Expand Up @@ -157,6 +160,7 @@ def upload(
images: List[str],
instance_name: str,
preannotations: Optional[Dict[str, pd.DataFrame]] = None,
species_to_genus: Optional[Dict[str, str]] = None,
) -> None:
# Build S3 URIs
s3_prefix = getattr(self.cfg.annotation.sagemaker, "s3_prefix", "").rstrip("/")
Expand Down
Loading
Loading