Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,54 @@
(e.g., breech face impressions) through leveling, filtering, and resampling steps.
"""

from dataclasses import asdict
from pydantic import BaseModel

from container_models.base import DepthData
from conversion.data_formats import Mark
from conversion.filter import (
apply_gaussian_filter_mark,
apply_filter_pipeline,
apply_gaussian_filter_mark,
)
from conversion.leveling import SurfaceTerms, level_map
from conversion.mask import crop_to_mask
from conversion.preprocess_impression.center import compute_center_local
from conversion.preprocess_impression.parameters import PreprocessingImpressionParams
from conversion.preprocess_impression.resample import (
resample,
needs_resampling,
resample,
)
from conversion.preprocess_impression.tilt import apply_tilt_correction
from conversion.preprocess_impression.utils import update_mark_data, Point2D
from conversion.preprocess_impression.utils import Point2D, update_mark_data
from conversion.resample import get_scaling_factors, resample_array_2d


class ImpressionParams(BaseModel):
pixel_size: float | None = None
adjust_pixel_spacing: bool = True
level_offset: bool = True
level_tilt: bool = True
level_2nd: bool = True
interp_method: str = "cubic"
highpass_cutoff: float | None = 250.0e-6
lowpass_cutoff: float | None = 5.0e-6
highpass_regression_order: int = 2
lowpass_regression_order: int = 0

@property
def surface_terms(self) -> SurfaceTerms:
"""Convert leveling flags to SurfaceTerms."""
terms = SurfaceTerms.NONE
if self.level_offset:
terms |= SurfaceTerms.OFFSET
if self.level_tilt:
terms |= SurfaceTerms.TILT_X | SurfaceTerms.TILT_Y
if self.level_2nd:
terms |= SurfaceTerms.ASTIG_45 | SurfaceTerms.DEFOCUS | SurfaceTerms.ASTIG_0
return terms


def preprocess_impression_mark(
mark: Mark,
params: PreprocessingImpressionParams,
params: ImpressionParams,
) -> tuple[Mark, Mark]:
"""
Preprocess trimmed impression image data.
Expand Down Expand Up @@ -101,7 +125,7 @@ def preprocess_impression_mark(
)

# Build output metadata
mark.meta_data.update(**asdict(params))
mark.meta_data.update(**params.model_dump())

