Skip to content

Conversation

@mihow
Copy link
Collaborator

@mihow mihow commented Oct 11, 2025

Summary by CodeRabbit

  • New Features
    • Added global moth classifier pipeline combining object detection with species classification.
    • Enabled persistent caching for model downloads, improving performance on repeat operations.

@netlify
Copy link

netlify bot commented Oct 11, 2025

👷 Deploy Preview for antenna-preview processing.

Name Link
🔨 Latest commit fcfcd87
🔍 Latest deploy log https://app.netlify.com/projects/antenna-preview/deploys/691cf01349cfe9000859418b

@mihow mihow force-pushed the feat/global-moths-model branch from 248fda8 to 9bd828b Compare October 13, 2025 22:46
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 18, 2025

Walkthrough

Introduces a new Global Moth Classifier algorithm integrated into a two-stage pipeline with ZeroShotObjectDetector. Adds inference base classes supporting lazy-loaded PyTorch models, persistent caching for weights and labels, device selection utilities, and supporting configuration updates across the API layer.

Changes

Cohort / File(s) Summary
Infrastructure & Configuration
.gitignore, requirements.txt, docker-compose.yml
Added cache directory ignores, timm dependency, and persistent cache volume mapping for model weights and labels
Inference Base Classes
processing_services/example/api/base.py
Introduced SimplifiedInferenceBase with device selection, weight/label loading, and transform composition; added ResNet50Base and TimmResNet50Base concrete implementations for model construction and batch inference
Global Moth Classifier Algorithm
processing_services/example/api/global_moth_classifier.py
Implemented new GlobalMothClassifier algorithm subclassing TimmResNet50Base, supporting lazy compilation, batch processing, and ClassificationResponse attachment to detections
Pipeline Definition & Integration
processing_services/example/api/pipelines.py, processing_services/example/api/api.py, processing_services/example/api/schemas.py
Added ZeroShotObjectDetectorWithGlobalMothClassifierPipeline orchestrating two-stage detection-then-classification; exported pipeline to API choices; extended PipelineChoice literal
Utility Functions
processing_services/example/api/utils.py
Added persistent cache directory handling and new get_best_device() utility for CUDA/CPU selection

Sequence Diagram(s)

sequenceDiagram
    participant Client
    participant API as API Handler
    participant Pipeline as ZeroShotObjectDetectorWithGlobalMothClassifierPipeline
    participant Detector as ZeroShotObjectDetector
    participant Classifier as GlobalMothClassifier
    participant Model as TimmResNet50 Model
    
    Client->>API: POST /run (with image)
    activate API
    
    API->>Pipeline: run()
    activate Pipeline
    
    Pipeline->>Detector: run(image, batch_size=1)
    activate Detector
    Detector->>Detector: detect objects
    Detector-->>Pipeline: detections[]
    deactivate Detector
    
    Pipeline->>Classifier: run(detections[], batch_size=4)
    activate Classifier
    Classifier->>Classifier: batch images, apply transforms
    Classifier->>Model: predict_batch(transformed_batch)
    activate Model
    Model-->>Classifier: logits
    deactivate Model
    Classifier->>Classifier: post_process (softmax, scores)
    Classifier->>Classifier: attach ClassificationResponse to each detection
    Classifier-->>Pipeline: detections[] with classifications
    deactivate Classifier
    
    Pipeline->>Pipeline: compute elapsed_time
    Pipeline-->>API: PipelineResultsResponse
    deactivate Pipeline
    
    API-->>Client: response with detections & classifications
    deactivate API
Loading
sequenceDiagram
    participant Caller
    participant Classifier as GlobalMothClassifier
    participant Base as SimplifiedInferenceBase
    participant Utils as utils.download_file
    participant Model as TimmResNet50
    
    Caller->>Classifier: compile()
    activate Classifier
    
    Classifier->>Base: __init__()
    activate Base
    Base->>Utils: get_weights(weights_path)
    activate Utils
    Utils-->>Base: cached_weight_path
    deactivate Utils
    Base->>Utils: get_labels(labels_path)
    activate Utils
    Utils-->>Base: label_dict
    deactivate Utils
    Base->>Base: get_transforms()
    Base->>Model: get_model() (via TimmResNet50Base)
    Model-->>Base: initialized model
    Base->>Base: move model to device, set eval mode
    deactivate Base
    
    Classifier->>Classifier: store model, transforms, category_map, num_classes
    Classifier-->>Caller: compile() complete
    deactivate Classifier
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

  • base.py: Multiple new base classes with abstract methods and device/weight-loading logic require careful review of PyTorch model initialization patterns and inheritance structure
  • global_moth_classifier.py: New algorithm subclassing two base classes; verify proper lifecycle (compile before run), batch processing correctness, and response object construction
  • pipelines.py: Duplicate class definition detected (ZeroShotObjectDetectorWithGlobalMothClassifierPipeline appears twice); requires clarification on whether this is intentional or a merge artifact
  • Cross-module integration: Verify consistency of imports, exports, and configuration flow across api.py, schemas.py, and pipelines.py

