From 584da72f57605e64e2f6e3dafe102cddbcbabe2f Mon Sep 17 00:00:00 2001 From: bw4sz Date: Fri, 1 May 2026 11:21:02 -0700 Subject: [PATCH] feat(active_learning): ensemble crop+H-CAST targeting and model-disagreement strategy - Add match_or_genus_consistent filtering for target-labels/taxonomy via ensemble_target_mode - Add model-disagreement strategy with optional genus-mismatch strictness and label filters - Load species_to_genus from hierarchical label CSV once; pass to select_images and Label Studio - Extend prediction_summary with genus fallback ensemble suggestion lines - Log numeric AL stats and ensemble_target_mode parameter to Comet Co-authored-by: Cursor --- boem_conf/boem_config.yaml | 6 + src/active_learning.py | 227 +++++++++++++++++++++++++++------- src/annotators.py | 4 + src/label_studio.py | 80 ++++++------ src/pipeline.py | 44 +++++-- tests/test_active_learning.py | 109 ++++++++++++++-- 6 files changed, 366 insertions(+), 104 deletions(-) diff --git a/boem_conf/boem_config.yaml b/boem_conf/boem_config.yaml index 57c391b..1583d59 100644 --- a/boem_conf/boem_config.yaml +++ b/boem_conf/boem_config.yaml @@ -91,6 +91,12 @@ active_learning: drop_n_most_common: 5 # Number of most common classes to drop rarest_confidence_selection: highest # "highest" or "lowest" - select highest or lowest confidence samples from rarest classes + # Crop + H-CAST ensemble (requires hierarchical.checkpoint for match_or_genus_consistent and model-disagreement): + # ensemble_target_mode: crop_only # crop_only | match_or_genus_consistent — for target-labels / taxonomy + # strategy: model-disagreement # prioritize images where cropmodel_label != hcast_species + # disagreement_require_genus_mismatch: false # if true, exclude congeneric species disagreements + # disagreement_target_labels: [] # optional list — keep disagreements where crop OR H-CAST species is in this set + # Optional parameters: evaluation: dask_client: diff --git a/src/active_learning.py b/src/active_learning.py index b5a459c..b289a16 100644 --- a/src/active_learning.py +++ b/src/active_learning.py @@ -1,11 +1,97 @@ import json +import os import random from pathlib import Path +import pandas as pd + from src import detection from src import hierarchical +def load_species_to_genus_from_csv(label_csv_path: str | None) -> dict[str, str]: + """Build species -> genus map from H-CAST label CSV (species + genus columns).""" + if not label_csv_path or not os.path.isfile(label_csv_path): + return {} + df = pd.read_csv(label_csv_path).dropna(subset=["species"]) + if "genus" not in df.columns: + return {} + return dict(zip(df["species"].astype(str), df["genus"].astype(str))) + + +def crop_hcast_supported_match_or_genus_consistent(row: pd.Series, species_to_genus: dict[str, str]) -> bool: + """H-CAST supports crop target species if species heads agree or genus agrees.""" + crop = row.get("cropmodel_label") + if crop is None or pd.isna(crop): + return False + crop_s = str(crop) + h_sp = row.get("hcast_species") + if h_sp is None or pd.isna(h_sp): + return False + if crop_s == str(h_sp): + return True + crop_g = species_to_genus.get(crop_s) + h_gn = row.get("hcast_genus") + if crop_g is None or h_gn is None or pd.isna(h_gn): + return False + return str(crop_g) == str(h_gn) + + +def row_crop_hcast_disagrees(row: pd.Series, species_to_genus: dict[str, str], strict_genus_mismatch: bool) -> bool: + """True when crop species != H-CAST species; if strict_genus_mismatch, exclude same-genus congeneric swaps.""" + crop = row.get("cropmodel_label") + h_sp = row.get("hcast_species") + if crop is None or pd.isna(crop) or h_sp is None or pd.isna(h_sp): + return False + crop_s, h_s = str(crop), str(h_sp) + if crop_s == h_s: + return False + if not strict_genus_mismatch: + return True + crop_g = species_to_genus.get(crop_s) + h_gn = row.get("hcast_genus") + if crop_g is None or h_gn is None or pd.isna(h_gn): + return True + return str(crop_g) != str(h_gn) + + +def _row_min_class_confidence(row: pd.Series) -> float: + vals = [] + for key in ("cropmodel_score", "hcast_species_score"): + x = row.get(key) + if x is not None and pd.notna(x) and isinstance(x, (int, float)): + vals.append(float(x)) + return min(vals) if vals else 0.0 + + +def format_ensemble_suggestion_line(row: pd.Series, species_to_genus: dict[str, str] | None) -> str | None: + """Human-readable ensemble suggestion for Label Studio summary (genus fallback when species disagree).""" + if "hcast_species" not in row or pd.isna(row.get("hcast_species")): + return None + crop = row.get("cropmodel_label") + if crop is None or pd.isna(crop): + return None + crop_s, h_sp = str(crop), str(row["hcast_species"]) + h_gn = row.get("hcast_genus") + gn_s = str(h_gn) if h_gn is not None and pd.notna(h_gn) else None + if crop_s == h_sp: + return f"Ensemble suggestion: {crop_s} (species agreement)" + crop_genus = species_to_genus.get(crop_s) if species_to_genus else None + if crop_genus is not None and gn_s is not None and str(crop_genus) == gn_s: + return ( + f"Ensemble suggestion: genus {crop_genus} (species disagree: crop={crop_s}, H-CAST={h_sp})" + ) + fa = row.get("hcast_family") + fa_s = str(fa) if fa is not None and pd.notna(fa) else None + parts = [f"Ensemble ambiguous: crop={crop_s}, H-CAST species={h_sp}"] + if gn_s: + parts.append(f", H-CAST genus={gn_s}") + if fa_s: + parts.append(f", H-CAST family={fa_s}") + parts.append(" — verify taxonomy manually") + return "".join(parts) + + def _collect_leaf_aliases(node: dict) -> list[str]: """Collect all leaf (species) alias strings under a taxonomy node.""" if node.get("isLeaf"): @@ -41,6 +127,7 @@ def walk(node: dict) -> None: walk(item) return result + def human_review(predictions, min_detection_score=0.6, min_classification_score=0.5, confident_threshold=0.5): """ Predict on images and divide into confident and uncertain predictions. @@ -51,22 +138,22 @@ def human_review(predictions, min_detection_score=0.6, min_classification_score= predictions (pd.DataFrame, optional): A DataFrame of existing predictions. Defaults to None. Returns: tuple: A tuple of confident and uncertain predictions. - """ + """ filtered_predictions = predictions[ (predictions["score"] >= min_detection_score) & (predictions["cropmodel_score"] < min_classification_score) ] - # Split predictions into confident and uncertain uncertain_predictions = filtered_predictions[ filtered_predictions["cropmodel_score"] <= confident_threshold] - + confident_predictions = filtered_predictions[ ~filtered_predictions["image_path"].isin( uncertain_predictions["image_path"])] - + return confident_predictions, uncertain_predictions + def generate_pool_predictions( pool, patch_size=512, @@ -135,6 +222,7 @@ def generate_pool_predictions( return preannotations + def _validate_target_labels(target_labels: list[str], valid_labels: set[str] | list[str] | None) -> None: """Raise ValueError if any target label is not in the crop model's label set (catches typos).""" if valid_labels is None: @@ -161,43 +249,35 @@ def select_images( taxonomy_path=None, taxonomy_aliases=None, valid_labels=None, + ensemble_target_mode: str = "crop_only", + species_to_genus: dict[str, str] | None = None, + disagreement_require_genus_mismatch: bool = False, + disagreement_target_labels: list[str] | None = None, ): """ Select images to annotate based on the strategy. Args: preannotations (pd.DataFrame): A DataFrame of predictions. - strategy (str): The strategy for choosing images. Available strategies are: - - "random": Choose images randomly from the pool. - - "most-detections": Choose images with the most detections based on predictions. - - "target-labels": Choose images with target labels (species-level). - - "taxonomy": Like target-labels but taxonomy_aliases (e.g. Aves, Mammalia, Cepphus) - are expanded to all leaf species under those nodes using transformed_taxonomy.json. - - "rarest": Choose images with rarest class labels. - n (int, optional): The number of images to choose. Defaults to 10. - target_labels (list, optional): For target-labels: list of species labels. Defaults to None. - min_score (float, optional): The minimum detection score for a prediction to be included. Defaults to 0.3. - drop_n_most_common (int, optional): For rarest strategy, number of most common classes to drop. Defaults to 1. - rarest_confidence_selection (str, optional): For rarest strategy, "highest" or "lowest" confidence selection. Defaults to "lowest". - min_classification_score (float, optional): Minimum classification confidence score. Defaults to None (no filter). - taxonomy_path (str | Path, optional): Path to transformed_taxonomy.json. Required for strategy "taxonomy". - taxonomy_aliases (list[str], optional): For strategy "taxonomy": e.g. ["Aves", "Mammalia", "Cepphus"]. Defaults to None. - valid_labels (set | list, optional): Crop model label set (e.g. label_dict.keys()). If provided, target-labels - and taxonomy-expanded labels are validated to catch typos/misspellings. + strategy (str): One of random, most-detections, target-labels, taxonomy, rarest, model-disagreement. + ensemble_target_mode (str): For target-labels/taxonomy: crop_only or match_or_genus_consistent. + species_to_genus (dict): Species binomial -> genus for ensemble consistency checks. + disagreement_require_genus_mismatch (bool): For model-disagreement: exclude congeneric species flips. + disagreement_target_labels (list): Optional union filter on crop or H-CAST species labels. Returns: - list: A list of image paths. - pd.DataFrame: A DataFrame of preannotations for the chosen images. + tuple: (chosen_image_paths, chosen_preannotations_df, al_stats_dict) """ + al_stats: dict = {} if preannotations.empty: - return [], None + return [], None, al_stats if strategy == "random": n = min(n, len(preannotations["image_path"].unique())) chosen_images = random.sample(preannotations["image_path"].unique().tolist(), n) else: - preannotations = preannotations[preannotations["score"] >= min_score] + preannotations = preannotations[preannotations["score"] >= min_score].copy() if strategy == "taxonomy": if taxonomy_aliases is None or not taxonomy_aliases: @@ -210,7 +290,7 @@ def select_images( ) target_labels = list(get_leaf_labels_for_taxonomy_aliases(taxonomy_path, taxonomy_aliases)) if not target_labels: - return [], None + return [], None, al_stats if valid_labels is not None: valid_set = set(valid_labels) target_labels = [lbl for lbl in target_labels if lbl in valid_set] @@ -227,45 +307,100 @@ def select_images( _validate_target_labels(target_labels, valid_labels) if strategy == "most-detections": - # Sort images by total number of predictions chosen_images = preannotations.groupby("image_path").size().sort_values(ascending=False).head(n).index.tolist() elif strategy == "target-labels": - # Filter images by target labels (already validated above if valid_labels provided) - chosen_images = preannotations[preannotations.cropmodel_label.isin(target_labels)].groupby("image_path")["score"].mean().sort_values(ascending=False).head(n).index.tolist() + mask_crop_target = preannotations["cropmodel_label"].isin(target_labels) + al_stats["al_target_crop_hits_rows"] = int(mask_crop_target.sum()) + + if ensemble_target_mode == "match_or_genus_consistent": + if "hcast_species" not in preannotations.columns: + raise ValueError( + "ensemble_target_mode='match_or_genus_consistent' requires hierarchical prediction columns " + "(hcast_species). Enable hierarchical.checkpoint or use ensemble_target_mode='crop_only'." + ) + sg = species_to_genus if species_to_genus is not None else {} + supported = preannotations.apply( + lambda r: crop_hcast_supported_match_or_genus_consistent(r, sg), + axis=1, + ) + mask = mask_crop_target & supported + al_stats["al_target_after_ensemble_rows"] = int(mask.sum()) + al_stats["al_ensemble_target_mode"] = ensemble_target_mode + elif ensemble_target_mode == "crop_only": + mask = mask_crop_target + else: + raise ValueError( + f"Unknown ensemble_target_mode {ensemble_target_mode!r}. " + "Use 'crop_only' or 'match_or_genus_consistent'." + ) + + filtered = preannotations[mask] + if filtered.empty: + return [], None, al_stats + chosen_images = ( + filtered.groupby("image_path")["score"].mean().sort_values(ascending=False).head(n).index.tolist() + ) + elif strategy == "model-disagreement": + if "hcast_species" not in preannotations.columns: + raise ValueError( + "strategy 'model-disagreement' requires hcast_species (enable hierarchical.checkpoint)." + ) + sg = species_to_genus if species_to_genus is not None else {} + disagree_mask = preannotations.apply( + lambda r: row_crop_hcast_disagrees(r, sg, disagreement_require_genus_mismatch), + axis=1, + ) + pool = preannotations[disagree_mask].copy() + al_stats["al_disagreement_boxes_before_target_filter"] = len(pool) + if disagreement_target_labels: + tl = set(disagreement_target_labels) + pool = pool[pool["cropmodel_label"].isin(tl) | pool["hcast_species"].isin(tl)] + if min_classification_score is not None and "cropmodel_score" in pool.columns: + pool = pool[pool["cropmodel_score"] >= min_classification_score] + al_stats["al_disagreement_boxes_after_filters"] = len(pool) + al_stats["al_disagreement_strict_genus"] = float(bool(disagreement_require_genus_mismatch)) + if pool.empty: + print("model-disagreement: no disagreeing boxes after filters") + return [], None, al_stats + pool["_joint_conf"] = pool.apply(_row_min_class_confidence, axis=1) + agg = ( + pool.groupby("image_path") + .agg(disagreement_count=("image_path", "count"), mean_joint_conf=("_joint_conf", "mean")) + .sort_values(["disagreement_count", "mean_joint_conf"], ascending=[False, False]) + ) + al_stats["al_disagreement_images_available"] = int(agg.shape[0]) + print( + f"model-disagreement: {len(pool)} disagreeing boxes across " + f"{agg.shape[0]} images (after filters); selecting top {n} images" + ) + chosen_images = agg.head(n).index.tolist() elif strategy == "rarest": - # Filter by minimum classification score if provided if min_classification_score is not None and "cropmodel_score" in preannotations.columns: preannotations = preannotations[preannotations["cropmodel_score"] >= min_classification_score] - - # Drop n most common classes + if drop_n_most_common > 0: most_common_labels = preannotations["cropmodel_label"].value_counts().nlargest(drop_n_most_common).index preannotations = preannotations[~preannotations["cropmodel_label"].isin(most_common_labels)] - + if preannotations.empty: - return [], None - - # Sort images by least common label + return [], None, al_stats + label_counts = preannotations.groupby("cropmodel_label").size().sort_values(ascending=True) - - # Sort preannotations by least common label preannotations["label_count"] = preannotations["cropmodel_label"].map(label_counts) - - # Sort by label count first, then by confidence score + if "cropmodel_score" in preannotations.columns: ascending_conf = rarest_confidence_selection == "lowest" preannotations.sort_values(["label_count", "cropmodel_score"], ascending=[True, ascending_conf], inplace=True) else: preannotations.sort_values("label_count", ascending=True, inplace=True) - + chosen_images = preannotations.drop_duplicates(subset=["image_path"], keep="first").head(n)["image_path"].tolist() else: raise ValueError( - "Invalid strategy. Must be one of 'random', 'most-detections', 'target-labels', 'taxonomy', or 'rarest'." + "Invalid strategy. Must be one of 'random', 'most-detections', 'target-labels', 'taxonomy', " + "'rarest', or 'model-disagreement'." ) - # Get preannotations for chosen images chosen_preannotations = preannotations[preannotations["image_path"].isin(chosen_images)] - - # Chosen preannotations is a dict with image_path as the key - return chosen_images, chosen_preannotations \ No newline at end of file + al_stats["al_selected_images"] = len(chosen_images) + return chosen_images, chosen_preannotations, al_stats diff --git a/src/annotators.py b/src/annotators.py index d8878db..65c6d8c 100644 --- a/src/annotators.py +++ b/src/annotators.py @@ -81,6 +81,7 @@ def upload( images: List[str], instance_name: str, preannotations: Optional[Dict[str, pd.DataFrame]] = None, + species_to_genus: Optional[Dict[str, str]] = None, ) -> None: raise NotImplementedError @@ -114,6 +115,7 @@ def upload( images: List[str], instance_name: str, preannotations: Optional[Dict[str, pd.DataFrame]] = None, + species_to_genus: Optional[Dict[str, str]] = None, ) -> None: project_name = self.cfg.annotation.label_studio.instances[instance_name].project_name ls_mod.upload_to_label_studio( @@ -124,6 +126,7 @@ def upload( images_to_annotate_dir=self.cfg.image_dir, folder_name=self.cfg.annotation.label_studio.folder_name, preannotations=preannotations, + species_to_genus=species_to_genus, ) def check_for_new_annotations(self, instance_name: str, image_dir: str) -> Optional[pd.DataFrame]: @@ -157,6 +160,7 @@ def upload( images: List[str], instance_name: str, preannotations: Optional[Dict[str, pd.DataFrame]] = None, + species_to_genus: Optional[Dict[str, str]] = None, ) -> None: # Build S3 URIs s3_prefix = getattr(self.cfg.annotation.sagemaker, "s3_prefix", "").rstrip("/") diff --git a/src/label_studio.py b/src/label_studio.py index d5f4fd5..0f1b498 100644 --- a/src/label_studio.py +++ b/src/label_studio.py @@ -9,6 +9,8 @@ from PIL import Image from deepforest.utilities import read_file +from src.active_learning import format_ensemble_suggestion_line + def get_taxonomy_leaf_paths(taxonomy_path): """Load transformed_taxonomy.json and return a dict mapping leaf alias -> full path (list of aliases from root to leaf). @@ -34,7 +36,16 @@ def visit(node, path_so_far): return result -def upload_to_label_studio(images, sftp_client, url, project_name, images_to_annotate_dir, folder_name, preannotations): +def upload_to_label_studio( + images, + sftp_client, + url, + project_name, + images_to_annotate_dir, + folder_name, + preannotations, + species_to_genus=None, +): """ Upload images to Label Studio and import image tasks. @@ -64,7 +75,13 @@ def upload_to_label_studio(images, sftp_client, url, project_name, images_to_ann ) label_studio_project = connect_to_label_studio(url=url, project_name=project_name, label_config=default_label_config) upload_images(sftp_client=sftp_client, images=images, folder_name=folder_name) - import_image_tasks(label_studio_project=label_studio_project, image_names=images, local_image_dir=images_to_annotate_dir, predictions=preannotations) + import_image_tasks( + label_studio_project=label_studio_project, + image_names=images, + local_image_dir=images_to_annotate_dir, + predictions=preannotations, + species_to_genus=species_to_genus, + ) def check_for_new_annotations(url, project_name, csv_dir, image_dir): """ @@ -159,43 +176,14 @@ def label_studio_bbox_format(local_image_dir, preannotations, taxonomy_path=None return {"result": predictions} -def format_prediction_summary_for_task(prediction_df: pd.DataFrame) -> str: - """Format crop model + hierarchical model predictions as read-only text for Label Studio task data. - - When the prediction DataFrame has hcast_* columns, includes species/genus/family and scores. - Display in Label Studio via a Header or Text tag with value=$prediction_summary. - """ - if prediction_df is None or prediction_df.empty: - return "No detections for this image." - lines = [] - for i, (_, row) in enumerate(prediction_df.iterrows(), 1): - crop_label = row.get("cropmodel_label", row.get("label", "—")) - score = row.get("score", row.get("cropmodel_score", "")) - if isinstance(score, (int, float)): - parts = [f"Crop {i}: {crop_label} (score={score:.2f})"] - else: - parts = [f"Crop {i}: {crop_label}"] - if "hcast_species" in row and pd.notna(row.get("hcast_species")): - sp = row["hcast_species"] - gn = row.get("hcast_genus") - fa = row.get("hcast_family") - sp_s = row.get("hcast_species_score") - parts.append(f" H-CAST: species={sp}") - if pd.notna(gn): - parts.append(f", genus={gn}") - if pd.notna(fa): - parts.append(f", family={fa}") - if sp_s is not None and isinstance(sp_s, (int, float)): - parts.append(f" (species_score={sp_s:.2f})") - lines.append("".join(parts)) - return "\n".join(lines) if lines else "No detections for this image." - - -def format_prediction_summary_for_task(prediction_df: pd.DataFrame) -> str: - """Format crop model + hierarchical (H-CAST) predictions as read-only text for Label Studio task data. +def format_prediction_summary_for_task( + prediction_df: pd.DataFrame, + species_to_genus: dict | None = None, +) -> str: + """Format crop + H-CAST predictions for Label Studio task data (prediction_summary). - Add the result to task data as prediction_summary and display it with - or
in the label config. + When species_to_genus is provided and hcast_* columns exist, appends ensemble suggestion line + (species agreement, genus fallback on mismatch, or ambiguous note). """ if prediction_df is None or prediction_df.empty: return "No detections for this image." @@ -220,6 +208,9 @@ def format_prediction_summary_for_task(prediction_df: pd.DataFrame) -> str: if sp_s is not None and isinstance(sp_s, (int, float)): parts.append(f" (species_score={sp_s:.2f})") lines.append("".join(parts)) + ens = format_ensemble_suggestion_line(row, species_to_genus) + if ens: + lines.append(f" {ens}") return "\n".join(lines) if lines else "No detections for this image." @@ -379,7 +370,13 @@ def delete_completed_tasks(label_studio_project): for task in tasks: label_studio_project.delete_task(task["id"]) -def import_image_tasks(label_studio_project, image_names, local_image_dir, predictions=None): +def import_image_tasks( + label_studio_project, + image_names, + local_image_dir, + predictions=None, + species_to_genus=None, +): """ Import image tasks into Label Studio project. @@ -403,7 +400,10 @@ def import_image_tasks(label_studio_project, image_names, local_image_dir, predi } if predictions is not None: prediction = predictions.get(basename, pd.DataFrame()) - data_dict["prediction_summary"] = format_prediction_summary_for_task(prediction) + data_dict["prediction_summary"] = format_prediction_summary_for_task( + prediction, + species_to_genus=species_to_genus, + ) # Skip predictions if there are none if prediction.empty: result_dict = [] diff --git a/src/pipeline.py b/src/pipeline.py index 8158cc0..679ab41 100644 --- a/src/pipeline.py +++ b/src/pipeline.py @@ -3,7 +3,12 @@ import shutil from omegaconf import DictConfig -from src.active_learning import generate_pool_predictions, select_images, human_review +from src.active_learning import ( + generate_pool_predictions, + human_review, + load_species_to_genus_from_csv, + select_images, +) from deepforest.model import CropModel from src import sagemaker_gt from src.annotators import get_annotator, LabelStudioAnnotator, SageMakerAnnotator @@ -408,6 +413,7 @@ def run(self): ) hcast_batch_size = getattr(self.config.hierarchical, "batch_size", 16) hcast_workers = getattr(self.config.hierarchical, "workers", 4) + species_to_genus = load_species_to_genus_from_csv(hcast_label_csv) # Apply pool_limit: when using cache, still honor config so debug/small runs stay fast configured_pool_limit = getattr(self.config.active_learning, "pool_limit", None) @@ -457,11 +463,6 @@ def run(self): label_dict = trained_classification_model.label_dict - species_to_genus = {} - if hcast_label_csv and os.path.exists(hcast_label_csv): - label_df = pd.read_csv(hcast_label_csv).dropna(subset=["species"]) - if "genus" in label_df.columns: - species_to_genus = dict(zip(label_df["species"], label_df["genus"])) pipeline_monitor = PipelineEvaluation( # species_to_genus added for hierarchical metrics predictions=evaluation_predictions, annotations=evaluation_annotations, @@ -477,7 +478,7 @@ def run(self): return None test_preannotations = flightline_predictions[~flightline_predictions.image_path.isin(self.existing_images)] - test_images_to_annotate, preannotations = select_images( + test_images_to_annotate, preannotations, _al_stats_test = select_images( preannotations=test_preannotations, strategy="random", n=self.config.active_testing.n_images, @@ -492,7 +493,10 @@ def run(self): # Default taxonomy path: project root transformed_taxonomy.json (for strategy "taxonomy") _default_taxonomy_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "transformed_taxonomy.json") - train_images_to_annotate, preannotations = select_images( + disagreement_target_labels = getattr(self.config.active_learning, "disagreement_target_labels", None) + if disagreement_target_labels is not None and len(disagreement_target_labels) == 0: + disagreement_target_labels = None + train_images_to_annotate, preannotations, al_stats_train = select_images( preannotations=training_preannotations, strategy=self.config.active_learning.strategy, n=self.config.active_learning.n_images, @@ -504,7 +508,24 @@ def run(self): taxonomy_path=getattr(self.config.active_learning, "taxonomy_path", None) or _default_taxonomy_path, taxonomy_aliases=getattr(self.config.active_learning, "taxonomy_aliases", None), valid_labels=list(trained_classification_model.label_dict.keys()), + ensemble_target_mode=getattr(self.config.active_learning, "ensemble_target_mode", "crop_only"), + species_to_genus=species_to_genus if species_to_genus else None, + disagreement_require_genus_mismatch=getattr( + self.config.active_learning, "disagreement_require_genus_mismatch", False + ), + disagreement_target_labels=disagreement_target_labels, ) + if al_stats_train: + numeric_metrics = { + k: float(v) if isinstance(v, bool) else v + for k, v in al_stats_train.items() + if isinstance(v, (int, float)) + } + if numeric_metrics: + self.comet_logger.experiment.log_metrics(numeric_metrics) + mode = al_stats_train.get("al_ensemble_target_mode") + if mode is not None: + self.comet_logger.experiment.log_parameter("al_ensemble_target_mode", mode) if len(train_images_to_annotate) == 0 and training_preannotations.empty: print("Training images to annotate: 0 (all images with detections ≥min_score were assigned to test)") @@ -706,7 +727,12 @@ def run(self): group["image_path"] = basename # Use basename as key to match the format expected by SageMaker upload preannotations[basename] = group.drop(columns=["image_path_basename"], errors="ignore") - self.annotator.upload(images=image_paths, instance_name=instance, preannotations=preannotations) + self.annotator.upload( + images=image_paths, + instance_name=instance, + preannotations=preannotations, + species_to_genus=species_to_genus if species_to_genus else None, + ) self.comet_logger.experiment.add_tag("complete") return None \ No newline at end of file diff --git a/tests/test_active_learning.py b/tests/test_active_learning.py index e4526c9..4b102fb 100644 --- a/tests/test_active_learning.py +++ b/tests/test_active_learning.py @@ -7,8 +7,10 @@ from deepforest.utilities import read_file from src.active_learning import ( + format_ensemble_suggestion_line, generate_pool_predictions, get_leaf_labels_for_taxonomy_aliases, + row_crop_hcast_disagrees, select_images, ) @@ -77,9 +79,9 @@ def test_select_train_images(detection_model): patch_size=450, model=detection_model, patch_overlap=0, - min_score=0.5, + min_score=0, ) - chosen_images, _ = select_images( + chosen_images, _, _stats = select_images( preannotations=train_image_pool, strategy="random", n=1, @@ -108,7 +110,6 @@ def test_select_images_taxonomy_strategy(): path = repo_root / "transformed_taxonomy.json" if not path.exists(): pytest.skip("transformed_taxonomy.json not found") - # Preannotations with one image that has a bird species label preannotations = pd.DataFrame( { "image_path": ["img1.jpg", "img1.jpg", "img2.jpg"], @@ -116,7 +117,7 @@ def test_select_images_taxonomy_strategy(): "score": [0.9, 0.8, 0.7], } ) - chosen_images, _chosen_pre = select_images( + chosen_images, _chosen_pre, stats = select_images( preannotations=preannotations, strategy="taxonomy", n=5, @@ -125,8 +126,8 @@ def test_select_images_taxonomy_strategy(): ) assert len(chosen_images) >= 1 assert "img1.jpg" in chosen_images - # Only Cepphus grylle is under Cepphus; img2 has Object so should not be selected assert "img2.jpg" not in chosen_images + assert stats.get("al_target_crop_hits_rows") == 1 def test_select_images_target_labels_validates_against_crop_model(): @@ -140,8 +141,7 @@ def test_select_images_target_labels_validates_against_crop_model(): ) valid = {"Cepphus grylle", "Actitis macularius"} - # All valid: succeeds - chosen_images, _ = select_images( + chosen_images, _, _ = select_images( preannotations=preannotations, strategy="target-labels", n=5, @@ -150,12 +150,103 @@ def test_select_images_target_labels_validates_against_crop_model(): ) assert "img1.jpg" in chosen_images - # Typo / unknown label: raises with pytest.raises(ValueError, match="not in crop model label dict"): select_images( preannotations=preannotations, strategy="target-labels", n=5, - target_labels=["Cepphus grille"], # typo: grille vs grylle + target_labels=["Cepphus grille"], valid_labels=valid, ) + + +def test_match_or_genus_consistent_requires_hcast(): + preannotations = pd.DataFrame( + { + "image_path": ["a.jpg"], + "cropmodel_label": ["Foo bar"], + "score": [0.9], + } + ) + with pytest.raises(ValueError, match="match_or_genus_consistent"): + select_images( + preannotations=preannotations, + strategy="target-labels", + n=5, + target_labels=["Foo bar"], + ensemble_target_mode="match_or_genus_consistent", + species_to_genus={"Foo bar": "Foo"}, + ) + + +def test_match_or_genus_consistent_filters_rows(): + sg = {"AAA bbb": "AAA", "CCC ddd": "CCC"} + preannotations = pd.DataFrame( + { + "image_path": ["x.jpg", "y.jpg", "z.jpg"], + "cropmodel_label": ["AAA bbb", "AAA bbb", "CCC ddd"], + "score": [0.9, 0.9, 0.9], + "hcast_species": ["AAA bbb", "XXX yyy", "CCC ddd"], + "hcast_genus": ["AAA", "XXX", "CCC"], + } + ) + chosen, _, stats = select_images( + preannotations=preannotations, + strategy="target-labels", + n=5, + target_labels=["AAA bbb", "CCC ddd"], + ensemble_target_mode="match_or_genus_consistent", + species_to_genus=sg, + ) + assert "x.jpg" in chosen + assert "z.jpg" in chosen + assert "y.jpg" not in chosen + assert stats["al_target_crop_hits_rows"] == 3 + assert stats["al_target_after_ensemble_rows"] == 2 + + +def test_model_disagreement_strategy_ranking(): + sg = {} + preannotations = pd.DataFrame( + { + "image_path": ["a.jpg", "a.jpg", "b.jpg", "b.jpg", "b.jpg"], + "cropmodel_label": ["Sp one", "Sp one", "Sp two", "Sp two", "Sp two"], + "score": [0.9, 0.9, 0.85, 0.85, 0.85], + "cropmodel_score": [0.9, 0.95, 0.8, 0.85, 0.9], + "hcast_species": ["Sp other", "Sp other", "Sp two", "Sp alt", "Sp alt"], + "hcast_genus": ["G", "G", "G2", "G2", "G2"], + "hcast_species_score": [0.88, 0.9, 0.85, 0.82, 0.87], + } + ) + chosen, _, stats = select_images( + preannotations=preannotations, + strategy="model-disagreement", + n=2, + min_score=0.3, + species_to_genus=sg, + ) + assert chosen[0] == "a.jpg" + assert stats["al_disagreement_boxes_after_filters"] >= 4 + + +def test_row_crop_hcast_disagree_strict_excludes_congener(): + sg = {"Uria aalge": "Uria", "Uria lomvia": "Uria"} + row_match = pd.Series( + { + "cropmodel_label": "Uria aalge", + "hcast_species": "Uria lomvia", + "hcast_genus": "Uria", + } + ) + assert row_crop_hcast_disagrees(row_match, sg, strict_genus_mismatch=False) + assert not row_crop_hcast_disagrees(row_match, sg, strict_genus_mismatch=True) + + +def test_format_ensemble_suggestion_line(): + sg = {"A b": "A"} + row_agree = pd.Series({"cropmodel_label": "A b", "hcast_species": "A b", "hcast_genus": "A"}) + assert "species agreement" in format_ensemble_suggestion_line(row_agree, sg) + row_genus = pd.Series({"cropmodel_label": "A b", "hcast_species": "A c", "hcast_genus": "A"}) + assert "genus A" in format_ensemble_suggestion_line(row_genus, sg) + row_amb = pd.Series({"cropmodel_label": "A b", "hcast_species": "X y", "hcast_genus": "X"}) + assert "ambiguous" in format_ensemble_suggestion_line(row_amb, sg)