Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion boem_conf/classification_model/USGS.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ batch_size: 96
workers: 8
expand: 30
balance_classes: false
use_metadata: false
use_metadata: true
metadata_dim: 32
metadata_dropout: 0.5
# metadata_dir defaults to report.metadata_dir when use_metadata is true.
34 changes: 34 additions & 0 deletions scripts/USGS_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@
# Crop image_path is like C1_L6_F560_T20241219_173703_737_23.png -> parent stem is C1_L6_F560_T20241219_173703_737
_CROP_SUFFIX_RE = re.compile(r"^(.+)_\d+\.(png|PNG|jpg|JPG|jpeg|JPEG)$")

# New-format UBFAI image names embed the flight datetime as T{YYYYMMDD}_{HHMMSS}.
# This maps directly to a captures CSV key (e.g. 20241219_165719_captures.csv).
# Old-format names (e.g. 668140-0806171338187-CAM3.png) have no parseable datetime.
_FLIGHT_DATETIME_RE = re.compile(r"T(\d{8}_\d{6})")


def _flight_name_from_image_path(image_path: str) -> str | None:
"""Extract flight datetime key from new-format UBFAI image paths, or None."""
m = _FLIGHT_DATETIME_RE.search(os.path.basename(str(image_path)))
return m.group(1) if m else None

# Detection-task combined splits: exclude so we only load per-image CSVs (avoids duplicate rows and Object vs species label conflicts).
UBFAI_CROPS_EXCLUDE_CSV = frozenset({"train.csv", "test.csv", "zero_shot.csv"})
UBFAI_CROPS_EXCLUDE_PREFIX = "train_max_empty_" # e.g. train_max_empty_0_10.csv
Expand Down Expand Up @@ -322,6 +333,25 @@ def main(cfg: DictConfig):
balance_classes = bool(cfg.classification_model.get("balance_classes", False))
comet_logger.experiment.log_parameter("balance_classes", balance_classes)

use_metadata = bool(cfg.classification_model.get("use_metadata", False))
metadata_dir = cfg.classification_model.get("metadata_dir", None) or cfg.report.get("metadata_dir", None)
metadata_dim = int(cfg.classification_model.get("metadata_dim", 32))
metadata_dropout = float(cfg.classification_model.get("metadata_dropout", 0.5))
comet_logger.experiment.log_parameter("use_metadata", use_metadata)
if use_metadata:
comet_logger.experiment.log_parameter("metadata_dir", metadata_dir)
comet_logger.experiment.log_parameter("metadata_dim", metadata_dim)
comet_logger.experiment.log_parameter("metadata_dropout", metadata_dropout)
# UBFAI crops span many flights in one directory. Extract per-row flight_name
# from the T{YYYYMMDD}_{HHMMSS} pattern in new-format image paths so
# build_crop_metadata_rows can look up the right captures CSV for each image.
# Old-format images (no T-datetime pattern) get None and are skipped silently.
for df in (train_df, validation_df):
df["flight_name"] = df["image_path"].map(_flight_name_from_image_path)
n_with = int(train_df["flight_name"].notna().sum())
n_without = int(train_df["flight_name"].isna().sum())
print(f"[metadata] flight_name extracted: {n_with} train crops matched, {n_without} unmatched (old-format, will be skipped)")

trained_model = preprocess_and_train(
train_df=train_df,
validation_df=validation_df,
Expand All @@ -338,6 +368,10 @@ def main(cfg: DictConfig):
batch_size=cfg.classification_model.batch_size,
workers=cfg.classification_model.workers,
balance_classes=balance_classes,
use_metadata=use_metadata,
metadata_dir=metadata_dir,
metadata_dim=metadata_dim,
metadata_dropout=metadata_dropout,
)


Expand Down
33 changes: 33 additions & 0 deletions scripts/visualize_metadata_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,31 @@
except ImportError: # pragma: no cover - contextily is an optional visual enhancement.
ctx = None

try:
import geopandas as gpd
except ImportError:
gpd = None

# Natural Earth 110m countries — small download, cached after first call.
_COUNTRIES_URL = "https://naciscdn.org/naturalearth/110m/cultural/ne_110m_admin_0_countries.zip"
_BOUNDARY_CACHE: "gpd.GeoDataFrame | None" = None


def load_land_boundaries() -> "gpd.GeoDataFrame | None":
"""Return a GeoDataFrame of USA + Canada + Mexico boundaries, or None."""
global _BOUNDARY_CACHE
if gpd is None:
return None
if _BOUNDARY_CACHE is not None:
return _BOUNDARY_CACHE
try:
world = gpd.read_file(_COUNTRIES_URL)
_BOUNDARY_CACHE = world[world["SOVEREIGNT"].isin(["United States of America", "Canada", "Mexico"])]
return _BOUNDARY_CACHE
except Exception as exc:
print(f"Could not load land boundaries: {exc}")
return None