Poem

🐰 A moth takes flight through neural nets so deep,
With timm and torch, classifications keep,
From zero-shot detection to final class,
Our fuzzy classifier learns to pass,
Two stages dance—now moths are seen at last! 🦋

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description check ⚠️ Warning The pull request description is completely empty, missing all required template sections including summary, list of changes, related issues, detailed description, testing instructions, and checklist. Add a comprehensive PR description following the template with summary, detailed list of changes, testing approach, and completion of the pre-merge checklist.
Docstring Coverage ⚠️ Warning Docstring coverage is 78.95% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main change: adding a general detector and global moth model to the example pipeline.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch feat/global-moths-model

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@mihow mihow changed the title [Draft] Add general detector + global moth model Add general detector + global moth model Nov 18, 2025
@mihow mihow changed the title Add general detector + global moth model Add general detector + global moth model to example pipeline Nov 18, 2025
@mihow mihow marked this pull request as ready for review November 18, 2025 22:17
Copilot AI review requested due to automatic review settings November 18, 2025 22:17
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for a global moth species classifier to the example processing service pipeline. The classifier provides high-quality identification across 29,176+ moth species using a ResNet50-based model trained on global data.

Key changes:

  • Adds a new GlobalMothClassifier algorithm with deferred model loading to optimize initialization
  • Introduces base classes for inference models (SimplifiedInferenceBase, ResNet50Base, TimmResNet50Base) to support model loading and inference
  • Updates file caching to use persistent cache directories instead of temporary directories

Reviewed Changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
requirements.txt Adds timm==1.0.11 dependency for the timm-based ResNet50 model
docker-compose.yml Adds persistent cache volume mount for model weights and labels
api/utils.py Refactors file download to use persistent cache, adds device selection helper, improves download progress logging
api/schemas.py Registers new pipeline slug for the global moth classifier pipeline
api/pipelines.py Adds ZeroShotObjectDetectorWithGlobalMothClassifierPipeline that combines object detection with moth classification
api/global_moth_classifier.py Implements GlobalMothClassifier algorithm with deferred compilation and batch inference
api/base.py Introduces base classes for simplified inference models with weight/label loading and preprocessing
api/api.py Registers the new pipeline in the API endpoints
.gitignore Adds cache directories to gitignore
Comments suppressed due to low confidence (1)

processing_services/example/api/pipelines.py:70

        self.stages = self.stages or self.get_stages()

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


resulting_filepath = pathlib.Path(local_filepath).resolve()
logger.info(f"✅ Download completed: {resulting_filepath}")
logger.info(f"Downloaded to {resulting_filepath}")
Copy link

Copilot AI Nov 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant logging: lines 91-92 log essentially the same information. Remove one of these duplicate log statements.

Suggested change
logger.info(f"Downloaded to {resulting_filepath}")

Copilot uses AI. Check for mistakes.
Comment on lines 83 to 89
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
downloaded += len(chunk)
if total_size > 0:
percent = (downloaded / total_size) * 100
logger.info(f" Progress: {percent:.1f}% ({downloaded}/{total_size} bytes)")

Copy link

Copilot AI Nov 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excessive logging in download loop: logging progress for every 8KB chunk will create thousands of log entries for large files. Consider logging progress at fixed intervals (e.g., every 10%) or reducing logging frequency.

Suggested change
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
downloaded += len(chunk)
if total_size > 0:
percent = (downloaded / total_size) * 100
logger.info(f" Progress: {percent:.1f}% ({downloaded}/{total_size} bytes)")
last_logged_percent = 0
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
downloaded += len(chunk)
if total_size > 0:
percent = int((downloaded / total_size) * 100)
if percent >= last_logged_percent + 10 or percent == 100:
logger.info(f" Progress: {percent}% ({downloaded}/{total_size} bytes)")
last_logged_percent = percent

