diff --git a/bats_ai/core/admin/__init__.py b/bats_ai/core/admin/__init__.py index fe5a419e..e3212716 100644 --- a/bats_ai/core/admin/__init__.py +++ b/bats_ai/core/admin/__init__.py @@ -11,6 +11,7 @@ NABatSpectrogramAdmin, ) from .processing_task import ProcessingTaskAdmin +from .pulse_annotation import ComputedPulseAnnotationAdmin from .recording import RecordingAdmin from .recording_annotations import RecordingAnnotationAdmin from .recording_tag import RecordingTagAdmin @@ -18,6 +19,7 @@ from .species import SpeciesAdmin from .spectrogram import SpectrogramAdmin from .spectrogram_image import SpectrogramImageAdmin +from .spectrogram_svg import SpectrogramSvgAdmin __all__ = [ 'AnnotationsAdmin', @@ -34,9 +36,11 @@ 'ConfigurationAdmin', 'ExportedAnnotationFileAdmin', 'SpectrogramImageAdmin', + 'SpectrogramSvgAdmin', # NABat Models 'NABatRecordingAnnotationAdmin', 'NABatCompressedSpectrogramAdmin', 'NABatSpectrogramAdmin', 'NABatRecordingAdmin', + 'ComputedPulseAnnotationAdmin', ] diff --git a/bats_ai/core/admin/pulse_annotation.py b/bats_ai/core/admin/pulse_annotation.py new file mode 100644 index 00000000..6bb409fc --- /dev/null +++ b/bats_ai/core/admin/pulse_annotation.py @@ -0,0 +1,13 @@ +from django.contrib import admin + +from bats_ai.core.models import ComputedPulseAnnotation + + +@admin.register(ComputedPulseAnnotation) +class ComputedPulseAnnotationAdmin(admin.ModelAdmin): + list_display = [ + 'id', + 'recording', + 'bounding_box', + ] + list_select_related = True diff --git a/bats_ai/core/admin/spectrogram_svg.py b/bats_ai/core/admin/spectrogram_svg.py new file mode 100644 index 00000000..26ad1ce1 --- /dev/null +++ b/bats_ai/core/admin/spectrogram_svg.py @@ -0,0 +1,22 @@ +from django.contrib import admin + +from bats_ai.core.models import SpectrogramSvg + + +@admin.register(SpectrogramSvg) +class SpectrogramSvgAdmin(admin.ModelAdmin): + list_display = [ + 'pk', + 'content_type', + 'object_id', + 'index', + 'image_file', + ] + list_select_related = True + readonly_fields = [ + 'pk', + 'content_type', + 'object_id', + 'index', + 'image_file', + ] diff --git a/bats_ai/core/migrations/0024_spectrogramsvg_computedpulseannotation.py b/bats_ai/core/migrations/0024_spectrogramsvg_computedpulseannotation.py new file mode 100644 index 00000000..22e17053 --- /dev/null +++ b/bats_ai/core/migrations/0024_spectrogramsvg_computedpulseannotation.py @@ -0,0 +1,41 @@ +# Generated by Django 4.2.23 on 2025-12-08 22:19 + +import bats_ai.core.models.spectrogram_vector +import django.contrib.gis.db.models.fields +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('contenttypes', '0002_remove_content_type_name'), + ('core', '0023_recordingtag_recording_tags_and_more'), + ] + + operations = [ + migrations.CreateModel( + name='SpectrogramSvg', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('object_id', models.PositiveIntegerField()), + ('type', models.CharField(choices=[('spectrogram', 'Spectrogram'), ('compressed', 'Compressed')], default='spectrogram', max_length=20)), + ('index', models.PositiveIntegerField()), + ('image_file', models.FileField(upload_to=bats_ai.core.models.spectrogram_vector.spectrogram_svg_upload_to)), + ('content_type', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='contenttypes.contenttype')), + ], + options={ + 'ordering': ['index'], + }, + ), + migrations.CreateModel( + name='ComputedPulseAnnotation', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('index', models.IntegerField()), + ('bounding_box', django.contrib.gis.db.models.fields.PolygonField(srid=4326)), + ('contours', models.JSONField()), + ('recording', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='core.recording')), + ], + ), + ] diff --git a/bats_ai/core/models/__init__.py b/bats_ai/core/models/__init__.py index ad11fab9..1b2b89a9 100644 --- a/bats_ai/core/models/__init__.py +++ b/bats_ai/core/models/__init__.py @@ -5,6 +5,7 @@ from .grts_cells import GRTSCells from .image import Image from .processing_task import ProcessingTask, ProcessingTaskType +from .pulse_annotation import ComputedPulseAnnotation from .recording import Recording, RecordingTag from .recording_annotation import RecordingAnnotation from .recording_annotation_status import RecordingAnnotationStatus @@ -12,6 +13,7 @@ from .species import Species from .spectrogram import Spectrogram from .spectrogram_image import SpectrogramImage +from .spectrogram_vector import SpectrogramSvg __all__ = [ 'Annotations', @@ -30,4 +32,6 @@ 'ProcessingTaskType', 'ExportedAnnotationFile', 'SpectrogramImage', + 'SpectrogramSvg', + 'ComputedPulseAnnotation', ] diff --git a/bats_ai/core/models/compressed_spectrogram.py b/bats_ai/core/models/compressed_spectrogram.py index f4fa734a..90a01573 100644 --- a/bats_ai/core/models/compressed_spectrogram.py +++ b/bats_ai/core/models/compressed_spectrogram.py @@ -9,6 +9,7 @@ from .recording import Recording from .spectrogram import Spectrogram from .spectrogram_image import SpectrogramImage +from .spectrogram_vector import SpectrogramSvg # TimeStampedModel also provides "created" and "modified" fields @@ -17,6 +18,7 @@ class CompressedSpectrogram(TimeStampedModel, models.Model): spectrogram = models.ForeignKey(Spectrogram, on_delete=models.CASCADE) length = models.IntegerField() images = GenericRelation(SpectrogramImage) + vector_images = GenericRelation(SpectrogramSvg) starts = ArrayField(ArrayField(models.IntegerField())) stops = ArrayField(ArrayField(models.IntegerField())) widths = ArrayField(ArrayField(models.IntegerField())) @@ -28,6 +30,11 @@ def image_url_list(self): images = self.images.filter(type='compressed').order_by('index') return [default_storage.url(img.image_file.name) for img in images] + @property + def vector_url_list(self): + images = self.vector_images.filter(type='compressed').order_by('index') + return [default_storage.url(img.image_file.name) for img in images] + @property def image_pil_list(self): """List of PIL images in order.""" diff --git a/bats_ai/core/models/pulse_annotation.py b/bats_ai/core/models/pulse_annotation.py new file mode 100644 index 00000000..fbe519fb --- /dev/null +++ b/bats_ai/core/models/pulse_annotation.py @@ -0,0 +1,10 @@ +from django.contrib.gis.db import models + +from .recording import Recording + + +class ComputedPulseAnnotation(models.Model): + recording = models.ForeignKey(Recording, on_delete=models.CASCADE) + index = models.IntegerField(null=False, blank=False) + bounding_box = models.PolygonField(null=False, blank=False) + contours = models.JSONField() diff --git a/bats_ai/core/models/spectrogram.py b/bats_ai/core/models/spectrogram.py index 1d200878..574eb8f4 100644 --- a/bats_ai/core/models/spectrogram.py +++ b/bats_ai/core/models/spectrogram.py @@ -7,11 +7,13 @@ from .recording import Recording from .spectrogram_image import SpectrogramImage +from .spectrogram_vector import SpectrogramSvg class Spectrogram(TimeStampedModel, models.Model): recording = models.ForeignKey(Recording, on_delete=models.CASCADE) images = GenericRelation(SpectrogramImage) + vector_images = GenericRelation(SpectrogramSvg) width = models.IntegerField() # pixels height = models.IntegerField() # pixels duration = models.IntegerField() # milliseconds @@ -24,6 +26,11 @@ def image_url_list(self): images = self.images.filter(type='spectrogram').order_by('index') return [default_storage.url(img.image_file.name) for img in images] + @property + def vector_url_list(self): + images = self.vector_images.filter(type='spectrogram').order_by('index') + return [default_storage.url(img.image_file.name) for img in images] + @property def image_pil_list(self): """List of PIL images in order.""" diff --git a/bats_ai/core/models/spectrogram_vector.py b/bats_ai/core/models/spectrogram_vector.py new file mode 100644 index 00000000..f4f47a1e --- /dev/null +++ b/bats_ai/core/models/spectrogram_vector.py @@ -0,0 +1,43 @@ +from django.contrib.contenttypes.fields import GenericForeignKey +from django.contrib.contenttypes.models import ContentType +from django.db import models +from django.dispatch import receiver + + +def spectrogram_svg_upload_to(instance, filename): + related = instance.content_object + + recording = getattr(related, 'recording', None) or getattr(related, 'nabat_recording', None) + recording_id = getattr(recording, 'id', None) + + if not recording_id: + raise ValueError('Related content must have a recording or nabat_recording.') + + return f'recording_{recording_id}/{instance.type}/svg_{instance.index}_{filename}' + + +class SpectrogramSvg(models.Model): + SPECTROGRAM_TYPE_CHOICES = [ + ('spectrogram', 'Spectrogram'), + ('compressed', 'Compressed'), + ] + content_object = GenericForeignKey('content_type', 'object_id') + + content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE) + object_id = models.PositiveIntegerField() + type = models.CharField( + max_length=20, + choices=SPECTROGRAM_TYPE_CHOICES, + default='spectrogram', + ) + index = models.PositiveIntegerField() + image_file = models.FileField(upload_to=spectrogram_svg_upload_to) + + class Meta: + ordering = ['index'] + + +@receiver(models.signals.pre_delete, sender=SpectrogramSvg) +def delete_content(sender, instance, **kwargs): + if instance.image_file: + instance.image_file.delete(save=False) diff --git a/bats_ai/core/views/recording.py b/bats_ai/core/views/recording.py index d4b5552f..997a111f 100644 --- a/bats_ai/core/views/recording.py +++ b/bats_ai/core/views/recording.py @@ -1,7 +1,7 @@ from datetime import datetime import json import logging -from typing import List, Optional +from typing import Any, List, Optional from django.contrib.auth.models import User from django.contrib.gis.geos import Point @@ -16,6 +16,7 @@ from bats_ai.core.models import ( Annotations, CompressedSpectrogram, + ComputedPulseAnnotation, Recording, RecordingAnnotation, RecordingTag, @@ -127,6 +128,22 @@ class UpdateAnnotationsSchema(Schema): id: int | None +class ComputedPulseAnnotationSchema(Schema): + id: int | None + index: int + bounding_box: Any + contours: list + + @classmethod + def from_orm(cls, obj: ComputedPulseAnnotation): + return cls( + id=obj.id, + index=obj.index, + contours=obj.contours, + bounding_box=json.loads(obj.bounding_box.geojson) + ) + + @router.post('/') def create_recording( request: HttpRequest, @@ -373,6 +390,7 @@ def get_spectrogram(request: HttpRequest, id: int): spectro_data = { 'urls': spectrogram.image_url_list, + 'vectors': spectrogram.vector_url_list, 'spectroInfo': { 'spectroId': spectrogram.pk, 'width': spectrogram.width, @@ -443,6 +461,7 @@ def get_spectrogram_compressed(request: HttpRequest, id: int): spectro_data = { 'urls': compressed_spectrogram.image_url_list, + 'vectors': compressed_spectrogram.vector_url_list, 'spectroInfo': { 'spectroId': compressed_spectrogram.pk, 'width': compressed_spectrogram.spectrogram.width, @@ -526,6 +545,26 @@ def get_annotations(request: HttpRequest, id: int): return {'error': 'Recording not found'} +@router.get('/{id}/pulse_data') +def get_pulse_data(request: HttpRequest, id: int): + try: + recording = Recording.objects.get(pk=id) + if recording.owner == request.user or recording.public: + computed_pulse_annotation_qs = ComputedPulseAnnotation.objects.filter( + recording=recording + ).order_by('index') + return [ + ComputedPulseAnnotationSchema.from_orm(pulse) + for pulse in computed_pulse_annotation_qs.all() + ] + else: + return { + 'error': 'Permission denied. You do not own this recording, and it is not public.' + } + except Recording.DoesNotExist: + return {'error': 'Recording not found'} + + @router.get('/{id}/annotations/other_users') def get_other_user_annotations(request: HttpRequest, id: int): try: diff --git a/bats_ai/tasks/tasks.py b/bats_ai/tasks/tasks.py index 33a2783f..91cbf189 100644 --- a/bats_ai/tasks/tasks.py +++ b/bats_ai/tasks/tasks.py @@ -5,16 +5,19 @@ from PIL import Image from celery import shared_task from django.contrib.contenttypes.models import ContentType +from django.contrib.gis.geos import Polygon from django.core.files import File from bats_ai.core.models import ( CompressedSpectrogram, + ComputedPulseAnnotation, Configuration, Recording, RecordingAnnotation, Species, Spectrogram, SpectrogramImage, + SpectrogramSvg, ) from bats_ai.utils.spectrogram_utils import generate_spectrogram_assets, predict_from_compressed @@ -59,6 +62,18 @@ def recording_compute_spectrogram(recording_id: int): }, ) + for idx, svg_path in enumerate(results['normal']['vectors']): + with open(svg_path, 'rb') as f: + SpectrogramSvg.objects.get_or_create( + content_type=ContentType.objects.get_for_model(spectrogram), + object_id=spectrogram.id, + index=idx, + defaults={ + 'type': 'spectrogram', + 'image_file': File(f, name=os.path.basename(svg_path)), + }, + ) + # Create or get CompressedSpectrogram compressed = results['compressed'] compressed_obj, _ = CompressedSpectrogram.objects.get_or_create( @@ -86,6 +101,58 @@ def recording_compute_spectrogram(recording_id: int): }, ) + for idx, svg_path in enumerate(compressed['vectors']): + with open(svg_path, 'rb') as f: + SpectrogramSvg.objects.get_or_create( + content_type=ContentType.objects.get_for_model(compressed_obj), + object_id=compressed_obj.id, + index=idx, + defaults={ + 'image_file': File(f, name=os.path.basename(svg_path)), + 'type': 'compressed', + }, + ) + + # Generate computed annotations for contours + logger.info( + "Adding contour and bounding boxes for " + f"{len(results.get('contours', []))} pulses" + ) + for idx, contour in enumerate(results.get('contours', [])): + # Transform contour (x, y) pairs into (time, freq) pairs + widths, starts, stops = compressed['widths'], compressed['starts'], compressed['stops'] + start_time = starts[idx] + end_time = stops[idx] + width = widths[idx] + time_per_pixel = (end_time - start_time) / width + mhz_per_pixel = (results['freq_max'] - results['freq_min']) / compressed['height'] + transformed_lines = [] + for contour_line in contour: + new_curve = [ + [ + point[0] * time_per_pixel + start_time, + results['freq_max'] - (point[1] * mhz_per_pixel) + ] + for point in contour_line["curve"] + ] + transformed_lines.append({ + "curve": new_curve, + "level": contour_line["level"], + "index": idx + }) + ComputedPulseAnnotation.objects.get_or_create( + index=idx, + recording=recording, + contours=transformed_lines, + bounding_box=Polygon(( + (start_time, results['freq_max']), + (end_time, results['freq_max']), + (end_time, results['freq_min']), + (start_time, results['freq_min']), + (start_time, results['freq_max']), + )), + ) + config = Configuration.objects.first() if config and config.run_inference_on_upload: predict_results = predict_from_compressed(compressed_obj) diff --git a/bats_ai/utils/contour_utils.py b/bats_ai/utils/contour_utils.py new file mode 100644 index 00000000..421094e0 --- /dev/null +++ b/bats_ai/utils/contour_utils.py @@ -0,0 +1,361 @@ +import logging + +import cv2 +import numpy as np +from scipy.ndimage import gaussian_filter1d +from skimage import measure +from skimage.filters import threshold_multiotsu +import svgwrite + +logger = logging.getLogger(__name__) + + +# This function computes the contour levels based on the selected mode. +def auto_histogram_levels( + data: np.ndarray, + bins: int = 512, + smooth_sigma: float = 2.0, + variance_threshold: float = 400.0, + max_levels: int = 5, +) -> list[float]: + """Select intensity levels by grouping histogram bins until variance exceeds a threshold.""" + if data.size == 0: + return [] + + hist, edges = np.histogram(data, bins=bins) + counts = gaussian_filter1d(hist.astype(np.float64), sigma=smooth_sigma) + centers = (edges[:-1] + edges[1:]) / 2.0 + + mask = counts > 0 + counts = counts[mask] + centers = centers[mask] + + if counts.size == 0: + return [] + + groups = [] + current_centers = [] + current_weights = [] + + for center, weight in zip(centers, counts): + weight = max(float(weight), 1e-9) + current_centers.append(center) + current_weights.append(weight) + + values = np.array(current_centers, dtype=np.float64) + weights = np.array(current_weights, dtype=np.float64) + mean = np.average(values, weights=weights) + variance = np.average((values - mean) ** 2, weights=weights) + + if variance > variance_threshold and len(current_centers) > 1: + last_center = current_centers.pop() + last_weight = current_weights.pop() + + values = np.array(current_centers, dtype=np.float64) + weights = np.array(current_weights, dtype=np.float64) + if weights.sum() > 0: + grouped_mean = np.average(values, weights=weights) + groups.append(grouped_mean) + + current_centers = [last_center] + current_weights = [last_weight] + + if current_centers: + values = np.array(current_centers, dtype=np.float64) + weights = np.array(current_weights, dtype=np.float64) + grouped_mean = np.average(values, weights=weights) + groups.append(grouped_mean) + + groups = sorted(set(groups)) + + if len(groups) <= 1: + return groups + + groups = groups[1:] + + if max_levels is not None and len(groups) > max_levels: + indices = np.linspace(0, len(groups) - 1, max_levels, dtype=int) + groups = [groups[i] for i in indices] + + def subdivide_high_end(levels: list[float]) -> list[float]: + if len(levels) < 2: + return levels + gaps = np.diff(levels) + largest_gap_idx = int(np.argmax(gaps)) + remaining_slots = ( + max(0, max_levels - len(levels)) if max_levels is not None else len(levels) + ) + subdivisions = min(remaining_slots, 2) if remaining_slots > 0 else 0 + subdivided = [] + if subdivisions > 0: + if largest_gap_idx == len(levels) - 1: + low = levels[-2] + high = levels[-1] + stride = (high - low) / (subdivisions + 1) + subdivided = [low + stride * (i + 1) for i in range(subdivisions)] + levels = levels[:-1] + subdivided + [levels[-1]] + return sorted(levels) + + return subdivide_high_end(groups) + + +def compute_auto_levels( + data: np.ndarray, + mode: str, + percentile_values, + multi_otsu_classes: int, + min_intensity: float, + hist_bins: int = 512, + hist_sigma: float = 2.0, + hist_variance_threshold: float = 400.0, + hist_max_levels: int = 5, +) -> list[float]: + """Compute contour levels based on selected mode.""" + percentile_values = list(percentile_values) + percentile_values.sort() + + valid = data[data >= min_intensity] + if valid.size == 0: + return [] + + if mode == 'multi-otsu': + try: + thresholds = threshold_multiotsu(valid, classes=multi_otsu_classes) + return thresholds.tolist() + except Exception: + # Fallback to simple percentiles if multi-otsu fails + if len(percentile_values) == 0: + return [] + return np.percentile(valid, percentile_values).tolist() + elif mode == 'histogram': + return auto_histogram_levels( + valid, + bins=hist_bins, + smooth_sigma=hist_sigma, + variance_threshold=hist_variance_threshold, + max_levels=hist_max_levels, + ) + else: # percentile mode + if len(percentile_values) == 0: + return [] + return np.percentile(valid, percentile_values).tolist() + + +# This function computes the area of a polygon. +def polygon_area(points: np.ndarray) -> float: + """Return absolute area of a closed polygon given as Nx2 array.""" + if len(points) < 3: + return 0.0 + x = points[:, 0] + y = points[:, 1] + return 0.5 * np.abs(np.dot(x, np.roll(y, -1)) - np.dot(y, np.roll(x, -1))) + + +# This function smooths a contour using spline interpolation. +def smooth_contour_spline(contour, smoothing_factor=0.1): + """Smooth contour using spline interpolation""" + # Reshape contour + if contour.ndim != 2 or contour.shape[1] != 2: + if contour.size % 2 == 0: + contour = contour.reshape(-1, 2) + else: + logger.warning(f'Invalid contour shape: {contour.shape}') + # contour = contour.reshape(-1, 2) + + # Close the contour by adding first point at end + if not np.array_equal(contour[0], contour[-1]): + contour = np.vstack([contour, contour[0]]) + + # Calculate cumulative distance along contour + distances = np.cumsum(np.sqrt(np.sum(np.diff(contour, axis=0) ** 2, axis=1))) + distances = np.insert(distances, 0, 0) + + # Interpolate using splines + from scipy import interpolate + + # Create periodic spline + num_points = max(len(contour), 100) + alpha = np.linspace(0, 1, num_points) + + # Fit spline + try: + tck, u = interpolate.splprep( + [contour[:, 0], contour[:, 1]], s=len(contour) * smoothing_factor, per=True + ) + x_smooth, y_smooth = interpolate.splev(alpha, tck) + smooth_contour = np.column_stack([x_smooth, y_smooth]) + except Exception as e: + # Fallback to simple smoothing if spline fails + logger.info(f'Spline fitting failed {e}. Falling back to simple smoothing.') + smooth_contour = contour + + return smooth_contour + + +# This function saves the contours to an SVG file. +def save_contours_to_svg( + contours_with_levels, + output_path, + image_shape, + reference_image=None, + fill_opacity=0.6, + stroke_opacity=0.9, + stroke_width=1.0, + draw_stroke=True, + sample_shrink_px=3, + sample_radius=5, +): + """Save contours to SVG with filled shapes (optionally matching image colors).""" + height, width = image_shape[:2] + dwg = svgwrite.Drawing(output_path, size=(width, height)) + + # Default palette if no image supplied + colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#F7B267', '#CDB4DB'] + + if reference_image is not None and reference_image.shape[:2] != (height, width): + raise ValueError("reference_image shape does not match image_shape") + + def color_from_image(points, fallback_color): + if reference_image is None: + return fallback_color + polygon = np.round(points).astype(np.int32) + polygon[:, 0] = np.clip(polygon[:, 0], 0, width - 1) + polygon[:, 1] = np.clip(polygon[:, 1], 0, height - 1) + mask = np.zeros((height, width), dtype=np.uint8) + cv2.fillPoly(mask, [polygon], (255,)) + + eroded = mask.copy() + if sample_shrink_px > 0: + kernel_size = sample_shrink_px * 2 + 1 + kernel = np.ones((kernel_size, kernel_size), dtype=np.uint8) + eroded = cv2.erode(mask, kernel, iterations=1) + if not np.count_nonzero(eroded): + eroded = mask + + dist = cv2.distanceTransform(eroded, cv2.DIST_L2, 5) + _, max_val, _, max_loc = cv2.minMaxLoc(dist) + + if max_val <= 0: + region = reference_image[mask == 255] + if region.size == 0: + return fallback_color + mean_bgr = region.mean(axis=0) + else: + cx, cy = max_loc[0], max_loc[1] + x0 = max(cx - sample_radius, 0) + x1 = min(cx + sample_radius + 1, width) + y0 = max(cy - sample_radius, 0) + y1 = min(cy + sample_radius + 1, height) + patch = reference_image[y0:y1, x0:x1] + patch_mask = eroded[y0:y1, x0:x1] + region = patch[patch_mask > 0] + if region.size == 0: + region = reference_image[mask == 255] + mean_bgr = region.mean(axis=0) + + r, g, b = [int(np.clip(c, 0, 255)) for c in mean_bgr[::-1]] + return f"#{r:02X}{g:02X}{b:02X}" + + # Draw lower levels first so higher ones sit on top + contours_with_levels_sorted = sorted(contours_with_levels, key=lambda x: x[1]) + logger.info(f'Sorted contours length: {len(contours_with_levels_sorted)}') + + for i, (contour, level) in enumerate(contours_with_levels_sorted): + # logger.info(f'Attempting to add path for level {level}') + pts = contour.tolist() + if len(pts) < 3: + continue + + # Build a simple closed path (straight segments). Beziers look nice for strokes + # but can self-intersect when filled; straight segments are safer for fills. + d = [f"M {pts[0][0]},{pts[0][1]}"] + for j in range(1, len(pts)): + d.append(f"L {pts[j][0]},{pts[j][1]}") + d.append("Z") + path_data = " ".join(d) + + fallback = colors[i % len(colors)] + fill_color = color_from_image(np.array(pts), fallback) + + path = dwg.path( + d=path_data, + fill=fill_color, + fill_opacity=fill_opacity, + stroke=fill_color if draw_stroke else 'none', + stroke_opacity=stroke_opacity, + stroke_width=stroke_width, + ) + + # Helps when there are holes; keeps visual sane without hierarchy bookkeeping + path.update({'fill-rule': 'evenodd'}) + + dwg.add(path) + + dwg.save() + logger.info(f"Saved smooth filled contours to {output_path}") + + +def extract_marching_squares_contours( + image_path, + output_path='marching_squares.svg', + levels=None, + gaussian_kernel=(15, 15), + gaussian_sigma=3, + min_area=500, + smoothing_factor=0.08, + levels_mode='percentile', + percentile_values=(90, 95, 98), + min_intensity=1.0, + multi_otsu_classes=4, + hist_bins=512, + hist_sigma=2.0, + hist_variance_threshold=400.0, + hist_max_levels=5, + save_to_file=True, + verbose=True, +): + """Extract contours using marching squares (skimage.find_contours).""" + img = cv2.imread(image_path) + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + + blurred = cv2.GaussianBlur(gray, gaussian_kernel, gaussian_sigma) + + if levels is None: + mask = blurred > 0 + if not np.any(mask): + return [] + levels = compute_auto_levels( + blurred[mask], + mode=levels_mode, + percentile_values=percentile_values, + multi_otsu_classes=multi_otsu_classes, + min_intensity=min_intensity, + hist_bins=hist_bins, + hist_sigma=hist_sigma, + hist_variance_threshold=hist_variance_threshold, + hist_max_levels=hist_max_levels, + ) + if verbose: + logger.info(f"Marching squares levels ({levels_mode}): {levels}") + + marching_contours = [] + + for level in levels: + raw_contours = measure.find_contours(blurred, level=level) + for contour in raw_contours: + # skimage returns (row, col); flip to (x, y) + contour_xy = contour[:, ::-1] + if not np.array_equal(contour_xy[0], contour_xy[-1]): + contour_xy = np.vstack([contour_xy, contour_xy[0]]) + + if polygon_area(contour_xy) < min_area: + continue + + smooth = smooth_contour_spline(contour_xy, smoothing_factor=smoothing_factor) + marching_contours.append((smooth, level)) + + if marching_contours and save_to_file: + logger.info(f'Saving contours to {output_path}') + save_contours_to_svg(marching_contours, output_path, img.shape, reference_image=img) + + return sorted(marching_contours, key=lambda x: x[1], reverse=True) diff --git a/bats_ai/utils/spectrogram_utils.py b/bats_ai/utils/spectrogram_utils.py index 441b9fcc..cd2ca2c6 100644 --- a/bats_ai/utils/spectrogram_utils.py +++ b/bats_ai/utils/spectrogram_utils.py @@ -4,6 +4,7 @@ import math import os from pathlib import Path +import tempfile from typing import TypedDict from PIL import Image @@ -20,6 +21,8 @@ from bats_ai.core.models import CompressedSpectrogram from bats_ai.core.models.nabat import NABatCompressedSpectrogram +from .contour_utils import extract_marching_squares_contours + logger = logging.getLogger(__name__) FREQ_MIN = 5e3 @@ -29,12 +32,14 @@ class SpectrogramAssetResult(TypedDict): paths: list[str] + vectors: list[str] width: int height: int class SpectrogramCompressedAssetResult(TypedDict): paths: list[str] + vectors: list[str] width: int height: int widths: list[float] @@ -42,12 +47,18 @@ class SpectrogramCompressedAssetResult(TypedDict): stops: list[float] +class Contour(TypedDict): + curve: list[list[int | float]] + level: int | float + + class SpectrogramAssets(TypedDict): duration: float freq_min: int freq_max: int normal: SpectrogramAssetResult compressed: SpectrogramCompressedAssetResult + contours: list[list[Contour]] class PredictionOutput(TypedDict): @@ -226,10 +237,18 @@ def generate_spectrogram_assets( os.path.splitext(os.path.basename(output_base))[0] + '_spectrogram', ) os.makedirs(os.path.dirname(normal_out_path_base), exist_ok=True) - normal_paths = save_img(normal_img_resized, normal_out_path_base) + normal_paths, vector_paths = save_img(normal_img_resized, normal_out_path_base) real_duration = math.ceil(duration * 1e3) - compressed_img, compressed_paths, widths, starts, stops = generate_compressed( - normal_img_resized, real_duration, output_base + ( + compressed_img, + compressed_paths, + compressed_vector_paths, + widths, + starts, + stops, + contours, + ) = ( + generate_compressed(normal_img_resized, real_duration, output_base) ) result = { @@ -238,11 +257,13 @@ def generate_spectrogram_assets( 'freq_max': freq_high, 'normal': { 'paths': normal_paths, + 'vectors': vector_paths, 'width': normal_img_resized.shape[1], 'height': normal_img_resized.shape[0], }, 'compressed': { 'paths': compressed_paths, + 'vectors': compressed_vector_paths, 'width': compressed_img.shape[1], 'height': compressed_img.shape[0], 'widths': widths, @@ -250,10 +271,35 @@ def generate_spectrogram_assets( 'stops': stops, }, } + if contours: + result["contours"] = contours return result +def generate_pulse_contours(segments: list[np.ndarray], widths: list): + logger.info(f"Generating pulse contours for {len(segments)} pulses") + contours = [] + with tempfile.TemporaryDirectory() as tmpdir: + for index, segment in enumerate(segments): + # Save the NDArray as a file in the tempdir + out_img = Image.fromarray(segment, "RGB") + segment_path = f"{tmpdir}/{index}.jpg" + out_img.save(segment_path, format="JPEG", optimize=True, quality=80) + # Generate marching square contours from temp file + np_contours = extract_marching_squares_contours( + segment_path, + "", + save_to_file=False + ) + logger.info(f"Generated {len(np_contours)} for pulse {index}") + segment_contours = [ + {"curve": c[0].tolist(), "level": c[1]} for c in np_contours + ] + contours.append(segment_contours) + return contours + + def generate_compressed(img: np.ndarray, duration: float, output_base: str): threshold = 0.5 compressed_img = img.copy() @@ -337,7 +383,9 @@ def generate_compressed(img: np.ndarray, duration: float, output_base: str): segments.append(segment) widths.append(stop_clamped - start_clamped) + contours = [] if segments: + contours = generate_pulse_contours(segments, widths) compressed_img = np.hstack(segments) break @@ -359,9 +407,9 @@ def generate_compressed(img: np.ndarray, duration: float, output_base: str): compressed_out_path = os.path.join(out_folder, f'{base_name}_compressed') # save_img should be your existing function to save images and return file paths - paths = save_img(compressed_img, compressed_out_path) + paths, vector_paths = save_img(compressed_img, compressed_out_path) - return compressed_img, paths, widths, starts_time, stops_time + return compressed_img, paths, vector_paths, widths, starts_time, stops_time, contours def save_img(img: np.ndarray, output_base: str): @@ -374,6 +422,7 @@ def save_img(img: np.ndarray, output_base: str): ) total = len(chunks) output_paths = [] + output_svg_paths = [] for index, chunk in enumerate(chunks): out_path = f'{output_base}.{index + 1:02d}_of_{total:02d}.jpg' out_img = Image.fromarray(chunk, 'RGB') @@ -381,4 +430,12 @@ def save_img(img: np.ndarray, output_base: str): output_paths.append(out_path) logger.info(f'Saved image: {out_path}') - return output_paths + svg_path = f'{output_base}.{index + 1:02d}_of_{total:02d}.svg' + try: + extract_marching_squares_contours(out_path, svg_path) + output_svg_paths.append(svg_path) + logger.info(f'Saved SVG {svg_path}') + except Exception as e: + logger.error(f'Failed to create SVG for {out_path}. {e}') + + return output_paths, output_svg_paths diff --git a/client/src/api/api.ts b/client/src/api/api.ts index 3b13e655..9c0013ff 100644 --- a/client/src/api/api.ts +++ b/client/src/api/api.ts @@ -118,6 +118,7 @@ export interface UpdateFileAnnotation { export interface Spectrogram { urls: string[]; + vectors: string[]; filename?: string; annotations?: SpectrogramAnnotation[]; fileAnnotations: FileAnnotation[]; @@ -131,6 +132,7 @@ export interface Spectrogram { otherUsers?: UserInfo[]; } + export type OtherUserAnnotations = Record< string, { annotations: SpectrogramAnnotation[]; sequence: SpectrogramSequenceAnnotation[] } @@ -506,6 +508,23 @@ async function getExportStatus(exportId: number) { return result.data; } +export interface Contour { + curve: number[][]; + level: number; + index: number; +} + +export interface ComputedPulseAnnotation { + id: number; + index: number; + contours: Contour[]; +} + +async function getComputedPulseAnnotations(recordingId: number) { + const result = await axiosInstance.get(`/recording/${recordingId}/pulse_data`); + return result.data; +} + export { uploadRecordingFile, getRecordings, @@ -540,4 +559,5 @@ export { getFileAnnotationDetails, getExportStatus, getRecordingTags, + getComputedPulseAnnotations, }; diff --git a/client/src/components/SpectrogramViewer.vue b/client/src/components/SpectrogramViewer.vue index b4248dfc..3a7ff337 100644 --- a/client/src/components/SpectrogramViewer.vue +++ b/client/src/components/SpectrogramViewer.vue @@ -267,6 +267,7 @@ export default defineComponent({ :spectro-info="spectroInfo" :scaled-width="scaledWidth" :scaled-height="scaledHeight" + :recording-id="recordingId" @update:annotation="updateAnnotation($event)" @create:annotation="createAnnotation($event)" @set-cursor="setCursor($event)" diff --git a/client/src/components/ThumbnailViewer.vue b/client/src/components/ThumbnailViewer.vue index 716b66ec..5478eab8 100644 --- a/client/src/components/ThumbnailViewer.vue +++ b/client/src/components/ThumbnailViewer.vue @@ -165,6 +165,7 @@ export default defineComponent({ :spectro-info="spectroInfo" :scaled-width="scaledWidth" :scaled-height="scaledHeight" + :recording-id="recordingId" thumbnail @selected="$emit('selected',$event)" /> @@ -189,7 +190,7 @@ export default defineComponent({ margin:2px; &.geojs-map:focus { outline: none; - } + } } .playback-container { diff --git a/client/src/components/geoJS/LayerManager.vue b/client/src/components/geoJS/LayerManager.vue index 266952a0..f75dfe90 100644 --- a/client/src/components/geoJS/LayerManager.vue +++ b/client/src/components/geoJS/LayerManager.vue @@ -1,7 +1,7 @@