SPECIES_ALIASES = {
"Northern Gannet": "Morus bassanus",
Expand Down Expand Up @@ -189,10 +214,17 @@ def plot_species_map(
vmax=1 if plot_column == "relative_score" else None,
)
fig.colorbar(image, ax=ax, label=plot_column.replace("_", " "))

boundaries = load_land_boundaries()
if boundaries is not None:
boundaries.boundary.plot(ax=ax, color="black", linewidth=0.8, zorder=3)

ax.set_title(f"{species} metadata prior, {date}")
ax.set_xlabel("Longitude")
ax.set_ylabel("Latitude")
ax.grid(color="white", linewidth=0.3, alpha=0.4)
ax.set_xlim(min_lon, max_lon)
ax.set_ylim(min_lat, max_lat)
fig.savefig(output_path, dpi=250, bbox_inches="tight")
plt.close(fig)

Expand Down Expand Up @@ -269,6 +301,7 @@ def main() -> None:
species = resolve_species(args.species)
grid = make_grid(tuple(args.bounds), args.cell_degrees)
model = load_metadata_model(args.checkpoint, args.device)
load_land_boundaries() # warm cache once before the plotting loop

all_scores = []
for date in args.dates:
Expand Down
7 changes: 6 additions & 1 deletion src/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,12 @@ def preprocess_and_train(
ignore_index=True,
).drop_duplicates(subset=["filename"])
if metadata_rows.empty:
raise ValueError(f"No crop metadata rows were created for flight {flight_name}")
print(
f"[preprocess_and_train] WARNING: no crop metadata rows found for any image "
f"(default flight_name={flight_name!r}). Disabling metadata for this run."
)
use_metadata = False
metadata_csv = None
metadata_csv = os.path.join(checkpoint_dir, "classification_crop_metadata.csv")
os.makedirs(checkpoint_dir, exist_ok=True)
metadata_rows.to_csv(metadata_csv, index=False)
Expand Down
11 changes: 9 additions & 2 deletions src/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@
import random
import tempfile

_IMAGE_EXTS = ("jpg", "JPG", "jpeg", "JPEG", "tif", "TIF", "tiff", "TIFF")


def _collect_images(image_dir: str) -> list[str]:
return [p for ext in _IMAGE_EXTS for p in glob.glob(os.path.join(image_dir, f"*.{ext}"))]


class Pipeline:
"""Pipeline for training and evaluating a detection and classification model"""
def __init__(self, cfg: DictConfig):
Expand All @@ -33,7 +40,7 @@ def __init__(self, cfg: DictConfig):
self.annotator = get_annotator(self.config)

# Pool of all images
self.all_images = glob.glob(os.path.join(self.config.image_dir, "*.jpg")) + glob.glob(os.path.join(self.config.image_dir, "*.JPG"))
self.all_images = _collect_images(self.config.image_dir)

self.comet_logger = CometLogger(project_name=self.config.comet.project, workspace=self.config.comet.workspace)
self.comet_logger.experiment.add_tag("pipeline")
Expand Down Expand Up @@ -388,7 +395,7 @@ def run(self):
else:
raise NotImplementedError("Only deepforest classification backend is currently implemented")

pool = glob.glob(os.path.join(self.config.image_dir, "*.jpg")) + glob.glob(os.path.join(self.config.image_dir, "*.JPG"))
pool = _collect_images(self.config.image_dir)
pool = [image for image in pool if os.path.basename(image) not in self.existing_images]
pool_limit = getattr(self.config.active_learning, "pool_limit", None)
print(f"Pool: {len(pool)} images (after excluding existing), pool_limit={pool_limit}")
Expand Down
9 changes: 8 additions & 1 deletion src/spatiotemporal_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,16 @@ def _image_stem(image_path: str) -> str:


def load_flight_metadata(flight_name: str, metadata_dir: str | Path) -> dict[str, dict]:
"""Return image-stem keyed metadata dicts for one flight."""
"""Return image-stem keyed metadata dicts for one flight.

Returns an empty dict (rather than raising) when no captures CSV exists for
the flight — this happens for old-format UBFAI images whose filenames don't
encode a parseable flight datetime.
"""
metadata_dir = Path(metadata_dir)
captures_path = metadata_dir / f"{flight_datetime_key(flight_name)}_captures.csv"
if not captures_path.exists():
return {}
captures = pd.read_csv(captures_path)
required = {"Basename", "Lat", "Lon"}
missing = required - set(captures.columns)
Expand Down
3 changes: 2 additions & 1 deletion src/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ def crop_images(annotations, root_dir, experiment=None, expand=30):
return crop_image_paths

def select_images_for_video(image_dir, thin_factor):
all_images = glob.glob(image_dir + "/*.jpg")
exts = ("jpg", "JPG", "jpeg", "JPEG", "tif", "TIF", "tiff", "TIFF")
all_images = [p for ext in exts for p in glob.glob(os.path.join(image_dir, f"*.{ext}"))]
# Thin by factor, select every nth image
thinned_images = all_images[::thin_factor]
return thinned_images
Expand Down
Loading