Copilot uses AI. Check for mistakes.
timestamp=datetime.datetime.now(),
algorithm=AlgorithmReference(
name=self.name,
key=self.get_key(),
Copy link

Copilot AI Nov 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Algorithm base class does not have a get_key() method, but TimmResNet50Base inherits it from SimplifiedInferenceBase. This works due to multiple inheritance, but the method call relies on the MRO (Method Resolution Order). Consider adding get_key() to the Algorithm base class or using a mixin pattern to make this dependency more explicit.

Copilot uses AI. Check for mistakes.
# Create a basic config response before compilation
self._algorithm_config_response = AlgorithmConfigResponse(
name=self.name,
key=self.get_key(),
Copy link

Copilot AI Nov 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue as Comment 3: self.get_key() is called before the model is compiled, which means TimmResNet50Base.__init__() hasn't been called yet and the method is inherited via multiple inheritance. While this works, it's fragile and could break if the class hierarchy changes.

Copilot uses AI. Check for mistakes.
Comment on lines +96 to +104
def get_best_device() -> str:
"""
Returns the best available device for running the model.
MPS is not supported by the current algorithms.
"""
if torch.cuda.is_available():
return f"cuda:{torch.cuda.current_device()}"
else:
return "cpu"
Copy link

Copilot AI Nov 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function duplicates the get_best_device() function already defined in algorithms.py (lines 24-33). Consider removing one implementation and importing from a single location to avoid code duplication.

Copilot uses AI. Check for mistakes.
self.category_map = self.get_labels(self.labels_path)
self.num_classes = self.num_classes or len(self.category_map)
self.weights = self.get_weights(self.weights_path)
self.transforms = self.get_transforms()
Copy link

Copilot AI Nov 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This call to SimplifiedInferenceBase.get_transforms in an initialization method is overridden by ResNet50Base.get_transforms.
This call to SimplifiedInferenceBase.get_transforms in an initialization method is overridden by GlobalMothClassifier.get_transforms.

Suggested change
self.transforms = self.get_transforms()
self.transforms = None # Subclasses should set this after their own initialization

Copilot uses AI. Check for mistakes.
self.transforms = self.get_transforms()

logger.info(f"Loading model for {self.name} with {len(self.category_map or [])} categories")
self.model = self.get_model()
Copy link

Copilot AI Nov 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This call to SimplifiedInferenceBase.get_model in an initialization method is overridden by ResNet50Base.get_model.
This call to SimplifiedInferenceBase.get_model in an initialization method is overridden by TimmResNet50Base.get_model.

Suggested change
self.model = self.get_model()
self.model = None # Model should be initialized in subclass after all setup

Copilot uses AI. Check for mistakes.
logger.setLevel(logging.INFO)


class GlobalMothClassifier(Algorithm, TimmResNet50Base):
Copy link

Copilot AI Nov 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class does not call SimplifiedInferenceBase.init during initialization. (GlobalMothClassifier.init may be missing a call to a base class init)

Copilot uses AI. Check for mistakes.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
processing_services/example/api/utils.py (1)

42-93: Fix local-path handling regression and add timeout to get_or_download_file

Verification confirms both issues:

  1. Local paths are broken (regression). The function always routes through requests.get() without checking if the input is a local path first. For a local path like /home/user/file.jpg, urlparse().path extracts /home/user/file.jpg, then requests.get() is called with that value, which fails. This breaks the documented contract ("If the path is a local path, return the path"). Callers in processing_services/example/api/base.py:74 and :87 expect to pass local paths via weights_path and labels_path parameters.

  2. requests.get() has no timeout. Confirmed S113 violation—this can hang indefinitely on bad network conditions.

Suggested fix:

  • Add early return for local paths using the existing is_url() function (defined at line 30).
  • Add optional timeout parameter (default 60 seconds) and pass it to requests.get().

Example diff:

-def get_or_download_file(path_or_url, tempdir_prefix="antenna") -> pathlib.Path:
+def get_or_download_file(
+    path_or_url,
+    tempdir_prefix: str = "antenna",
+    timeout: float | None = 60,
+) -> pathlib.Path:
     """
     Fetch a file from a URL or local path. If the path is a URL, download the file.
     If the URL has already been downloaded, return the existing local path.
     If the path is a local path, return the path.
     """
     if not path_or_url:
         raise Exception("Specify a URL or path to fetch file from.")
 
+    # Handle local filesystem paths directly
+    if not is_url(str(path_or_url)):
+        local_path = pathlib.Path(path_or_url)
+        if local_path.exists():
+            return local_path
+        raise FileNotFoundError(f"Local file does not exist: {local_path}")
+
     # Use persistent cache directory for downloaded files
     cache_root = (
         pathlib.Path("/app/cache") if pathlib.Path("/app").exists() else pathlib.Path.home() / ".antenna_cache"
     )
     destination_dir = cache_root / tempdir_prefix
     fname = pathlib.Path(urlparse(path_or_url).path).name
     destination_dir.mkdir(parents=True, exist_ok=True)
     local_filepath = destination_dir / fname
 
-    if local_filepath and local_filepath.exists():
+    if local_filepath.exists():
         logger.info(f"📁 Using cached file: {local_filepath}")
         return local_filepath
 
     else:
         logger.info(f"⬇️  Downloading {path_or_url} to {local_filepath}")
         headers = {"User-Agent": USER_AGENT}
-        response = requests.get(path_or_url, stream=True, headers=headers)
+        request_kwargs = {"stream": True, "headers": headers}
+        if timeout is not None:
+            request_kwargs["timeout"] = timeout
+        response = requests.get(path_or_url, **request_kwargs)
         response.raise_for_status()

Additionally, get_best_device() is defined in both utils.py:96 and algorithms.py:24—consider centralizing it to a single location to prevent divergence.

🧹 Nitpick comments (5)
processing_services/example/api/utils.py (1)

96-105: Avoid duplicating get_best_device logic across modules

Defining get_best_device() here is fine, but there is an existing implementation in processing_services/example/api/algorithms.py with the same behavior. To avoid divergence later (e.g., if you ever add MPS support or tweak device selection), consider centralizing this helper in a single place (e.g., utils.get_best_device) and importing it where needed, removing the duplicate definition.

processing_services/example/api/pipelines.py (1)

12-12: Pipeline wiring is correct; consider aligning metadata initialization and class attributes

Functionally, ZeroShotObjectDetectorWithGlobalMothClassifierPipeline looks good:

  • Stage order and batching ([1, 4]) match the intended detector→classifier flow.
  • Existing detections are reused when provided.
  • Logging and timing mirror the other pipeline implementations.

Two small follow‑ups to consider:

  1. Metadata consistency for config.algorithms.
    Other pipelines populate config.algorithms at class definition time, so pipelines[i].config.algorithms and algorithm_choices are valid even before any pipeline instance is constructed. Here config.algorithms starts as [] and is only filled in get_stages(), meaning this pipeline’s algorithms won’t appear in purely static /info or introspection flows unless an instance has been created and get_stages() run. If you want consistent metadata behavior, you could initialize config.algorithms similarly:

    config = PipelineConfigResponse(
        name="Zero Shot Object Detector With Global Moth Classifier Pipeline",
        slug="zero-shot-object-detector-with-global-moth-classifier-pipeline",
        description=(
            "HF zero shot object detector with global moth species classifier. "
            "Supports 29,176+ moth species trained on global data."
        ),
        version=1,
        algorithms=[
            ZeroShotObjectDetector().algorithm_config_response,
            GlobalMothClassifier().algorithm_config_response,
        ],
    )

    (This is cheap, since GlobalMothClassifier.__init__ defers model loading.)

  2. Mutable class attributes.
    config (with algorithms=[]) is a mutable class attribute, which Ruff flags (RUF012). This pattern exists in the other pipelines too, but over time it may be worth annotating such attributes with typing.ClassVar and avoiding mutable defaults to reduce the risk of accidental cross‑instance state sharing.

Overall, the runtime behavior of the new pipeline looks solid; these are mainly metadata and hygiene improvements.

Also applies to: 352-410

processing_services/example/api/base.py (3)

35-56: Initialization: tighten num_classes derivation and consider static-analysis hint on category_map

The overall init flow looks good (kwargs override, device selection, then labels/weights/model). One edge case and one style point:

  • self.num_classes = self.num_classes or len(self.category_map) will silently set num_classes to 0 when both num_classes and labels_path are unset (empty category_map). For a classifier this yields a degenerate head. A more explicit pattern would avoid that and only infer when labels exist:
-        self.category_map = self.get_labels(self.labels_path)
-        self.num_classes = self.num_classes or len(self.category_map)
+        self.category_map = self.get_labels(self.labels_path)
+        if self.num_classes is None and self.category_map:
+            self.num_classes = len(self.category_map)
  • Ruff’s RUF012 on category_map is about the mutable default at the class level. Since you always overwrite self.category_map in __init__, it’s not a functional bug; you can either ignore the warning or switch to a non-mutable default / ClassVar annotation if you want a clean lint run.

148-186: ResNet50/timm model creation and checkpoint loading: behavior OK, but align guardrails and reduce duplication

The torchvision and timm variants look correct and consistent with each other. A few targeted improvements:

  • ResNet50Base.get_model explicitly guards num_classes is None before constructing the classifier head, but TimmResNet50Base.get_model passes self.num_classes straight into timm.create_model. For misconfigured cases (no labels and no num_classes override), you’ll currently get different failure modes. Consider reusing the same guard (or doing the guard once in the base) so both paths fail fast and consistently when num_classes is missing.
  • The checkpoint handling logic (model_state_dict / state_dict / raw state dict) is duplicated between the two methods, and the timm variant has less logging. A small shared helper like _load_checkpoint(checkpoint, model) (or a load_state_dict_from_checkpoint utility) would keep behavior and logging in sync as more backbones are added.

Overall, the current behavior is functionally fine; these are about consistency and maintainability.

Also applies to: 216-237


187-208: Batch inference and post‑processing look correct; only minor readability tweak optional

The predict_batch implementation (no‑grad context, moving the batch to self.device, single self.model call) is idiomatic for inference. post_process_batch correctly applies softmax and returns per‑sample dicts with scores and logits.

If you want a tiny readability win, you could avoid using len(predictions) as an index and rely on enumerate instead:

-        for prob_tensor in probabilities:
-            prob_list = prob_tensor.cpu().numpy().tolist()
-            predictions.append(
-                {
-                    "scores": prob_list,
-                    "logits": logits[len(predictions)].cpu().numpy().tolist(),
-                }
-            )
+        for idx, prob_tensor in enumerate(probabilities):
+            predictions.append(
+                {
+                    "scores": prob_tensor.cpu().numpy().tolist(),
+                    "logits": logits[idx].cpu().numpy().tolist(),
+                }
+            )

Purely cosmetic; no behavior change.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 562f734 and fcfcd87.

📒 Files selected for processing (9)
  • processing_services/example/.gitignore (1 hunks)
  • processing_services/example/api/api.py (2 hunks)
  • processing_services/example/api/base.py (1 hunks)
  • processing_services/example/api/global_moth_classifier.py (1 hunks)
  • processing_services/example/api/pipelines.py (2 hunks)
  • processing_services/example/api/schemas.py (1 hunks)
  • processing_services/example/api/utils.py (2 hunks)
  • processing_services/example/docker-compose.yml (1 hunks)
  • processing_services/example/requirements.txt (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
processing_services/example/api/api.py (1)
processing_services/example/api/pipelines.py (1)
  • ZeroShotObjectDetectorWithGlobalMothClassifierPipeline (352-410)
processing_services/example/api/global_moth_classifier.py (3)
processing_services/example/api/base.py (8)
  • TimmResNet50Base (211-237)
  • get_key (63-68)
  • get_transforms (98-106)
  • get_transforms (138-146)
  • predict_batch (115-120)
  • predict_batch (187-192)
  • post_process_batch (122-127)
  • post_process_batch (194-208)
processing_services/example/api/schemas.py (5)
  • AlgorithmCategoryMapResponse (150-179)
  • AlgorithmConfigResponse (182-208)
  • AlgorithmReference (75-77)
  • ClassificationResponse (80-101)
  • Detection (138-147)
processing_services/example/api/pipelines.py (6)
  • run (88-92)
  • run (176-200)
  • run (230-243)
  • run (275-296)
  • run (328-349)
  • run (384-410)
processing_services/example/api/utils.py (2)
processing_services/example/api/schemas.py (1)
  • open (55-68)
processing_services/example/api/algorithms.py (1)
  • get_best_device (24-33)
processing_services/example/api/base.py (1)
processing_services/example/api/utils.py (2)
  • get_best_device (96-104)
  • get_or_download_file (42-93)
processing_services/example/api/pipelines.py (4)
processing_services/example/api/global_moth_classifier.py (4)
  • GlobalMothClassifier (27-249)
  • algorithm_config_response (72-92)
  • algorithm_config_response (95-97)
  • run (133-215)
processing_services/minimal/api/pipelines.py (4)
  • Pipeline (182-200)
  • run (199-200)
  • run (209-221)
  • run (241-253)
processing_services/example/api/schemas.py (3)
  • PipelineConfigResponse (300-308)
  • PipelineResultsResponse (266-280)
  • Detection (138-147)
processing_services/example/api/algorithms.py (7)
  • Algorithm (36-66)
  • ZeroShotObjectDetector (69-178)
  • run (42-43)
  • run (99-155)
  • run (201-242)
  • run (326-332)
  • run (400-406)
🪛 Ruff (0.14.5)
processing_services/example/api/global_moth_classifier.py

148-148: Avoid specifying long messages outside the exception class

(TRY003)


177-177: Avoid specifying long messages outside the exception class

(TRY003)


189-189: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

processing_services/example/api/utils.py

77-77: Probable use of requests call without timeout

(S113)

processing_services/example/api/base.py

39-39: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)


158-158: Avoid specifying long messages outside the exception class

(TRY003)

processing_services/example/api/pipelines.py

358-358: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: Agent
  • GitHub Check: test
🔇 Additional comments (6)
processing_services/example/.gitignore (1)

1-4: Ignore rules align with new cache volumes

Ignoring cache/, huggingface_cache/, and pytorch_cache/ is consistent with the new persistent cache usage and prevents large artifacts from being committed. Looks good.

processing_services/example/docker-compose.yml (1)

5-9: Persistent cache volume wiring looks correct

The ./cache:/app/cache mapping matches the logic in get_or_download_file (preferring /app/cache) and will keep model weights/labels persistent across container restarts. No issues from the compose side.

processing_services/example/requirements.txt (1)

7-10: Confirm timm version compatibility with pinned torch/torchvision

Adding timm==1.0.11 is appropriate for the new TimmResNet50Base / GlobalMothClassifier flow. Please just double‑check that this specific timm version is tested with torch==2.6.0 and torchvision==0.21.0 in your runtime environment (CUDA/CPU variants), and that the version pin matches the training environment for the shipped weights.

processing_services/example/api/api.py (1)

9-16: New pipeline is correctly exposed via the API

The ZeroShotObjectDetectorWithGlobalMothClassifierPipeline import and its inclusion in the pipelines list are consistent with the new slug in PipelineChoice and the pipeline’s own config.slug. This makes the new two‑stage pipeline selectable through the existing /process and /info flows without altering error handling.

Also applies to: 40-46

processing_services/example/api/schemas.py (1)

211-217: PipelineChoice updated consistently with new pipeline slug

The new literal "zero-shot-object-detector-with-global-moth-classifier-pipeline" matches the slug defined in the corresponding pipeline’s PipelineConfigResponse, keeping request/response typing in sync.

processing_services/example/api/base.py (1)

70-107: Weights/labels loading and transforms are solid; only minor robustness nits

get_weights, get_labels, and get_transforms are straightforward and line up with get_or_download_file’s behavior. Two small considerations:

  • get_labels assumes the JSON is a {label: index} mapping and inverts it. If any existing label files are already {index: label}, this will flip them the wrong way around. Worth double‑checking that all current label JSONs follow the expected schema.
  • Log messages in get_weights always say “Downloading …” even when weights_path is a local path. If you care about log clarity, you could special‑case local paths, but this is cosmetic.

Comment on lines +133 to +215
def run(self, detections: list[Detection]) -> list[Detection]:
"""
Run classification on a list of detections.
Args:
detections: List of Detection objects with cropped images
Returns:
List of Detection objects with added classifications
"""
if not detections:
return []

# Ensure model is compiled
if self.model is None:
raise RuntimeError("Model not compiled. Call compile() first.")

logger.info(f"Running {self.name} on {len(detections)} detections")

# Process detections in batches
classified_detections = []

for i in range(0, len(detections), self.batch_size):
batch_detections = detections[i : i + self.batch_size]
batch_images = []

# Prepare batch of images
for detection in batch_detections:
if detection._pil:
# Convert to RGB if needed
if detection._pil.mode != "RGB":
img = detection._pil.convert("RGB")
else:
img = detection._pil
batch_images.append(img)
else:
logger.warning(f"Detection {detection.id} has no PIL image")
continue

if not batch_images:
continue

# Transform images
if self.transforms is None:
raise RuntimeError("Transforms not initialized. Call compile() first.")
batch_tensor = torch.stack([self.transforms(img) for img in batch_images])

# Run inference
start_time = datetime.datetime.now()
predictions = self.predict_batch(batch_tensor)
processed_predictions = self.post_process_batch(predictions)
end_time = datetime.datetime.now()

inference_time = (end_time - start_time).total_seconds() / len(batch_images)

# Add classifications to detections
for detection, prediction in zip(batch_detections, processed_predictions):
# Get best prediction
best_score = max(prediction["scores"])
best_idx = prediction["scores"].index(best_score)
best_label = self.category_map.get(best_idx, f"class_{best_idx}")

classification = ClassificationResponse(
classification=best_label,
labels=[best_label],
scores=[best_score],
logits=prediction["logits"],
inference_time=inference_time,
timestamp=datetime.datetime.now(),
algorithm=AlgorithmReference(
name=self.name,
key=self.get_key(),
),
terminal=True,
)

# Add classification to detection
detection_with_classification = detection.copy(deep=True)
detection_with_classification.classifications = [classification]
classified_detections.append(detection_with_classification)

logger.info(f"Classified {len(classified_detections)} detections")
return classified_detections
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix detection/prediction alignment in run() and use zip(..., strict=True)

In run(), detections are filtered into batch_images based on detection._pil, but you still zip over the original batch_detections:

batch_detections = detections[i : i + self.batch_size]
batch_images = []
...
for detection in batch_detections:
    if detection._pil:
        ...
        batch_images.append(img)
    else:
        logger.warning(...)
        continue
...
processed_predictions = self.post_process_batch(predictions)
...
for detection, prediction in zip(batch_detections, processed_predictions):
    ...

If any Detection in the batch lacks a PIL image, batch_images and processed_predictions will have fewer elements than batch_detections, and predictions will be paired with the wrong detections (or the wrong ones will get classified). That’s a silent correctness bug.

You can fix this and address the B905 warning by tracking only valid detections and zipping with strict=True:

-        for i in range(0, len(detections), self.batch_size):
-            batch_detections = detections[i : i + self.batch_size]
-            batch_images = []
+        for i in range(0, len(detections), self.batch_size):
+            batch_detections = detections[i : i + self.batch_size]
+            batch_images = []
+            valid_detections: list[Detection] = []
@@
-            # Prepare batch of images
-            for detection in batch_detections:
-                if detection._pil:
+            # Prepare batch of images
+            for detection in batch_detections:
+                if detection._pil:
@@
-                    else:
-                        img = detection._pil
-                    batch_images.append(img)
-                else:
-                    logger.warning(f"Detection {detection.id} has no PIL image")
-                    continue
+                    else:
+                        img = detection._pil
+                    batch_images.append(img)
+                    valid_detections.append(detection)
+                else:
+                    logger.warning("Detection %s has no PIL image", detection.id)
+                    continue
@@
-            # Add classifications to detections
-            for detection, prediction in zip(batch_detections, processed_predictions):
+            # Add classifications to detections
+            for detection, prediction in zip(valid_detections, processed_predictions, strict=True):
@@
-                detection_with_classification = detection.copy(deep=True)
+                detection_with_classification = detection.copy(deep=True)
                 detection_with_classification.classifications = [classification]
                 classified_detections.append(detection_with_classification)

This ensures:

  • Only detections that actually contributed an image get classifications.
  • Any mismatch between the number of valid detections and predictions raises immediately instead of silently misaligning outputs.
🧰 Tools
🪛 Ruff (0.14.5)

148-148: Avoid specifying long messages outside the exception class

(TRY003)


177-177: Avoid specifying long messages outside the exception class

(TRY003)


189-189: zip() without an explicit strict= parameter

Add explicit value for parameter strict=

(B905)

🤖 Prompt for AI Agents
In processing_services/example/api/global_moth_classifier.py around lines
133-215, the run() loop builds batch_images by skipping detections with no PIL
image but later zips processed_predictions with the original batch_detections,
causing misalignment; change the code to collect a parallel list of
valid_detections (append detections only when you append their image), use
valid_detections for post-inference pairing, compute inference_time using
len(valid_detections), and replace zip(... ) with zip(valid_detections,
processed_predictions, strict=True) so any mismatch raises immediately while
preserving the existing logger.warning for skipped detections.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants