Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def shear_data_by_shifting_profiles(
"""
# Skip shear for angles smaller than ~0.1° (0.00175 rad)
if abs(angle_rad) <= 0.00175:
return depth_data.astype(np.floating).copy()
return depth_data

height, width = depth_data.shape
center_x, center_y = width / 2, height / 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,16 @@ def test_rotate_data_by_shifting_profiles():
assert np.std(max_positions) > 0


def test_shear_near_zero_angle_returns_copy():
Comment thread
SimoneAriens marked this conversation as resolved.
"""Shearing with near-zero angle returns a float64 copy unchanged."""
data = np.array([[1.0, 2.0], [3.0, 4.0]])
result = shear_data_by_shifting_profiles(
data, angle_rad=0.001, cut_y_after_shift=False
)
np.testing.assert_array_equal(result, data)
assert result.dtype == np.float64
Comment thread
SimoneAriens marked this conversation as resolved.


def test_detect_striation_angle():
"""Test gradient-based striation angle detection."""
np.random.seed(42)
Expand Down
203 changes: 203 additions & 0 deletions scripts/conversion_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
"""Shared utilities for conversion scripts."""

import json
import logging
from collections.abc import Callable, Iterable, Iterator
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
from typing import Any

from conversion.data_formats import MarkType
from tqdm import tqdm

logging.basicConfig(level=logging.WARNING, format="%(levelname)s: %(message)s")
logger = logging.getLogger(__name__)


@dataclass
class ConversionConfig:
"""Shared configuration for conversion pipelines."""

root: Path
output_dir: Path
api_url: str
force: bool = False

def __post_init__(self) -> None:
self.api_url = self.api_url.rstrip("/")
self.root = self.root.resolve()
self.output_dir = self.output_dir.resolve()


def run_parallel(
tasks: Iterable[tuple[Any, Callable, tuple]],
workers: int,
desc: str,
unit: str,
) -> dict[Any, Any]:
"""Run tasks in parallel with a progress bar.

:param tasks: iterable of ``(key, fn, args)`` tuples.
:param workers: number of parallel workers.
:param desc: progress bar description.
:param unit: progress bar unit label.
:returns: dict of ``{key: result}``.
"""
task_list = list(tasks)
with ThreadPoolExecutor(max_workers=workers) as pool:
futures = {pool.submit(fn, *args): key for key, fn, args in task_list}
results: dict[Any, Any] = {}
for future in tqdm(as_completed(futures), total=len(futures), desc=desc, unit=unit):
results[futures[future]] = future.result()
return results


_MARK_TYPE_FOLDER_MAP: list[tuple[str, MarkType]] = sorted(
((mt.value.replace(" ", "_"), mt) for mt in MarkType),
key=lambda x: -len(x[0]),
)


def infer_mark_type(folder_name: str) -> MarkType | None:
"""Infer a :class:`MarkType` from a folder name.

Handles suffixed variants (``_1``, ``_2``) and ``comparison_results`` folders.
"""
lower = folder_name.lower()
for fragment, mt in _MARK_TYPE_FOLDER_MAP:
if fragment in lower:
return mt
return None


def _parse_db_scratch(path: Path) -> dict[str, str]:
"""Parse a Java-properties-style db.scratch file.

:param path: path to the db.scratch file.
:returns: dict of key-value pairs (empty if file missing).
"""
if not path.exists():
return {}
props: dict[str, str] = {}
for line in path.read_text().splitlines():
line = line.strip() # noqa: PLW2901
if not line or line.startswith("#"):
continue
if "=" in line:
key, _, value = line.partition("=")
props[key.strip()] = value.strip().replace("\\:", ":")
return props


_tool_entries_root_cache: dict[Path, Path | None] = {}


def _get_tool_entries_root(output_dir: Path) -> Path | None:
"""Find (and cache) the parent of the ``tool-entries`` folder under *output_dir*."""
if output_dir not in _tool_entries_root_cache:
_tool_entries_root_cache[output_dir] = next(
(c.parent for c in output_dir.rglob("tool-entries") if c.is_dir()), None
)
return _tool_entries_root_cache[output_dir]


def _resolve_mark_dir(relative_path: str, output_dir: Path) -> Path:
"""Map a MATLAB-relative mark path to the converted output directory."""
parts = relative_path.replace("\\", "/").strip("/").split("/")
try:
suffix = "/".join(parts[parts.index("tool-entries") :])
except ValueError:
suffix = "/".join(parts)
te_root = _get_tool_entries_root(output_dir)
return (te_root / suffix) if te_root else (output_dir / suffix)


def _firearm_dir(mark_dir: Path) -> Path:
"""Return the firearm directory (first child of ``tool-entries``)."""
parts = mark_dir.parts
try:
return Path(*parts[: parts.index("tool-entries") + 2])
except ValueError:
return mark_dir.parent.parent.parent


def _extract_metadata(mark_dir: Path) -> dict[str, str]:
"""Extract MarkMetadata by walking up from *mark_dir* to ``tool-entries``.

:param mark_dir: path to the mark directory.
:returns: dict with case_id, firearm_id, specimen_id, measurement_id, mark_id.
"""
parts = mark_dir.parts
try:
te = parts.index("tool-entries")
except ValueError:
return {k: "unknown" for k in ("case_id", "firearm_id", "specimen_id", "measurement_id", "mark_id")}

def _name(idx: int) -> str:
p = Path(*parts[: idx + 1]) if idx < len(parts) else mark_dir
return _parse_db_scratch(p / "db.scratch").get("NAME", p.name)

return {
"case_id": _name(te - 1) if te > 0 else "unknown",
"firearm_id": _name(te + 1),
"specimen_id": _name(te + 2),
"measurement_id": _name(te + 3),
"mark_id": _name(te + 4) if len(parts) > te + 4 else mark_dir.name,
}


def find_comparison_results(root: Path) -> Iterator[tuple[Path, MarkType]]:
"""Yield ``(results_folder, mark_type)`` for each ``results_table.mat`` found."""
for mat in root.rglob("mark-comparison-results/*/results_table.mat"):
mt = infer_mark_type(mat.parent.name)
if mt is None:
logger.warning("Cannot infer mark type from %s, skipping", mat.parent.name)
continue
yield mat.parent, mt


@dataclass
class ComparisonEntry:
"""A single comparison pair with pre-resolved paths."""

mark_dir_ref: Path
mark_dir_comp: Path
mark_type: MarkType
comparison_out: Path
row_index: int


def _build_body(entry: ComparisonEntry) -> dict[str, Any]:
"""Build the API request body for a comparison."""
processed_ref = str(entry.mark_dir_ref / "processed")
processed_comp = str(entry.mark_dir_comp / "processed")

if entry.mark_type.is_striation():
return {
"mark_dir_ref": processed_ref,
"mark_dir_comp": processed_comp,
"param": {
"metadata_reference": _extract_metadata(entry.mark_dir_ref),
"metadata_compared": _extract_metadata(entry.mark_dir_comp),
},
}
# TODO: update with actual CalculateScoreImpression fields
return {"mark_dir_ref": processed_ref, "mark_dir_comp": processed_comp}


def _save_result(entry: ComparisonEntry, result: dict[str, Any] | None = None, error: str | None = None) -> None:
"""Write comparison_results.json with full context."""
entry.comparison_out.mkdir(parents=True, exist_ok=True)
output = {
"mark_dir_ref": str(entry.mark_dir_ref),
"mark_dir_comp": str(entry.mark_dir_comp),
"mark_type": entry.mark_type.value,
"metadata": {
"metadata_reference": _extract_metadata(entry.mark_dir_ref),
"metadata_compared": _extract_metadata(entry.mark_dir_comp),
},
"error": error,
"comparison_results": result.get("comparison_results") if result else None,
}
(entry.comparison_out / "comparison_results.json").write_text(json.dumps(output, indent=2, default=str))
66 changes: 18 additions & 48 deletions scripts/convert_matlab_results.py → scripts/convert_marks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,26 @@
Walks a nested folder structure, converts x3p files, extracts crop and
preprocessing parameters from .mat files, and calls the local preprocessor
API to regenerate marks.

Usage:
python convert_matlab_results.py /path/to/root output/
python convert_matlab_results.py /path/to/root output/ --api-url http://localhost:8000
"""

import argparse
import contextlib
import io
import json
import logging
import os
import uuid
from collections.abc import Callable, Iterable, Iterator
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from collections.abc import Iterator
from datetime import datetime
from pathlib import Path
from typing import Any

import numpy as np
import requests
from parsers import load_scan_image, parse_to_x3p
from returns.unsafe import unsafe_perform_io
from tqdm import tqdm

from scripts.conversion_utils import ConversionConfig, run_parallel
from scripts.http_utils import _post_with_retry
from scripts.matlab_utils import (
extract_impression_params,
Expand All @@ -40,21 +37,8 @@
logger = logging.getLogger(__name__)


@dataclass
class ConversionConfig:
"""Shared configuration for the conversion pipeline."""

root: Path
output_dir: Path
api_url: str
force: bool = False

def __post_init__(self) -> None:
self.api_url = self.api_url.rstrip("/")


def convert_x3p(input_path: Path, output_path: Path) -> tuple[int, int]:
"""Load an X3P from path and parse it with the pipelines and write the result."""
"""Load an X3P from path, parse it, and write the result."""
scan = unsafe_perform_io(load_scan_image(input_path).unwrap())
x3p = parse_to_x3p(scan).unwrap()

Expand Down Expand Up @@ -139,7 +123,8 @@ def convert_mark(
"""Process a single mark: extract params, call API, download results.

Reads crop info and preprocessing parameters from mark.mat, builds the
API request, and downloads the resulting files into the output directory.
API request as multipart form + file upload, and downloads the resulting
files into the output directory.
"""
mark_dir = cfg.output_dir / mark_folder.relative_to(cfg.root)

Expand All @@ -156,15 +141,20 @@ def convert_mark(
endpoint = f"preprocessor/prepare-mark-{'impression' if is_impression else 'striation'}"
params = extract_impression_params(struct, mark_type) if is_impression else extract_striation_params(struct)

body = {
params_dict = {
"scan_file": str(converted_x3p),
"mark_type": mark_type.value,
"mask": mask.astype(int).tolist(),
"mark_parameters": params,
"bounding_box_list": bounding_box_list,
"mask_is_bitpacked": False,
}

result = _post_with_retry(f"{cfg.api_url}/{endpoint}", body)
mask_bytes = mask.astype(np.bool_).tobytes()
result = _post_with_retry(
f"{cfg.api_url}/{endpoint}",
data={"params": json.dumps(params_dict)},
files={"mask_data": ("mask.bin", io.BytesIO(mask_bytes), "application/octet-stream")},
)

processed_dir = mark_dir / "processed"
mark_dir.mkdir(parents=True, exist_ok=True)
Expand All @@ -181,26 +171,6 @@ def convert_mark(
(dest / filename).write_bytes(resp.content)


def _run_parallel(
tasks: Iterable[tuple[Any, Callable, tuple]],
workers: int,
desc: str,
unit: str,
) -> dict[Any, Any]:
"""Run tasks in parallel with a progress bar.

:param tasks: iterable of (key, fn, args) tuples.
:returns: dict of {key: result}.
"""
task_list = list(tasks)
with ThreadPoolExecutor(max_workers=workers) as pool:
futures = {pool.submit(fn, *args): key for key, fn, args in task_list}
results = {}
for future in tqdm(as_completed(futures), total=len(futures), desc=desc, unit=unit):
results[futures[future]] = future.result()
return results


def main() -> None:
"""Entry point: parse args and run the conversion pipeline."""
parser = argparse.ArgumentParser(description="Convert MATLAB results via Python API")
Expand All @@ -222,14 +192,14 @@ def main() -> None:
logger.info(f"Found {len(marks)} marks")

unique_measurements = list({mf for mf, _ in marks})
converted_x3ps = _run_parallel(
converted_x3ps = run_parallel(
((mf, convert_measurement_x3p, (mf, cfg)) for mf in unique_measurements),
args.workers,
"Converting x3p",
" files",
)

_run_parallel(
run_parallel(
((mf, convert_mark, (mf, *converted_x3ps[meas], cfg)) for meas, mf in marks),
args.workers,
"Converting marks",
Expand Down
Loading
Loading