-
Notifications
You must be signed in to change notification settings - Fork 11
Add general detector + global moth model to example pipeline #992
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
👷 Deploy Preview for antenna-preview processing.
|
248fda8 to
9bd828b
Compare
WalkthroughIntroduces a new Global Moth Classifier algorithm integrated into a two-stage pipeline with ZeroShotObjectDetector. Adds inference base classes supporting lazy-loaded PyTorch models, persistent caching for weights and labels, device selection utilities, and supporting configuration updates across the API layer. Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant API as API Handler
participant Pipeline as ZeroShotObjectDetectorWithGlobalMothClassifierPipeline
participant Detector as ZeroShotObjectDetector
participant Classifier as GlobalMothClassifier
participant Model as TimmResNet50 Model
Client->>API: POST /run (with image)
activate API
API->>Pipeline: run()
activate Pipeline
Pipeline->>Detector: run(image, batch_size=1)
activate Detector
Detector->>Detector: detect objects
Detector-->>Pipeline: detections[]
deactivate Detector
Pipeline->>Classifier: run(detections[], batch_size=4)
activate Classifier
Classifier->>Classifier: batch images, apply transforms
Classifier->>Model: predict_batch(transformed_batch)
activate Model
Model-->>Classifier: logits
deactivate Model
Classifier->>Classifier: post_process (softmax, scores)
Classifier->>Classifier: attach ClassificationResponse to each detection
Classifier-->>Pipeline: detections[] with classifications
deactivate Classifier
Pipeline->>Pipeline: compute elapsed_time
Pipeline-->>API: PipelineResultsResponse
deactivate Pipeline
API-->>Client: response with detections & classifications
deactivate API
sequenceDiagram
participant Caller
participant Classifier as GlobalMothClassifier
participant Base as SimplifiedInferenceBase
participant Utils as utils.download_file
participant Model as TimmResNet50
Caller->>Classifier: compile()
activate Classifier
Classifier->>Base: __init__()
activate Base
Base->>Utils: get_weights(weights_path)
activate Utils
Utils-->>Base: cached_weight_path
deactivate Utils
Base->>Utils: get_labels(labels_path)
activate Utils
Utils-->>Base: label_dict
deactivate Utils
Base->>Base: get_transforms()
Base->>Model: get_model() (via TimmResNet50Base)
Model-->>Base: initialized model
Base->>Base: move model to device, set eval mode
deactivate Base
Classifier->>Classifier: store model, transforms, category_map, num_classes
Classifier-->>Caller: compile() complete
deactivate Classifier
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes
Poem
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for a global moth species classifier to the example processing service pipeline. The classifier provides high-quality identification across 29,176+ moth species using a ResNet50-based model trained on global data.
Key changes:
- Adds a new
GlobalMothClassifieralgorithm with deferred model loading to optimize initialization - Introduces base classes for inference models (
SimplifiedInferenceBase,ResNet50Base,TimmResNet50Base) to support model loading and inference - Updates file caching to use persistent cache directories instead of temporary directories
Reviewed Changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| requirements.txt | Adds timm==1.0.11 dependency for the timm-based ResNet50 model |
| docker-compose.yml | Adds persistent cache volume mount for model weights and labels |
| api/utils.py | Refactors file download to use persistent cache, adds device selection helper, improves download progress logging |
| api/schemas.py | Registers new pipeline slug for the global moth classifier pipeline |
| api/pipelines.py | Adds ZeroShotObjectDetectorWithGlobalMothClassifierPipeline that combines object detection with moth classification |
| api/global_moth_classifier.py | Implements GlobalMothClassifier algorithm with deferred compilation and batch inference |
| api/base.py | Introduces base classes for simplified inference models with weight/label loading and preprocessing |
| api/api.py | Registers the new pipeline in the API endpoints |
| .gitignore | Adds cache directories to gitignore |
Comments suppressed due to low confidence (1)
processing_services/example/api/pipelines.py:70
- This call to Pipeline.get_stages in an initialization method is overridden by ZeroShotHFClassifierPipeline.get_stages.
This call to Pipeline.get_stages in an initialization method is overridden by ZeroShotObjectDetectorPipeline.get_stages.
This call to Pipeline.get_stages in an initialization method is overridden by ZeroShotObjectDetectorWithRandomSpeciesClassifierPipeline.get_stages.
This call to Pipeline.get_stages in an initialization method is overridden by ZeroShotObjectDetectorWithConstantClassifierPipeline.get_stages.
This call to Pipeline.get_stages in an initialization method is overridden by ZeroShotObjectDetectorWithGlobalMothClassifierPipeline.get_stages.
self.stages = self.stages or self.get_stages()
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| resulting_filepath = pathlib.Path(local_filepath).resolve() | ||
| logger.info(f"✅ Download completed: {resulting_filepath}") | ||
| logger.info(f"Downloaded to {resulting_filepath}") |
Copilot
AI
Nov 18, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Redundant logging: lines 91-92 log essentially the same information. Remove one of these duplicate log statements.
| logger.info(f"Downloaded to {resulting_filepath}") |
| for chunk in response.iter_content(chunk_size=8192): | ||
| f.write(chunk) | ||
| downloaded += len(chunk) | ||
| if total_size > 0: | ||
| percent = (downloaded / total_size) * 100 | ||
| logger.info(f" Progress: {percent:.1f}% ({downloaded}/{total_size} bytes)") | ||
|
|
Copilot
AI
Nov 18, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excessive logging in download loop: logging progress for every 8KB chunk will create thousands of log entries for large files. Consider logging progress at fixed intervals (e.g., every 10%) or reducing logging frequency.
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| downloaded += len(chunk) | |
| if total_size > 0: | |
| percent = (downloaded / total_size) * 100 | |
| logger.info(f" Progress: {percent:.1f}% ({downloaded}/{total_size} bytes)") | |
| last_logged_percent = 0 | |
| for chunk in response.iter_content(chunk_size=8192): | |
| f.write(chunk) | |
| downloaded += len(chunk) | |
| if total_size > 0: | |
| percent = int((downloaded / total_size) * 100) | |
| if percent >= last_logged_percent + 10 or percent == 100: | |
| logger.info(f" Progress: {percent}% ({downloaded}/{total_size} bytes)") | |
| last_logged_percent = percent |
| timestamp=datetime.datetime.now(), | ||
| algorithm=AlgorithmReference( | ||
| name=self.name, | ||
| key=self.get_key(), |
Copilot
AI
Nov 18, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Algorithm base class does not have a get_key() method, but TimmResNet50Base inherits it from SimplifiedInferenceBase. This works due to multiple inheritance, but the method call relies on the MRO (Method Resolution Order). Consider adding get_key() to the Algorithm base class or using a mixin pattern to make this dependency more explicit.
| # Create a basic config response before compilation | ||
| self._algorithm_config_response = AlgorithmConfigResponse( | ||
| name=self.name, | ||
| key=self.get_key(), |
Copilot
AI
Nov 18, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same issue as Comment 3: self.get_key() is called before the model is compiled, which means TimmResNet50Base.__init__() hasn't been called yet and the method is inherited via multiple inheritance. While this works, it's fragile and could break if the class hierarchy changes.
| def get_best_device() -> str: | ||
| """ | ||
| Returns the best available device for running the model. | ||
| MPS is not supported by the current algorithms. | ||
| """ | ||
| if torch.cuda.is_available(): | ||
| return f"cuda:{torch.cuda.current_device()}" | ||
| else: | ||
| return "cpu" |
Copilot
AI
Nov 18, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function duplicates the get_best_device() function already defined in algorithms.py (lines 24-33). Consider removing one implementation and importing from a single location to avoid code duplication.
| self.category_map = self.get_labels(self.labels_path) | ||
| self.num_classes = self.num_classes or len(self.category_map) | ||
| self.weights = self.get_weights(self.weights_path) | ||
| self.transforms = self.get_transforms() |
Copilot
AI
Nov 18, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This call to SimplifiedInferenceBase.get_transforms in an initialization method is overridden by ResNet50Base.get_transforms.
This call to SimplifiedInferenceBase.get_transforms in an initialization method is overridden by GlobalMothClassifier.get_transforms.
| self.transforms = self.get_transforms() | |
| self.transforms = None # Subclasses should set this after their own initialization |
| self.transforms = self.get_transforms() | ||
|
|
||
| logger.info(f"Loading model for {self.name} with {len(self.category_map or [])} categories") | ||
| self.model = self.get_model() |
Copilot
AI
Nov 18, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This call to SimplifiedInferenceBase.get_model in an initialization method is overridden by ResNet50Base.get_model.
This call to SimplifiedInferenceBase.get_model in an initialization method is overridden by TimmResNet50Base.get_model.
| self.model = self.get_model() | |
| self.model = None # Model should be initialized in subclass after all setup |
| logger.setLevel(logging.INFO) | ||
|
|
||
|
|
||
| class GlobalMothClassifier(Algorithm, TimmResNet50Base): |
Copilot
AI
Nov 18, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class does not call SimplifiedInferenceBase.init during initialization. (GlobalMothClassifier.init may be missing a call to a base class init)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
processing_services/example/api/utils.py (1)
42-93: Fix local-path handling regression and add timeout toget_or_download_fileVerification confirms both issues:
Local paths are broken (regression). The function always routes through
requests.get()without checking if the input is a local path first. For a local path like/home/user/file.jpg,urlparse().pathextracts/home/user/file.jpg, thenrequests.get()is called with that value, which fails. This breaks the documented contract ("If the path is a local path, return the path"). Callers inprocessing_services/example/api/base.py:74and:87expect to pass local paths viaweights_pathandlabels_pathparameters.
requests.get()has no timeout. Confirmed S113 violation—this can hang indefinitely on bad network conditions.Suggested fix:
- Add early return for local paths using the existing
is_url()function (defined at line 30).- Add optional
timeoutparameter (default60seconds) and pass it torequests.get().Example diff:
-def get_or_download_file(path_or_url, tempdir_prefix="antenna") -> pathlib.Path: +def get_or_download_file( + path_or_url, + tempdir_prefix: str = "antenna", + timeout: float | None = 60, +) -> pathlib.Path: """ Fetch a file from a URL or local path. If the path is a URL, download the file. If the URL has already been downloaded, return the existing local path. If the path is a local path, return the path. """ if not path_or_url: raise Exception("Specify a URL or path to fetch file from.") + # Handle local filesystem paths directly + if not is_url(str(path_or_url)): + local_path = pathlib.Path(path_or_url) + if local_path.exists(): + return local_path + raise FileNotFoundError(f"Local file does not exist: {local_path}") + # Use persistent cache directory for downloaded files cache_root = ( pathlib.Path("/app/cache") if pathlib.Path("/app").exists() else pathlib.Path.home() / ".antenna_cache" ) destination_dir = cache_root / tempdir_prefix fname = pathlib.Path(urlparse(path_or_url).path).name destination_dir.mkdir(parents=True, exist_ok=True) local_filepath = destination_dir / fname - if local_filepath and local_filepath.exists(): + if local_filepath.exists(): logger.info(f"📁 Using cached file: {local_filepath}") return local_filepath else: logger.info(f"⬇️ Downloading {path_or_url} to {local_filepath}") headers = {"User-Agent": USER_AGENT} - response = requests.get(path_or_url, stream=True, headers=headers) + request_kwargs = {"stream": True, "headers": headers} + if timeout is not None: + request_kwargs["timeout"] = timeout + response = requests.get(path_or_url, **request_kwargs) response.raise_for_status()Additionally,
get_best_device()is defined in bothutils.py:96andalgorithms.py:24—consider centralizing it to a single location to prevent divergence.
🧹 Nitpick comments (5)
processing_services/example/api/utils.py (1)
96-105: Avoid duplicatingget_best_devicelogic across modulesDefining
get_best_device()here is fine, but there is an existing implementation inprocessing_services/example/api/algorithms.pywith the same behavior. To avoid divergence later (e.g., if you ever add MPS support or tweak device selection), consider centralizing this helper in a single place (e.g.,utils.get_best_device) and importing it where needed, removing the duplicate definition.processing_services/example/api/pipelines.py (1)
12-12: Pipeline wiring is correct; consider aligning metadata initialization and class attributesFunctionally,
ZeroShotObjectDetectorWithGlobalMothClassifierPipelinelooks good:
- Stage order and batching (
[1, 4]) match the intended detector→classifier flow.- Existing detections are reused when provided.
- Logging and timing mirror the other pipeline implementations.
Two small follow‑ups to consider:
Metadata consistency for
config.algorithms.
Other pipelines populateconfig.algorithmsat class definition time, sopipelines[i].config.algorithmsandalgorithm_choicesare valid even before any pipeline instance is constructed. Hereconfig.algorithmsstarts as[]and is only filled inget_stages(), meaning this pipeline’s algorithms won’t appear in purely static/infoor introspection flows unless an instance has been created andget_stages()run. If you want consistent metadata behavior, you could initializeconfig.algorithmssimilarly:config = PipelineConfigResponse( name="Zero Shot Object Detector With Global Moth Classifier Pipeline", slug="zero-shot-object-detector-with-global-moth-classifier-pipeline", description=( "HF zero shot object detector with global moth species classifier. " "Supports 29,176+ moth species trained on global data." ), version=1, algorithms=[ ZeroShotObjectDetector().algorithm_config_response, GlobalMothClassifier().algorithm_config_response, ], )(This is cheap, since
GlobalMothClassifier.__init__defers model loading.)Mutable class attributes.
config(withalgorithms=[]) is a mutable class attribute, which Ruff flags (RUF012). This pattern exists in the other pipelines too, but over time it may be worth annotating such attributes withtyping.ClassVarand avoiding mutable defaults to reduce the risk of accidental cross‑instance state sharing.Overall, the runtime behavior of the new pipeline looks solid; these are mainly metadata and hygiene improvements.
Also applies to: 352-410
processing_services/example/api/base.py (3)
35-56: Initialization: tightennum_classesderivation and consider static-analysis hint oncategory_mapThe overall init flow looks good (kwargs override, device selection, then labels/weights/model). One edge case and one style point:
self.num_classes = self.num_classes or len(self.category_map)will silently setnum_classesto0when bothnum_classesandlabels_pathare unset (emptycategory_map). For a classifier this yields a degenerate head. A more explicit pattern would avoid that and only infer when labels exist:- self.category_map = self.get_labels(self.labels_path) - self.num_classes = self.num_classes or len(self.category_map) + self.category_map = self.get_labels(self.labels_path) + if self.num_classes is None and self.category_map: + self.num_classes = len(self.category_map)
- Ruff’s RUF012 on
category_mapis about the mutable default at the class level. Since you always overwriteself.category_mapin__init__, it’s not a functional bug; you can either ignore the warning or switch to a non-mutable default /ClassVarannotation if you want a clean lint run.
148-186: ResNet50/timm model creation and checkpoint loading: behavior OK, but align guardrails and reduce duplicationThe torchvision and timm variants look correct and consistent with each other. A few targeted improvements:
ResNet50Base.get_modelexplicitly guardsnum_classes is Nonebefore constructing the classifier head, butTimmResNet50Base.get_modelpassesself.num_classesstraight intotimm.create_model. For misconfigured cases (no labels and nonum_classesoverride), you’ll currently get different failure modes. Consider reusing the same guard (or doing the guard once in the base) so both paths fail fast and consistently whennum_classesis missing.- The checkpoint handling logic (
model_state_dict/state_dict/ raw state dict) is duplicated between the two methods, and the timm variant has less logging. A small shared helper like_load_checkpoint(checkpoint, model)(or aload_state_dict_from_checkpointutility) would keep behavior and logging in sync as more backbones are added.Overall, the current behavior is functionally fine; these are about consistency and maintainability.
Also applies to: 216-237
187-208: Batch inference and post‑processing look correct; only minor readability tweak optionalThe
predict_batchimplementation (no‑grad context, moving the batch toself.device, singleself.modelcall) is idiomatic for inference.post_process_batchcorrectly applies softmax and returns per‑sample dicts withscoresandlogits.If you want a tiny readability win, you could avoid using
len(predictions)as an index and rely onenumerateinstead:- for prob_tensor in probabilities: - prob_list = prob_tensor.cpu().numpy().tolist() - predictions.append( - { - "scores": prob_list, - "logits": logits[len(predictions)].cpu().numpy().tolist(), - } - ) + for idx, prob_tensor in enumerate(probabilities): + predictions.append( + { + "scores": prob_tensor.cpu().numpy().tolist(), + "logits": logits[idx].cpu().numpy().tolist(), + } + )Purely cosmetic; no behavior change.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
processing_services/example/.gitignore(1 hunks)processing_services/example/api/api.py(2 hunks)processing_services/example/api/base.py(1 hunks)processing_services/example/api/global_moth_classifier.py(1 hunks)processing_services/example/api/pipelines.py(2 hunks)processing_services/example/api/schemas.py(1 hunks)processing_services/example/api/utils.py(2 hunks)processing_services/example/docker-compose.yml(1 hunks)processing_services/example/requirements.txt(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (5)
processing_services/example/api/api.py (1)
processing_services/example/api/pipelines.py (1)
ZeroShotObjectDetectorWithGlobalMothClassifierPipeline(352-410)
processing_services/example/api/global_moth_classifier.py (3)
processing_services/example/api/base.py (8)
TimmResNet50Base(211-237)get_key(63-68)get_transforms(98-106)get_transforms(138-146)predict_batch(115-120)predict_batch(187-192)post_process_batch(122-127)post_process_batch(194-208)processing_services/example/api/schemas.py (5)
AlgorithmCategoryMapResponse(150-179)AlgorithmConfigResponse(182-208)AlgorithmReference(75-77)ClassificationResponse(80-101)Detection(138-147)processing_services/example/api/pipelines.py (6)
run(88-92)run(176-200)run(230-243)run(275-296)run(328-349)run(384-410)
processing_services/example/api/utils.py (2)
processing_services/example/api/schemas.py (1)
open(55-68)processing_services/example/api/algorithms.py (1)
get_best_device(24-33)
processing_services/example/api/base.py (1)
processing_services/example/api/utils.py (2)
get_best_device(96-104)get_or_download_file(42-93)
processing_services/example/api/pipelines.py (4)
processing_services/example/api/global_moth_classifier.py (4)
GlobalMothClassifier(27-249)algorithm_config_response(72-92)algorithm_config_response(95-97)run(133-215)processing_services/minimal/api/pipelines.py (4)
Pipeline(182-200)run(199-200)run(209-221)run(241-253)processing_services/example/api/schemas.py (3)
PipelineConfigResponse(300-308)PipelineResultsResponse(266-280)Detection(138-147)processing_services/example/api/algorithms.py (7)
Algorithm(36-66)ZeroShotObjectDetector(69-178)run(42-43)run(99-155)run(201-242)run(326-332)run(400-406)
🪛 Ruff (0.14.5)
processing_services/example/api/global_moth_classifier.py
148-148: Avoid specifying long messages outside the exception class
(TRY003)
177-177: Avoid specifying long messages outside the exception class
(TRY003)
189-189: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
processing_services/example/api/utils.py
77-77: Probable use of requests call without timeout
(S113)
processing_services/example/api/base.py
39-39: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
158-158: Avoid specifying long messages outside the exception class
(TRY003)
processing_services/example/api/pipelines.py
358-358: Mutable class attributes should be annotated with typing.ClassVar
(RUF012)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Agent
- GitHub Check: test
🔇 Additional comments (6)
processing_services/example/.gitignore (1)
1-4: Ignore rules align with new cache volumesIgnoring
cache/,huggingface_cache/, andpytorch_cache/is consistent with the new persistent cache usage and prevents large artifacts from being committed. Looks good.processing_services/example/docker-compose.yml (1)
5-9: Persistent cache volume wiring looks correctThe
./cache:/app/cachemapping matches the logic inget_or_download_file(preferring/app/cache) and will keep model weights/labels persistent across container restarts. No issues from the compose side.processing_services/example/requirements.txt (1)
7-10: Confirm timm version compatibility with pinned torch/torchvisionAdding
timm==1.0.11is appropriate for the newTimmResNet50Base/GlobalMothClassifierflow. Please just double‑check that this specific timm version is tested withtorch==2.6.0andtorchvision==0.21.0in your runtime environment (CUDA/CPU variants), and that the version pin matches the training environment for the shipped weights.processing_services/example/api/api.py (1)
9-16: New pipeline is correctly exposed via the APIThe
ZeroShotObjectDetectorWithGlobalMothClassifierPipelineimport and its inclusion in thepipelineslist are consistent with the new slug inPipelineChoiceand the pipeline’s ownconfig.slug. This makes the new two‑stage pipeline selectable through the existing/processand/infoflows without altering error handling.Also applies to: 40-46
processing_services/example/api/schemas.py (1)
211-217: PipelineChoice updated consistently with new pipeline slugThe new literal
"zero-shot-object-detector-with-global-moth-classifier-pipeline"matches the slug defined in the corresponding pipeline’sPipelineConfigResponse, keeping request/response typing in sync.processing_services/example/api/base.py (1)
70-107: Weights/labels loading and transforms are solid; only minor robustness nits
get_weights,get_labels, andget_transformsare straightforward and line up withget_or_download_file’s behavior. Two small considerations:
get_labelsassumes the JSON is a{label: index}mapping and inverts it. If any existing label files are already{index: label}, this will flip them the wrong way around. Worth double‑checking that all current label JSONs follow the expected schema.- Log messages in
get_weightsalways say “Downloading …” even whenweights_pathis a local path. If you care about log clarity, you could special‑case local paths, but this is cosmetic.
| def run(self, detections: list[Detection]) -> list[Detection]: | ||
| """ | ||
| Run classification on a list of detections. | ||
| Args: | ||
| detections: List of Detection objects with cropped images | ||
| Returns: | ||
| List of Detection objects with added classifications | ||
| """ | ||
| if not detections: | ||
| return [] | ||
|
|
||
| # Ensure model is compiled | ||
| if self.model is None: | ||
| raise RuntimeError("Model not compiled. Call compile() first.") | ||
|
|
||
| logger.info(f"Running {self.name} on {len(detections)} detections") | ||
|
|
||
| # Process detections in batches | ||
| classified_detections = [] | ||
|
|
||
| for i in range(0, len(detections), self.batch_size): | ||
| batch_detections = detections[i : i + self.batch_size] | ||
| batch_images = [] | ||
|
|
||
| # Prepare batch of images | ||
| for detection in batch_detections: | ||
| if detection._pil: | ||
| # Convert to RGB if needed | ||
| if detection._pil.mode != "RGB": | ||
| img = detection._pil.convert("RGB") | ||
| else: | ||
| img = detection._pil | ||
| batch_images.append(img) | ||
| else: | ||
| logger.warning(f"Detection {detection.id} has no PIL image") | ||
| continue | ||
|
|
||
| if not batch_images: | ||
| continue | ||
|
|
||
| # Transform images | ||
| if self.transforms is None: | ||
| raise RuntimeError("Transforms not initialized. Call compile() first.") | ||
| batch_tensor = torch.stack([self.transforms(img) for img in batch_images]) | ||
|
|
||
| # Run inference | ||
| start_time = datetime.datetime.now() | ||
| predictions = self.predict_batch(batch_tensor) | ||
| processed_predictions = self.post_process_batch(predictions) | ||
| end_time = datetime.datetime.now() | ||
|
|
||
| inference_time = (end_time - start_time).total_seconds() / len(batch_images) | ||
|
|
||
| # Add classifications to detections | ||
| for detection, prediction in zip(batch_detections, processed_predictions): | ||
| # Get best prediction | ||
| best_score = max(prediction["scores"]) | ||
| best_idx = prediction["scores"].index(best_score) | ||
| best_label = self.category_map.get(best_idx, f"class_{best_idx}") | ||
|
|
||
| classification = ClassificationResponse( | ||
| classification=best_label, | ||
| labels=[best_label], | ||
| scores=[best_score], | ||
| logits=prediction["logits"], | ||
| inference_time=inference_time, | ||
| timestamp=datetime.datetime.now(), | ||
| algorithm=AlgorithmReference( | ||
| name=self.name, | ||
| key=self.get_key(), | ||
| ), | ||
| terminal=True, | ||
| ) | ||
|
|
||
| # Add classification to detection | ||
| detection_with_classification = detection.copy(deep=True) | ||
| detection_with_classification.classifications = [classification] | ||
| classified_detections.append(detection_with_classification) | ||
|
|
||
| logger.info(f"Classified {len(classified_detections)} detections") | ||
| return classified_detections |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix detection/prediction alignment in run() and use zip(..., strict=True)
In run(), detections are filtered into batch_images based on detection._pil, but you still zip over the original batch_detections:
batch_detections = detections[i : i + self.batch_size]
batch_images = []
...
for detection in batch_detections:
if detection._pil:
...
batch_images.append(img)
else:
logger.warning(...)
continue
...
processed_predictions = self.post_process_batch(predictions)
...
for detection, prediction in zip(batch_detections, processed_predictions):
...If any Detection in the batch lacks a PIL image, batch_images and processed_predictions will have fewer elements than batch_detections, and predictions will be paired with the wrong detections (or the wrong ones will get classified). That’s a silent correctness bug.
You can fix this and address the B905 warning by tracking only valid detections and zipping with strict=True:
- for i in range(0, len(detections), self.batch_size):
- batch_detections = detections[i : i + self.batch_size]
- batch_images = []
+ for i in range(0, len(detections), self.batch_size):
+ batch_detections = detections[i : i + self.batch_size]
+ batch_images = []
+ valid_detections: list[Detection] = []
@@
- # Prepare batch of images
- for detection in batch_detections:
- if detection._pil:
+ # Prepare batch of images
+ for detection in batch_detections:
+ if detection._pil:
@@
- else:
- img = detection._pil
- batch_images.append(img)
- else:
- logger.warning(f"Detection {detection.id} has no PIL image")
- continue
+ else:
+ img = detection._pil
+ batch_images.append(img)
+ valid_detections.append(detection)
+ else:
+ logger.warning("Detection %s has no PIL image", detection.id)
+ continue
@@
- # Add classifications to detections
- for detection, prediction in zip(batch_detections, processed_predictions):
+ # Add classifications to detections
+ for detection, prediction in zip(valid_detections, processed_predictions, strict=True):
@@
- detection_with_classification = detection.copy(deep=True)
+ detection_with_classification = detection.copy(deep=True)
detection_with_classification.classifications = [classification]
classified_detections.append(detection_with_classification)This ensures:
- Only detections that actually contributed an image get classifications.
- Any mismatch between the number of valid detections and predictions raises immediately instead of silently misaligning outputs.
🧰 Tools
🪛 Ruff (0.14.5)
148-148: Avoid specifying long messages outside the exception class
(TRY003)
177-177: Avoid specifying long messages outside the exception class
(TRY003)
189-189: zip() without an explicit strict= parameter
Add explicit value for parameter strict=
(B905)
🤖 Prompt for AI Agents
In processing_services/example/api/global_moth_classifier.py around lines
133-215, the run() loop builds batch_images by skipping detections with no PIL
image but later zips processed_predictions with the original batch_detections,
causing misalignment; change the code to collect a parallel list of
valid_detections (append detections only when you append their image), use
valid_detections for post-inference pairing, compute inference_time using
len(valid_detections), and replace zip(... ) with zip(valid_detections,
processed_predictions, strict=True) so any mismatch raises immediately while
preserving the existing logger.warning for skipped detections.
Summary by CodeRabbit