From ed1bc6a5ee4ca4837fe717550658f9a7cf24ae51 Mon Sep 17 00:00:00 2001 From: bw4sz Date: Fri, 24 Apr 2026 09:39:55 -0700 Subject: [PATCH 1/3] add single-path slurm annotation prep workflow Introduce manifest-based refresh modes in prepare_USGS and add one submit script that runs shard refresh then final prepare, so stale split_raster work is parallelized without rerunning up-to-date crops. Made-with: Cursor --- scripts/prepare_USGS.py | 241 +++++++++++++++++++++++++++------- submit_prepare_annotations.sh | 108 +++++++++++++++ 2 files changed, 302 insertions(+), 47 deletions(-) create mode 100755 submit_prepare_annotations.sh diff --git a/scripts/prepare_USGS.py b/scripts/prepare_USGS.py index b0809dd..15bcd2c 100644 --- a/scripts/prepare_USGS.py +++ b/scripts/prepare_USGS.py @@ -109,6 +109,18 @@ def parse_args(): help="Do not overwrite UBFAI crop CSVs from detection/crops when dest exists. " "Default is to update (overwrite) so labels stay current.", ) + parser.add_argument( + "--write-detection-refresh-manifest", + type=str, + default=None, + help="Write CSV manifest of stale/missing detection crop work and exit.", + ) + parser.add_argument( + "--process-detection-refresh-manifest", + type=str, + default=None, + help="Process only rows in a detection refresh manifest CSV and exit.", + ) return parser.parse_args() @@ -254,61 +266,29 @@ def generate_detection_crops(): ) for flight_name in flights: - root_dir = os.path.join(IMAGERY_BASE, flight_name) - save_dir = os.path.join(DETECTION_CROPS_BASE, flight_name) - if not os.path.isdir(root_dir): - print(f" Skip {flight_name}: imagery dir not found {root_dir}") - continue - - csvs = [] - for sub in ("train", "validation", "review"): - csvs.extend( - glob.glob(os.path.join(ANNOTATIONS_BASE, sub, flight_name, "*.csv")) - ) - if not csvs: - print(f" Skip {flight_name}: no annotation CSVs") + flight_data = _build_flight_detection_annotations(flight_name) + if flight_data is None: continue - - # Build combined and record each row's source CSV mtime - parts = [] - for f in csvs: - df = pd.read_csv(f) - df["_source_mtime"] = os.path.getmtime(f) - parts.append(df) - combined = pd.concat(parts, ignore_index=True).drop_duplicates() - combined = _normalize_annotation_columns(combined) - - # Keep only annotations whose images actually exist on disk - combined["_path"] = combined["image_path"].apply( - lambda p: os.path.join(root_dir, p) - ) - combined = combined[combined["_path"].apply(os.path.exists)].drop( - columns=["_path"] - ) - if combined.empty: - print(f" Skip {flight_name}: no annotations with existing images") - continue - - # Per image: max mtime of any annotation CSV that contains it - image_ann_mtime = combined.groupby("image_path")["_source_mtime"].max() - combined = combined.drop(columns=["_source_mtime"]) + root_dir, save_dir, combined, image_ann_mtime = flight_data # Only refresh images whose annotation is newer than existing crop CSV (or missing) - os.makedirs(save_dir, exist_ok=True) - images_to_refresh = [] - for image_path in combined["image_path"].unique(): - image_stem = os.path.splitext(os.path.basename(image_path))[0] - crop_csv = os.path.join(save_dir, f"{image_stem}.csv") - ann_mtime = image_ann_mtime.loc[image_path] - if not os.path.exists(crop_csv) or os.path.getmtime(crop_csv) < ann_mtime: - images_to_refresh.append(image_path) - if os.path.exists(crop_csv): - os.remove(crop_csv) + refresh_df = _build_refresh_rows_for_flight( + flight_name=flight_name, + root_dir=root_dir, + save_dir=save_dir, + combined=combined, + image_ann_mtime=image_ann_mtime, + ) + images_to_refresh = refresh_df["image_path"].tolist() if not images_to_refresh: print(f" {flight_name}: no images need refresh (all crop CSVs up to date)") continue + for crop_csv in refresh_df["crop_csv"]: + if os.path.exists(crop_csv): + os.remove(crop_csv) + combined_refresh = combined[combined["image_path"].isin(images_to_refresh)] data_processing.preprocess_images( combined_refresh, @@ -326,6 +306,165 @@ def generate_detection_crops(): print("Detection crop generation done.") +def _build_flight_detection_annotations( + flight_name: str, +) -> tuple[str, str, pd.DataFrame, pd.Series] | None: + root_dir = os.path.join(IMAGERY_BASE, flight_name) + save_dir = os.path.join(DETECTION_CROPS_BASE, flight_name) + if not os.path.isdir(root_dir): + print(f" Skip {flight_name}: imagery dir not found {root_dir}") + return None + + csvs = [] + for sub in ("train", "validation", "review"): + csvs.extend(glob.glob(os.path.join(ANNOTATIONS_BASE, sub, flight_name, "*.csv"))) + if not csvs: + print(f" Skip {flight_name}: no annotation CSVs") + return None + + parts = [] + for f in csvs: + df = pd.read_csv(f) + df["_source_mtime"] = os.path.getmtime(f) + parts.append(df) + combined = pd.concat(parts, ignore_index=True).drop_duplicates() + combined = _normalize_annotation_columns(combined) + combined["_path"] = combined["image_path"].apply(lambda p: os.path.join(root_dir, p)) + combined = combined[combined["_path"].apply(os.path.exists)].drop(columns=["_path"]) + if combined.empty: + print(f" Skip {flight_name}: no annotations with existing images") + return None + + image_ann_mtime = combined.groupby("image_path")["_source_mtime"].max() + combined = combined.drop(columns=["_source_mtime"]) + return root_dir, save_dir, combined, image_ann_mtime + + +def _build_refresh_rows_for_flight( + flight_name: str, + root_dir: str, + save_dir: str, + combined: pd.DataFrame, + image_ann_mtime: pd.Series, +) -> pd.DataFrame: + rows = [] + for image_path in combined["image_path"].unique(): + image_stem = os.path.splitext(os.path.basename(image_path))[0] + crop_csv = os.path.join(save_dir, f"{image_stem}.csv") + ann_mtime = float(image_ann_mtime.loc[image_path]) + needs_refresh = (not os.path.exists(crop_csv)) or ( + os.path.getmtime(crop_csv) < ann_mtime + ) + if needs_refresh: + rows.append( + { + "flight_name": flight_name, + "root_dir": root_dir, + "save_dir": save_dir, + "image_path": image_path, + "crop_csv": crop_csv, + "ann_mtime": ann_mtime, + } + ) + return pd.DataFrame(rows) + + +def write_detection_refresh_manifest(manifest_path: str): + rows = [] + flight_dirs = set() + for sub in ("train", "validation", "review"): + parent = os.path.join(ANNOTATIONS_BASE, sub) + if os.path.isdir(parent): + for name in os.listdir(parent): + if os.path.isdir(os.path.join(parent, name)): + flight_dirs.add(name) + + for flight_name in sorted(flight_dirs): + flight_data = _build_flight_detection_annotations(flight_name) + if flight_data is None: + continue + root_dir, save_dir, combined, image_ann_mtime = flight_data + refresh_df = _build_refresh_rows_for_flight( + flight_name=flight_name, + root_dir=root_dir, + save_dir=save_dir, + combined=combined, + image_ann_mtime=image_ann_mtime, + ) + if not refresh_df.empty: + rows.append(refresh_df) + + if rows: + manifest = pd.concat(rows, ignore_index=True) + else: + manifest = pd.DataFrame( + columns=[ + "flight_name", + "root_dir", + "save_dir", + "image_path", + "crop_csv", + "ann_mtime", + ] + ) + manifest.to_csv(manifest_path, index=False) + print(f"Wrote refresh manifest with {len(manifest)} rows to {manifest_path}") + + +def process_detection_refresh_manifest(manifest_path: str): + from src import data_processing + + manifest = pd.read_csv(manifest_path) + if manifest.empty: + print(f"Manifest {manifest_path} is empty; nothing to process.") + return + + total_runtime_skips = 0 + total_refreshed = 0 + + for flight_name in sorted(manifest["flight_name"].unique()): + flight_manifest = manifest[manifest["flight_name"] == flight_name] + flight_data = _build_flight_detection_annotations(flight_name) + if flight_data is None: + continue + root_dir, save_dir, combined, image_ann_mtime = flight_data + available_images = set(combined["image_path"].unique()) + requested_images = set(flight_manifest["image_path"].astype(str)) + refresh_now = [] + for image_path in requested_images: + if image_path not in available_images: + continue + image_stem = os.path.splitext(os.path.basename(image_path))[0] + crop_csv = os.path.join(save_dir, f"{image_stem}.csv") + ann_mtime = float(image_ann_mtime.loc[image_path]) + if os.path.exists(crop_csv) and os.path.getmtime(crop_csv) >= ann_mtime: + total_runtime_skips += 1 + continue + if os.path.exists(crop_csv): + os.remove(crop_csv) + refresh_now.append(image_path) + + if not refresh_now: + continue + + combined_refresh = combined[combined["image_path"].isin(refresh_now)] + data_processing.preprocess_images( + combined_refresh, + root_dir=root_dir, + save_dir=save_dir, + patch_size=PATCH_SIZE, + patch_overlap=PATCH_OVERLAP, + allow_empty=True, + ) + total_refreshed += len(refresh_now) + print(f" {flight_name}: refreshed {len(refresh_now)} images from manifest shard") + + print( + f"Manifest processing complete: refreshed={total_refreshed}, " + f"runtime_skipped_up_to_date={total_runtime_skips}" + ) + + # --------------------------------------------------------------------------- # Stage 1: Pre-workflow annotations (UBFAI cumulative CSV) # --------------------------------------------------------------------------- @@ -648,6 +787,14 @@ def main(): args = parse_args() set_seed(args.seed) + if args.write_detection_refresh_manifest: + write_detection_refresh_manifest(args.write_detection_refresh_manifest) + return + + if args.process_detection_refresh_manifest: + process_detection_refresh_manifest(args.process_detection_refresh_manifest) + return + # Generate detection crops from annotations (default: on; only refreshes when newer) if not args.no_generate_detection_crops: generate_detection_crops() diff --git a/submit_prepare_annotations.sh b/submit_prepare_annotations.sh new file mode 100755 index 0000000..5036f24 --- /dev/null +++ b/submit_prepare_annotations.sh @@ -0,0 +1,108 @@ +#!/bin/bash +# One-command SLURM route for USGS annotation preparation. +# 1) Build manifest of stale/missing detection crop CSVs +# 2) Run shard array to refresh only those crops +# 3) Run final prepare_USGS stages 1-3 (skip Stage 0) + +set -euo pipefail + +cd "$(dirname "$0")" + +RUN_ID="$(date +%Y%m%d_%H%M%S)" +WORK_DIR="tmp/prepare_annotations_${RUN_ID}" +MANIFEST="${WORK_DIR}/manifest.csv" +SHARD_DIR="${WORK_DIR}/shards" + +mkdir -p "$SHARD_DIR" + +echo "Building manifest..." +uv run python scripts/prepare_USGS.py --write-detection-refresh-manifest "$MANIFEST" + +N_ROWS="$(uv run python - <<'PY' "$MANIFEST" +import pandas as pd +import sys +print(len(pd.read_csv(sys.argv[1]))) +PY +)" + +if [[ "$N_ROWS" -eq 0 ]]; then + echo "No stale crops found. Running final prepare only." + sbatch <<'EOF' +#!/bin/bash +#SBATCH --job-name=prep_ann_final +#SBATCH --account=ewhite +#SBATCH --partition=hpg-b200 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=2 +#SBATCH --mem=16GB +#SBATCH --time=08:00:00 +#SBATCH --output=/home/b.weinstein/logs/prep_ann_final_%j.out +#SBATCH --error=/home/b.weinstein/logs/prep_ann_final_%j.err +cd /blue/ewhite/b.weinstein/BOEM || exit 1 +uv run python scripts/prepare_USGS.py --no-generate-detection-crops +EOF + exit 0 +fi + +N_SHARDS="$(uv run python - <<'PY' "$MANIFEST" "$SHARD_DIR" +import math +import pandas as pd +import sys +from pathlib import Path + +manifest = Path(sys.argv[1]) +shard_dir = Path(sys.argv[2]) +shard_size = 200 +df = pd.read_csv(manifest) +n = len(df) +n_shards = math.ceil(n / shard_size) +for i in range(n_shards): + start = i * shard_size + end = min((i + 1) * shard_size, n) + df.iloc[start:end].to_csv(shard_dir / f"shard_{i:05d}.csv", index=False) +print(n_shards) +PY +)" + +ARRAY_JOB_ID="$( + sbatch --parsable < Date: Fri, 24 Apr 2026 10:35:33 -0700 Subject: [PATCH 2/3] Revert "add single-path slurm annotation prep workflow" This reverts commit ed1bc6a5ee4ca4837fe717550658f9a7cf24ae51. --- scripts/prepare_USGS.py | 241 +++++++--------------------------- submit_prepare_annotations.sh | 108 --------------- 2 files changed, 47 insertions(+), 302 deletions(-) delete mode 100755 submit_prepare_annotations.sh diff --git a/scripts/prepare_USGS.py b/scripts/prepare_USGS.py index 15bcd2c..b0809dd 100644 --- a/scripts/prepare_USGS.py +++ b/scripts/prepare_USGS.py @@ -109,18 +109,6 @@ def parse_args(): help="Do not overwrite UBFAI crop CSVs from detection/crops when dest exists. " "Default is to update (overwrite) so labels stay current.", ) - parser.add_argument( - "--write-detection-refresh-manifest", - type=str, - default=None, - help="Write CSV manifest of stale/missing detection crop work and exit.", - ) - parser.add_argument( - "--process-detection-refresh-manifest", - type=str, - default=None, - help="Process only rows in a detection refresh manifest CSV and exit.", - ) return parser.parse_args() @@ -266,29 +254,61 @@ def generate_detection_crops(): ) for flight_name in flights: - flight_data = _build_flight_detection_annotations(flight_name) - if flight_data is None: + root_dir = os.path.join(IMAGERY_BASE, flight_name) + save_dir = os.path.join(DETECTION_CROPS_BASE, flight_name) + if not os.path.isdir(root_dir): + print(f" Skip {flight_name}: imagery dir not found {root_dir}") continue - root_dir, save_dir, combined, image_ann_mtime = flight_data - # Only refresh images whose annotation is newer than existing crop CSV (or missing) - refresh_df = _build_refresh_rows_for_flight( - flight_name=flight_name, - root_dir=root_dir, - save_dir=save_dir, - combined=combined, - image_ann_mtime=image_ann_mtime, + csvs = [] + for sub in ("train", "validation", "review"): + csvs.extend( + glob.glob(os.path.join(ANNOTATIONS_BASE, sub, flight_name, "*.csv")) + ) + if not csvs: + print(f" Skip {flight_name}: no annotation CSVs") + continue + + # Build combined and record each row's source CSV mtime + parts = [] + for f in csvs: + df = pd.read_csv(f) + df["_source_mtime"] = os.path.getmtime(f) + parts.append(df) + combined = pd.concat(parts, ignore_index=True).drop_duplicates() + combined = _normalize_annotation_columns(combined) + + # Keep only annotations whose images actually exist on disk + combined["_path"] = combined["image_path"].apply( + lambda p: os.path.join(root_dir, p) ) - images_to_refresh = refresh_df["image_path"].tolist() + combined = combined[combined["_path"].apply(os.path.exists)].drop( + columns=["_path"] + ) + if combined.empty: + print(f" Skip {flight_name}: no annotations with existing images") + continue + + # Per image: max mtime of any annotation CSV that contains it + image_ann_mtime = combined.groupby("image_path")["_source_mtime"].max() + combined = combined.drop(columns=["_source_mtime"]) + + # Only refresh images whose annotation is newer than existing crop CSV (or missing) + os.makedirs(save_dir, exist_ok=True) + images_to_refresh = [] + for image_path in combined["image_path"].unique(): + image_stem = os.path.splitext(os.path.basename(image_path))[0] + crop_csv = os.path.join(save_dir, f"{image_stem}.csv") + ann_mtime = image_ann_mtime.loc[image_path] + if not os.path.exists(crop_csv) or os.path.getmtime(crop_csv) < ann_mtime: + images_to_refresh.append(image_path) + if os.path.exists(crop_csv): + os.remove(crop_csv) if not images_to_refresh: print(f" {flight_name}: no images need refresh (all crop CSVs up to date)") continue - for crop_csv in refresh_df["crop_csv"]: - if os.path.exists(crop_csv): - os.remove(crop_csv) - combined_refresh = combined[combined["image_path"].isin(images_to_refresh)] data_processing.preprocess_images( combined_refresh, @@ -306,165 +326,6 @@ def generate_detection_crops(): print("Detection crop generation done.") -def _build_flight_detection_annotations( - flight_name: str, -) -> tuple[str, str, pd.DataFrame, pd.Series] | None: - root_dir = os.path.join(IMAGERY_BASE, flight_name) - save_dir = os.path.join(DETECTION_CROPS_BASE, flight_name) - if not os.path.isdir(root_dir): - print(f" Skip {flight_name}: imagery dir not found {root_dir}") - return None - - csvs = [] - for sub in ("train", "validation", "review"): - csvs.extend(glob.glob(os.path.join(ANNOTATIONS_BASE, sub, flight_name, "*.csv"))) - if not csvs: - print(f" Skip {flight_name}: no annotation CSVs") - return None - - parts = [] - for f in csvs: - df = pd.read_csv(f) - df["_source_mtime"] = os.path.getmtime(f) - parts.append(df) - combined = pd.concat(parts, ignore_index=True).drop_duplicates() - combined = _normalize_annotation_columns(combined) - combined["_path"] = combined["image_path"].apply(lambda p: os.path.join(root_dir, p)) - combined = combined[combined["_path"].apply(os.path.exists)].drop(columns=["_path"]) - if combined.empty: - print(f" Skip {flight_name}: no annotations with existing images") - return None - - image_ann_mtime = combined.groupby("image_path")["_source_mtime"].max() - combined = combined.drop(columns=["_source_mtime"]) - return root_dir, save_dir, combined, image_ann_mtime - - -def _build_refresh_rows_for_flight( - flight_name: str, - root_dir: str, - save_dir: str, - combined: pd.DataFrame, - image_ann_mtime: pd.Series, -) -> pd.DataFrame: - rows = [] - for image_path in combined["image_path"].unique(): - image_stem = os.path.splitext(os.path.basename(image_path))[0] - crop_csv = os.path.join(save_dir, f"{image_stem}.csv") - ann_mtime = float(image_ann_mtime.loc[image_path]) - needs_refresh = (not os.path.exists(crop_csv)) or ( - os.path.getmtime(crop_csv) < ann_mtime - ) - if needs_refresh: - rows.append( - { - "flight_name": flight_name, - "root_dir": root_dir, - "save_dir": save_dir, - "image_path": image_path, - "crop_csv": crop_csv, - "ann_mtime": ann_mtime, - } - ) - return pd.DataFrame(rows) - - -def write_detection_refresh_manifest(manifest_path: str): - rows = [] - flight_dirs = set() - for sub in ("train", "validation", "review"): - parent = os.path.join(ANNOTATIONS_BASE, sub) - if os.path.isdir(parent): - for name in os.listdir(parent): - if os.path.isdir(os.path.join(parent, name)): - flight_dirs.add(name) - - for flight_name in sorted(flight_dirs): - flight_data = _build_flight_detection_annotations(flight_name) - if flight_data is None: - continue - root_dir, save_dir, combined, image_ann_mtime = flight_data - refresh_df = _build_refresh_rows_for_flight( - flight_name=flight_name, - root_dir=root_dir, - save_dir=save_dir, - combined=combined, - image_ann_mtime=image_ann_mtime, - ) - if not refresh_df.empty: - rows.append(refresh_df) - - if rows: - manifest = pd.concat(rows, ignore_index=True) - else: - manifest = pd.DataFrame( - columns=[ - "flight_name", - "root_dir", - "save_dir", - "image_path", - "crop_csv", - "ann_mtime", - ] - ) - manifest.to_csv(manifest_path, index=False) - print(f"Wrote refresh manifest with {len(manifest)} rows to {manifest_path}") - - -def process_detection_refresh_manifest(manifest_path: str): - from src import data_processing - - manifest = pd.read_csv(manifest_path) - if manifest.empty: - print(f"Manifest {manifest_path} is empty; nothing to process.") - return - - total_runtime_skips = 0 - total_refreshed = 0 - - for flight_name in sorted(manifest["flight_name"].unique()): - flight_manifest = manifest[manifest["flight_name"] == flight_name] - flight_data = _build_flight_detection_annotations(flight_name) - if flight_data is None: - continue - root_dir, save_dir, combined, image_ann_mtime = flight_data - available_images = set(combined["image_path"].unique()) - requested_images = set(flight_manifest["image_path"].astype(str)) - refresh_now = [] - for image_path in requested_images: - if image_path not in available_images: - continue - image_stem = os.path.splitext(os.path.basename(image_path))[0] - crop_csv = os.path.join(save_dir, f"{image_stem}.csv") - ann_mtime = float(image_ann_mtime.loc[image_path]) - if os.path.exists(crop_csv) and os.path.getmtime(crop_csv) >= ann_mtime: - total_runtime_skips += 1 - continue - if os.path.exists(crop_csv): - os.remove(crop_csv) - refresh_now.append(image_path) - - if not refresh_now: - continue - - combined_refresh = combined[combined["image_path"].isin(refresh_now)] - data_processing.preprocess_images( - combined_refresh, - root_dir=root_dir, - save_dir=save_dir, - patch_size=PATCH_SIZE, - patch_overlap=PATCH_OVERLAP, - allow_empty=True, - ) - total_refreshed += len(refresh_now) - print(f" {flight_name}: refreshed {len(refresh_now)} images from manifest shard") - - print( - f"Manifest processing complete: refreshed={total_refreshed}, " - f"runtime_skipped_up_to_date={total_runtime_skips}" - ) - - # --------------------------------------------------------------------------- # Stage 1: Pre-workflow annotations (UBFAI cumulative CSV) # --------------------------------------------------------------------------- @@ -787,14 +648,6 @@ def main(): args = parse_args() set_seed(args.seed) - if args.write_detection_refresh_manifest: - write_detection_refresh_manifest(args.write_detection_refresh_manifest) - return - - if args.process_detection_refresh_manifest: - process_detection_refresh_manifest(args.process_detection_refresh_manifest) - return - # Generate detection crops from annotations (default: on; only refreshes when newer) if not args.no_generate_detection_crops: generate_detection_crops() diff --git a/submit_prepare_annotations.sh b/submit_prepare_annotations.sh deleted file mode 100755 index 5036f24..0000000 --- a/submit_prepare_annotations.sh +++ /dev/null @@ -1,108 +0,0 @@ -#!/bin/bash -# One-command SLURM route for USGS annotation preparation. -# 1) Build manifest of stale/missing detection crop CSVs -# 2) Run shard array to refresh only those crops -# 3) Run final prepare_USGS stages 1-3 (skip Stage 0) - -set -euo pipefail - -cd "$(dirname "$0")" - -RUN_ID="$(date +%Y%m%d_%H%M%S)" -WORK_DIR="tmp/prepare_annotations_${RUN_ID}" -MANIFEST="${WORK_DIR}/manifest.csv" -SHARD_DIR="${WORK_DIR}/shards" - -mkdir -p "$SHARD_DIR" - -echo "Building manifest..." -uv run python scripts/prepare_USGS.py --write-detection-refresh-manifest "$MANIFEST" - -N_ROWS="$(uv run python - <<'PY' "$MANIFEST" -import pandas as pd -import sys -print(len(pd.read_csv(sys.argv[1]))) -PY -)" - -if [[ "$N_ROWS" -eq 0 ]]; then - echo "No stale crops found. Running final prepare only." - sbatch <<'EOF' -#!/bin/bash -#SBATCH --job-name=prep_ann_final -#SBATCH --account=ewhite -#SBATCH --partition=hpg-b200 -#SBATCH --nodes=1 -#SBATCH --ntasks-per-node=1 -#SBATCH --cpus-per-task=2 -#SBATCH --mem=16GB -#SBATCH --time=08:00:00 -#SBATCH --output=/home/b.weinstein/logs/prep_ann_final_%j.out -#SBATCH --error=/home/b.weinstein/logs/prep_ann_final_%j.err -cd /blue/ewhite/b.weinstein/BOEM || exit 1 -uv run python scripts/prepare_USGS.py --no-generate-detection-crops -EOF - exit 0 -fi - -N_SHARDS="$(uv run python - <<'PY' "$MANIFEST" "$SHARD_DIR" -import math -import pandas as pd -import sys -from pathlib import Path - -manifest = Path(sys.argv[1]) -shard_dir = Path(sys.argv[2]) -shard_size = 200 -df = pd.read_csv(manifest) -n = len(df) -n_shards = math.ceil(n / shard_size) -for i in range(n_shards): - start = i * shard_size - end = min((i + 1) * shard_size, n) - df.iloc[start:end].to_csv(shard_dir / f"shard_{i:05d}.csv", index=False) -print(n_shards) -PY -)" - -ARRAY_JOB_ID="$( - sbatch --parsable < Date: Fri, 24 Apr 2026 10:36:43 -0700 Subject: [PATCH 3/3] parallelize detection split_raster in-process from SLURM CPUs Use ProcessPoolExecutor for Stage 0 per-image crops when multiple workers are available (PREPARE_USGS_CROP_WORKERS or SLURM_CPUS_PER_TASK). Add a minimal sbatch wrapper that allocates cpus-per-task and runs prepare_USGS. Made-with: Cursor --- scripts/prepare_USGS.py | 75 +++++++++++++++++++++++++++++++---- submit_prepare_annotations.sh | 17 ++++++++ 2 files changed, 84 insertions(+), 8 deletions(-) create mode 100755 submit_prepare_annotations.sh diff --git a/scripts/prepare_USGS.py b/scripts/prepare_USGS.py index b0809dd..f8b9c57 100644 --- a/scripts/prepare_USGS.py +++ b/scripts/prepare_USGS.py @@ -19,6 +19,9 @@ and to updating labels (overwriting UBFAI from detection/crops). Use --no-generate-detection-crops or --no-update-labels to skip either step. +Stage 0 parallelizes split_raster across images when multiple CPUs are available +(set SLURM_CPUS_PER_TASK in your sbatch script, or PREPARE_USGS_CROP_WORKERS). + Why new Label Studio annotations might not appear in classification: - Stage 0 (generate detection crops): runs by default; only regenerates a crop CSV when that image's annotation CSV is newer (or the crop is missing). @@ -42,6 +45,7 @@ import os import random import shutil +from concurrent.futures import ProcessPoolExecutor, as_completed import numpy as np import pandas as pd @@ -64,6 +68,37 @@ PATCH_OVERLAP = 0 +def _detection_crop_parallel_workers() -> int: + """Workers for per-image split_raster in Stage 0. + + Prefer SLURM allocation, then explicit override, then host CPU count. + """ + for key in ("PREPARE_USGS_CROP_WORKERS", "SLURM_CPUS_PER_TASK"): + raw = os.environ.get(key) + if raw: + return max(1, int(raw)) + return max(1, os.cpu_count() or 1) + + +def _detection_crop_worker(task: tuple) -> str: + """Run split_raster for one image (pickled args; must stay top-level for spawn).""" + root_dir, save_dir, image_path, records, patch_size, patch_overlap = task + import pandas as pd + from src import data_processing + + annotation_df = pd.DataFrame.from_records(records) + data_processing.process_image( + image_path=image_path, + annotation_df=annotation_df, + root_dir=root_dir, + save_dir=save_dir, + patch_size=patch_size, + patch_overlap=patch_overlap, + allow_empty=True, + ) + return image_path + + def parse_args(): parser = argparse.ArgumentParser( description="Prepare USGS detection data for training" @@ -310,14 +345,38 @@ def generate_detection_crops(): continue combined_refresh = combined[combined["image_path"].isin(images_to_refresh)] - data_processing.preprocess_images( - combined_refresh, - root_dir=root_dir, - save_dir=save_dir, - patch_size=PATCH_SIZE, - patch_overlap=PATCH_OVERLAP, - allow_empty=True, - ) + n_workers = min(_detection_crop_parallel_workers(), len(images_to_refresh)) + if n_workers <= 1: + data_processing.preprocess_images( + combined_refresh, + root_dir=root_dir, + save_dir=save_dir, + patch_size=PATCH_SIZE, + patch_overlap=PATCH_OVERLAP, + allow_empty=True, + ) + else: + tasks = [] + for image_path in images_to_refresh: + ann = combined_refresh[combined_refresh["image_path"] == image_path] + tasks.append( + ( + root_dir, + save_dir, + image_path, + ann.to_dict("records"), + PATCH_SIZE, + PATCH_OVERLAP, + ) + ) + print( + f" {flight_name}: parallel split_raster " + f"({len(tasks)} images, {n_workers} workers)" + ) + with ProcessPoolExecutor(max_workers=n_workers) as pool: + futures = [pool.submit(_detection_crop_worker, t) for t in tasks] + for fut in as_completed(futures): + fut.result() print( f" {flight_name}: refreshed {len(images_to_refresh)} images (of " f"{combined['image_path'].nunique()} total) -> {save_dir}" diff --git a/submit_prepare_annotations.sh b/submit_prepare_annotations.sh new file mode 100755 index 0000000..4be0d66 --- /dev/null +++ b/submit_prepare_annotations.sh @@ -0,0 +1,17 @@ +#!/bin/bash +# One SLURM job: request many CPUs; prepare_USGS.py parallelizes split_raster per image. +# Run from repo root on the cluster: sbatch submit_prepare_annotations.sh + +#SBATCH --job-name=prep_ann +#SBATCH --account=ewhite +#SBATCH --partition=hpg-b200 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=32 +#SBATCH --mem=64GB +#SBATCH --time=24:00:00 +#SBATCH --output=/home/b.weinstein/logs/prep_ann_%j.out +#SBATCH --error=/home/b.weinstein/logs/prep_ann_%j.err + +cd "${SLURM_SUBMIT_DIR}" +uv run python scripts/prepare_USGS.py