return mark_filtered, mark_leveled_final

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,34 @@
- Fine rotation to align striations horizontally and profile extraction
"""

from dataclasses import asdict

import numpy as np
from pydantic import BaseModel

from container_models.base import FloatArray2D
from container_models.scan_image import ScanImage
from conversion.data_formats import Mark
from conversion.filter import (
cutoff_to_gaussian_sigma,
apply_striation_preserving_filter_1d,
cutoff_to_gaussian_sigma,
)
from conversion.preprocess_striation import PreprocessingStriationParams
from conversion.preprocess_striation.alignment import fine_align_bullet_marks
from conversion.preprocess_striation.shear import propagate_nan
from conversion.profile_correlator import Profile


class StriationParams(BaseModel):
highpass_cutoff: float = 2e-3
lowpass_cutoff: float = 2.5e-4
cut_borders_after_smoothing: bool = True
use_mean: bool = True
angle_accuracy: float = 0.1
max_iter: int = 25
subsampling_factor: int = 1


def preprocess_striation_mark(
mark: Mark,
params: PreprocessingStriationParams,
params: StriationParams,
) -> tuple[Mark, Profile]:
"""
Complete the preprocessing pipeline for striated marks.
Expand Down Expand Up @@ -89,7 +97,7 @@ def preprocess_striation_mark(
# Build meta_data with mask and total_angle
aligned_meta_data = {
**mark.meta_data,
**asdict(params),
**params.model_dump(),
"total_angle": total_angle,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,29 @@
Tests for preprocess_striation.py and related filter functions.
"""

from math import ceil

import numpy as np
import pytest
from math import ceil
from scipy.constants import micro

from container_models.scan_image import ScanImage
from conversion.data_formats import MarkType
from ..helper_functions import make_mark
from conversion.filter import (
apply_striation_preserving_filter_1d,
cutoff_to_gaussian_sigma,
)
from conversion.filter.gaussian import _apply_nan_weighted_gaussian_1d
from conversion.preprocess_striation import (
PreprocessingStriationParams,
apply_shape_noise_removal,
fine_align_bullet_marks,
preprocess_striation_mark,
)
from conversion.preprocess_striation.shear import shear_data_by_shifting_profiles
from conversion.preprocess_striation.alignment import _detect_striation_angle
from conversion.preprocess_striation.pipeline import StriationParams
from conversion.preprocess_striation.shear import shear_data_by_shifting_profiles

from ..helper_functions import make_mark


def test_cutoff_to_gaussian_sigma():
Expand Down Expand Up @@ -336,7 +338,7 @@ def test_preprocess_striation_mark():
scale_y=micro,
mark_type=MarkType.BULLET_LEA_STRIATION,
)
params = PreprocessingStriationParams(
params = StriationParams(
highpass_cutoff=2e-3,
lowpass_cutoff=2.5e-4,
cut_borders_after_smoothing=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@

from container_models.base import DepthData
from conversion.data_formats import MarkType
from ..helper_functions import make_mark
from conversion.preprocess_striation import (
PreprocessingStriationParams,
preprocess_striation_mark,
)
from conversion.preprocess_striation.pipeline import StriationParams

from ..helper_functions import (
_compute_correlation,
_crop_to_common_shape,
_compute_difference_stats,
_crop_to_common_shape,
make_mark,
)


MARK_TYPE_MAPPING = {
"bullet lea striation": MarkType.BULLET_LEA_STRIATION,
"bullet gea striation": MarkType.BULLET_GEA_STRIATION,
Expand Down Expand Up @@ -166,7 +166,7 @@ def run_python_preprocessing(
test_case: MatlabTestCase,
) -> tuple[np.ndarray, np.ndarray | None, float | None]:
"""Run Python preprocess_striation_mark and return the results."""
params = PreprocessingStriationParams(
params = StriationParams(
highpass_cutoff=test_case.cutoff_hi * micro,
lowpass_cutoff=test_case.cutoff_lo * micro,
use_mean=test_case.use_mean,
Expand Down
30 changes: 15 additions & 15 deletions packages/scratch-core/tests/conversion/test_impression.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
_points_are_collinear,
compute_center_local,
)
from conversion.preprocess_impression.parameters import PreprocessingImpressionParams
from conversion.preprocess_impression.preprocess_impression import (
ImpressionParams,
preprocess_impression_mark,
)
from conversion.preprocess_impression.resample import needs_resampling
Expand Down Expand Up @@ -422,7 +422,7 @@ def test_basic_pipeline_runs(self):
scale_y=micro,
mark_type=MarkType.FIRING_PIN_IMPRESSION,
)
params = PreprocessingImpressionParams(
params = ImpressionParams(
pixel_size=2 * micro,
adjust_pixel_spacing=False,
level_offset=True,
Expand All @@ -446,7 +446,7 @@ def test_output_has_correct_scale(self):
mark_type=MarkType.FIRING_PIN_IMPRESSION,
)
target_size = 2 * micro
params = PreprocessingImpressionParams(
params = ImpressionParams(
pixel_size=target_size,
adjust_pixel_spacing=False,
)
Expand All @@ -468,7 +468,7 @@ def test_output_is_smaller_after_downsampling(self):
scale_y=micro,
mark_type=MarkType.FIRING_PIN_IMPRESSION,
)
params = PreprocessingImpressionParams(
params = ImpressionParams(
pixel_size=2 * micro, # 2x downsampling
adjust_pixel_spacing=False,
)
Expand All @@ -491,7 +491,7 @@ def test_filtered_and_leveled_differ(self):
scale_y=micro,
mark_type=MarkType.FIRING_PIN_IMPRESSION,
)
params = PreprocessingImpressionParams(
params = ImpressionParams(
adjust_pixel_spacing=False,
highpass_cutoff=50 * micro, # Apply high-pass to create difference
)
Expand All @@ -514,7 +514,7 @@ def test_breech_face_uses_circle_center(self):
scale_y=micro,
mark_type=MarkType.BREECH_FACE_IMPRESSION,
)
params = PreprocessingImpressionParams(
params = ImpressionParams(
pixel_size=micro, # No resampling
adjust_pixel_spacing=False,
)
Expand All @@ -535,7 +535,7 @@ def test_no_resampling_when_pixel_size_matches(self):
scale_y=micro,
mark_type=MarkType.FIRING_PIN_IMPRESSION,
)
params = PreprocessingImpressionParams(
params = ImpressionParams(
pixel_size=micro, # Same as input
adjust_pixel_spacing=False,
)
Expand All @@ -555,7 +555,7 @@ def test_is_resampled_flag_set_on_resampling(self):
scale_y=micro,
mark_type=MarkType.FIRING_PIN_IMPRESSION,
)
params = PreprocessingImpressionParams(
params = ImpressionParams(
pixel_size=2 * micro, # Different from input
adjust_pixel_spacing=False,
)
Expand All @@ -575,7 +575,7 @@ def test_without_lowpass_filter(self):
scale_y=micro,
mark_type=MarkType.FIRING_PIN_IMPRESSION,
)
params = PreprocessingImpressionParams(
params = ImpressionParams(
pixel_size=2 * micro,
adjust_pixel_spacing=False,
lowpass_cutoff=None,
Expand All @@ -596,7 +596,7 @@ def test_without_highpass_filter(self):
scale_y=micro,
mark_type=MarkType.FIRING_PIN_IMPRESSION,
)
params = PreprocessingImpressionParams(
params = ImpressionParams(
pixel_size=2 * micro,
adjust_pixel_spacing=False,
highpass_cutoff=None,
Expand All @@ -617,7 +617,7 @@ def test_without_any_filters(self):
scale_y=micro,
mark_type=MarkType.FIRING_PIN_IMPRESSION,
)
params = PreprocessingImpressionParams(
params = ImpressionParams(
pixel_size=2 * micro,
adjust_pixel_spacing=False,
lowpass_cutoff=None,
Expand All @@ -639,7 +639,7 @@ def test_with_tilt_adjustment(self):
scale_y=micro,
mark_type=MarkType.FIRING_PIN_IMPRESSION,
)
params = PreprocessingImpressionParams(
params = ImpressionParams(
pixel_size=2 * micro,
adjust_pixel_spacing=True,
)
Expand All @@ -659,7 +659,7 @@ def test_with_second_order_leveling(self):
scale_y=micro,
mark_type=MarkType.FIRING_PIN_IMPRESSION,
)
params = PreprocessingImpressionParams(
params = ImpressionParams(
pixel_size=2 * micro,
adjust_pixel_spacing=False,
level_offset=True,
Expand All @@ -682,7 +682,7 @@ def test_output_data_is_finite_where_valid(self):
scale_y=micro,
mark_type=MarkType.FIRING_PIN_IMPRESSION,
)
params = PreprocessingImpressionParams(
params = ImpressionParams(
pixel_size=2 * micro,
adjust_pixel_spacing=False,
)
Expand Down Expand Up @@ -715,7 +715,7 @@ def test_leveled_preserves_form(self):
scale_y=micro,
mark_type=MarkType.FIRING_PIN_IMPRESSION,
)
params = PreprocessingImpressionParams(
params = ImpressionParams(
pixel_size=micro,
adjust_pixel_spacing=False,
level_offset=True,
Expand Down
11 changes: 6 additions & 5 deletions packages/scratch-core/tests/conversion/test_impression_matlab.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@

from container_models.base import FloatArray2D
from conversion.data_formats import MarkType
from .helper_functions import make_mark
from conversion.preprocess_impression.preprocess_impression import (
ImpressionParams,
preprocess_impression_mark,
)
from conversion.preprocess_impression.parameters import PreprocessingImpressionParams

from .helper_functions import (
_compute_correlation,
_crop_to_common_shape,
_compute_difference_stats,
_crop_to_common_shape,
make_mark,
)


Expand All @@ -32,7 +33,7 @@ class MatlabTestCase:
input_data: FloatArray2D
pixel_spacing: tuple[float, float]
# Processing options
params: PreprocessingImpressionParams
params: ImpressionParams
use_circle_center: bool
# Expected output
output_data: FloatArray2D
Expand Down Expand Up @@ -89,7 +90,7 @@ def from_directory(cls, case_dir: Path) -> "MatlabTestCase":
lowpass_cutoff = cutoff_val
lowpass_order = int(f.get("n_order", 0))

params = PreprocessingImpressionParams(
params = ImpressionParams(
adjust_pixel_spacing=meta.get("adjust_pixel_spacing", False),
level_offset=level_offset,
level_tilt=level_tilt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,18 @@
import pytest
from scipy.constants import micro

from container_models.base import DepthData, BinaryMask, StriationProfile
from container_models.base import BinaryMask, DepthData, StriationProfile
from conversion.data_formats import MarkType
from .helper_functions import make_mark
from conversion.preprocess_striation import (
PreprocessingStriationParams,
preprocess_striation_mark,
)
from conversion.preprocess_striation.pipeline import StriationParams

from .helper_functions import (
_compute_correlation,
_crop_to_common_shape,
_compute_difference_stats,
_crop_to_common_shape,
make_mark,
)


Expand Down Expand Up @@ -170,7 +171,7 @@ def run_python_preprocessing(
mark_type=mark_type,
)

params = PreprocessingStriationParams(
params = StriationParams(
highpass_cutoff=test_case.cutoff_hi * micro,
lowpass_cutoff=test_case.cutoff_lo * micro,
use_mean=test_case.use_mean,
Expand Down
Loading
Loading