From 9d6b5943d37054365617dcfcbc5eeed051e1a1bb Mon Sep 17 00:00:00 2001 From: bw4sz Date: Mon, 18 May 2026 17:00:43 -0400 Subject: [PATCH] update tifs --- boem_conf/classification_model/USGS.yaml | 2 +- scripts/USGS_classification.py | 34 ++++++++++++++++++++++++ scripts/visualize_metadata_priors.py | 33 +++++++++++++++++++++++ src/classification.py | 7 ++++- src/pipeline.py | 11 ++++++-- src/spatiotemporal_metadata.py | 9 ++++++- src/visualization.py | 3 ++- 7 files changed, 93 insertions(+), 6 deletions(-) diff --git a/boem_conf/classification_model/USGS.yaml b/boem_conf/classification_model/USGS.yaml index 82373d3..30cae1b 100644 --- a/boem_conf/classification_model/USGS.yaml +++ b/boem_conf/classification_model/USGS.yaml @@ -11,7 +11,7 @@ batch_size: 96 workers: 8 expand: 30 balance_classes: false -use_metadata: false +use_metadata: true metadata_dim: 32 metadata_dropout: 0.5 # metadata_dir defaults to report.metadata_dir when use_metadata is true. diff --git a/scripts/USGS_classification.py b/scripts/USGS_classification.py index 55107a1..d981c04 100644 --- a/scripts/USGS_classification.py +++ b/scripts/USGS_classification.py @@ -38,6 +38,17 @@ # Crop image_path is like C1_L6_F560_T20241219_173703_737_23.png -> parent stem is C1_L6_F560_T20241219_173703_737 _CROP_SUFFIX_RE = re.compile(r"^(.+)_\d+\.(png|PNG|jpg|JPG|jpeg|JPEG)$") +# New-format UBFAI image names embed the flight datetime as T{YYYYMMDD}_{HHMMSS}. +# This maps directly to a captures CSV key (e.g. 20241219_165719_captures.csv). +# Old-format names (e.g. 668140-0806171338187-CAM3.png) have no parseable datetime. +_FLIGHT_DATETIME_RE = re.compile(r"T(\d{8}_\d{6})") + + +def _flight_name_from_image_path(image_path: str) -> str | None: + """Extract flight datetime key from new-format UBFAI image paths, or None.""" + m = _FLIGHT_DATETIME_RE.search(os.path.basename(str(image_path))) + return m.group(1) if m else None + # Detection-task combined splits: exclude so we only load per-image CSVs (avoids duplicate rows and Object vs species label conflicts). UBFAI_CROPS_EXCLUDE_CSV = frozenset({"train.csv", "test.csv", "zero_shot.csv"}) UBFAI_CROPS_EXCLUDE_PREFIX = "train_max_empty_" # e.g. train_max_empty_0_10.csv @@ -322,6 +333,25 @@ def main(cfg: DictConfig): balance_classes = bool(cfg.classification_model.get("balance_classes", False)) comet_logger.experiment.log_parameter("balance_classes", balance_classes) + use_metadata = bool(cfg.classification_model.get("use_metadata", False)) + metadata_dir = cfg.classification_model.get("metadata_dir", None) or cfg.report.get("metadata_dir", None) + metadata_dim = int(cfg.classification_model.get("metadata_dim", 32)) + metadata_dropout = float(cfg.classification_model.get("metadata_dropout", 0.5)) + comet_logger.experiment.log_parameter("use_metadata", use_metadata) + if use_metadata: + comet_logger.experiment.log_parameter("metadata_dir", metadata_dir) + comet_logger.experiment.log_parameter("metadata_dim", metadata_dim) + comet_logger.experiment.log_parameter("metadata_dropout", metadata_dropout) + # UBFAI crops span many flights in one directory. Extract per-row flight_name + # from the T{YYYYMMDD}_{HHMMSS} pattern in new-format image paths so + # build_crop_metadata_rows can look up the right captures CSV for each image. + # Old-format images (no T-datetime pattern) get None and are skipped silently. + for df in (train_df, validation_df): + df["flight_name"] = df["image_path"].map(_flight_name_from_image_path) + n_with = int(train_df["flight_name"].notna().sum()) + n_without = int(train_df["flight_name"].isna().sum()) + print(f"[metadata] flight_name extracted: {n_with} train crops matched, {n_without} unmatched (old-format, will be skipped)") + trained_model = preprocess_and_train( train_df=train_df, validation_df=validation_df, @@ -338,6 +368,10 @@ def main(cfg: DictConfig): batch_size=cfg.classification_model.batch_size, workers=cfg.classification_model.workers, balance_classes=balance_classes, + use_metadata=use_metadata, + metadata_dir=metadata_dir, + metadata_dim=metadata_dim, + metadata_dropout=metadata_dropout, ) diff --git a/scripts/visualize_metadata_priors.py b/scripts/visualize_metadata_priors.py index 34f1dec..d0bdf21 100644 --- a/scripts/visualize_metadata_priors.py +++ b/scripts/visualize_metadata_priors.py @@ -32,6 +32,31 @@ except ImportError: # pragma: no cover - contextily is an optional visual enhancement. ctx = None +try: + import geopandas as gpd +except ImportError: + gpd = None + +# Natural Earth 110m countries — small download, cached after first call. +_COUNTRIES_URL = "https://naciscdn.org/naturalearth/110m/cultural/ne_110m_admin_0_countries.zip" +_BOUNDARY_CACHE: "gpd.GeoDataFrame | None" = None + + +def load_land_boundaries() -> "gpd.GeoDataFrame | None": + """Return a GeoDataFrame of USA + Canada + Mexico boundaries, or None.""" + global _BOUNDARY_CACHE + if gpd is None: + return None + if _BOUNDARY_CACHE is not None: + return _BOUNDARY_CACHE + try: + world = gpd.read_file(_COUNTRIES_URL) + _BOUNDARY_CACHE = world[world["SOVEREIGNT"].isin(["United States of America", "Canada", "Mexico"])] + return _BOUNDARY_CACHE + except Exception as exc: + print(f"Could not load land boundaries: {exc}") + return None + SPECIES_ALIASES = { "Northern Gannet": "Morus bassanus", @@ -189,10 +214,17 @@ def plot_species_map( vmax=1 if plot_column == "relative_score" else None, ) fig.colorbar(image, ax=ax, label=plot_column.replace("_", " ")) + + boundaries = load_land_boundaries() + if boundaries is not None: + boundaries.boundary.plot(ax=ax, color="black", linewidth=0.8, zorder=3) + ax.set_title(f"{species} metadata prior, {date}") ax.set_xlabel("Longitude") ax.set_ylabel("Latitude") ax.grid(color="white", linewidth=0.3, alpha=0.4) + ax.set_xlim(min_lon, max_lon) + ax.set_ylim(min_lat, max_lat) fig.savefig(output_path, dpi=250, bbox_inches="tight") plt.close(fig) @@ -269,6 +301,7 @@ def main() -> None: species = resolve_species(args.species) grid = make_grid(tuple(args.bounds), args.cell_degrees) model = load_metadata_model(args.checkpoint, args.device) + load_land_boundaries() # warm cache once before the plotting loop all_scores = [] for date in args.dates: diff --git a/src/classification.py b/src/classification.py index 00d25c5..bf75725 100644 --- a/src/classification.py +++ b/src/classification.py @@ -208,7 +208,12 @@ def preprocess_and_train( ignore_index=True, ).drop_duplicates(subset=["filename"]) if metadata_rows.empty: - raise ValueError(f"No crop metadata rows were created for flight {flight_name}") + print( + f"[preprocess_and_train] WARNING: no crop metadata rows found for any image " + f"(default flight_name={flight_name!r}). Disabling metadata for this run." + ) + use_metadata = False + metadata_csv = None metadata_csv = os.path.join(checkpoint_dir, "classification_crop_metadata.csv") os.makedirs(checkpoint_dir, exist_ok=True) metadata_rows.to_csv(metadata_csv, index=False) diff --git a/src/pipeline.py b/src/pipeline.py index e152657..a598411 100644 --- a/src/pipeline.py +++ b/src/pipeline.py @@ -23,6 +23,13 @@ import random import tempfile +_IMAGE_EXTS = ("jpg", "JPG", "jpeg", "JPEG", "tif", "TIF", "tiff", "TIFF") + + +def _collect_images(image_dir: str) -> list[str]: + return [p for ext in _IMAGE_EXTS for p in glob.glob(os.path.join(image_dir, f"*.{ext}"))] + + class Pipeline: """Pipeline for training and evaluating a detection and classification model""" def __init__(self, cfg: DictConfig): @@ -33,7 +40,7 @@ def __init__(self, cfg: DictConfig): self.annotator = get_annotator(self.config) # Pool of all images - self.all_images = glob.glob(os.path.join(self.config.image_dir, "*.jpg")) + glob.glob(os.path.join(self.config.image_dir, "*.JPG")) + self.all_images = _collect_images(self.config.image_dir) self.comet_logger = CometLogger(project_name=self.config.comet.project, workspace=self.config.comet.workspace) self.comet_logger.experiment.add_tag("pipeline") @@ -388,7 +395,7 @@ def run(self): else: raise NotImplementedError("Only deepforest classification backend is currently implemented") - pool = glob.glob(os.path.join(self.config.image_dir, "*.jpg")) + glob.glob(os.path.join(self.config.image_dir, "*.JPG")) + pool = _collect_images(self.config.image_dir) pool = [image for image in pool if os.path.basename(image) not in self.existing_images] pool_limit = getattr(self.config.active_learning, "pool_limit", None) print(f"Pool: {len(pool)} images (after excluding existing), pool_limit={pool_limit}") diff --git a/src/spatiotemporal_metadata.py b/src/spatiotemporal_metadata.py index f556a7e..66405fd 100644 --- a/src/spatiotemporal_metadata.py +++ b/src/spatiotemporal_metadata.py @@ -24,9 +24,16 @@ def _image_stem(image_path: str) -> str: def load_flight_metadata(flight_name: str, metadata_dir: str | Path) -> dict[str, dict]: - """Return image-stem keyed metadata dicts for one flight.""" + """Return image-stem keyed metadata dicts for one flight. + + Returns an empty dict (rather than raising) when no captures CSV exists for + the flight — this happens for old-format UBFAI images whose filenames don't + encode a parseable flight datetime. + """ metadata_dir = Path(metadata_dir) captures_path = metadata_dir / f"{flight_datetime_key(flight_name)}_captures.csv" + if not captures_path.exists(): + return {} captures = pd.read_csv(captures_path) required = {"Basename", "Lat", "Lon"} missing = required - set(captures.columns) diff --git a/src/visualization.py b/src/visualization.py index 79531d1..a75993c 100644 --- a/src/visualization.py +++ b/src/visualization.py @@ -181,7 +181,8 @@ def crop_images(annotations, root_dir, experiment=None, expand=30): return crop_image_paths def select_images_for_video(image_dir, thin_factor): - all_images = glob.glob(image_dir + "/*.jpg") + exts = ("jpg", "JPG", "jpeg", "JPEG", "tif", "TIF", "tiff", "TIFF") + all_images = [p for ext in exts for p in glob.glob(os.path.join(image_dir, f"*.{ext}"))] # Thin by factor, select every nth image thinned_images = all_images[::thin_factor] return thinned_images