From d3b98fdab47b776eb4b766c8675093aa92016532 Mon Sep 17 00:00:00 2001 From: dominosauro Date: Tue, 7 Oct 2025 11:39:27 +0200 Subject: [PATCH 1/9] Initialized function auto_filtering. --- src/dynsight/_internal/data_processing/auto_filtering.py | 2 ++ src/dynsight/data_processing.py | 4 ++++ 2 files changed, 6 insertions(+) create mode 100644 src/dynsight/_internal/data_processing/auto_filtering.py diff --git a/src/dynsight/_internal/data_processing/auto_filtering.py b/src/dynsight/_internal/data_processing/auto_filtering.py new file mode 100644 index 00000000..ccfe9da6 --- /dev/null +++ b/src/dynsight/_internal/data_processing/auto_filtering.py @@ -0,0 +1,2 @@ +def auto_filtering() -> None: + print("auto") diff --git a/src/dynsight/data_processing.py b/src/dynsight/data_processing.py index db4eb959..b5845793 100644 --- a/src/dynsight/data_processing.py +++ b/src/dynsight/data_processing.py @@ -1,5 +1,8 @@ """data processing package.""" +from dynsight._internal.data_processing.auto_filtering import ( + auto_filtering, +) from dynsight._internal.data_processing.classify import ( applyclassification, createreferencesfromtrajectory, @@ -20,6 +23,7 @@ __all__ = [ "applyclassification", + "auto_filtering", "createreferencesfromtrajectory", "getdistancebetween", "getdistancesfromref", From c1b788753092ed63078548a3884b8555bb945dfc Mon Sep 17 00:00:00 2001 From: dominosauro Date: Thu, 9 Oct 2025 15:08:05 +0200 Subject: [PATCH 2/9] Full auto_filtering code. --- .../data_processing/auto_filtering.py | 1363 ++++++++++++++++- 1 file changed, 1361 insertions(+), 2 deletions(-) diff --git a/src/dynsight/_internal/data_processing/auto_filtering.py b/src/dynsight/_internal/data_processing/auto_filtering.py index ccfe9da6..cd038158 100644 --- a/src/dynsight/_internal/data_processing/auto_filtering.py +++ b/src/dynsight/_internal/data_processing/auto_filtering.py @@ -1,2 +1,1361 @@ -def auto_filtering() -> None: - print("auto") +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +if TYPE_CHECKING: + import os + + from matplotlib.axes import Axes + from matplotlib.figure import Figure + +import io +import logging +import random +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, TypeAlias + +import imageio.v2 as imageio +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns # type: ignore[import-untyped] +from numpy.fft import fft, fftfreq +from numpy.typing import NDArray +from scipy.signal import butter, filtfilt + +from dynsight.trajectory import Insight + +# Type alias for 64-bit float numpy arrays +ArrayF64: TypeAlias = NDArray[np.float64] + + +# --------------------------- Constants --------------------------- + +# Frequency conversion constants +FREQ_TERA = 1e12 # Terahertz in Hz +FREQ_GIGA = 1e9 # Gigahertz in Hz +FREQ_MEGA = 1e6 # Megahertz in Hz +FREQ_KILO = 1e3 # Kilohertz in Hz + +# Default parameters for filtering +DEFAULT_FRAMES_TO_REMOVE = 20 # Frames to trim from each end +DEFAULT_FILTER_ORDER = 4 # Butterworth filter order + +# Image processing constants +IMG_NDIM_GRAYSCALE = 2 # Number of dimensions for grayscale +IMG_CHANNELS_RGBA = 4 # Number of channels for RGBA images + +# Numerical constants +SMALL_EPSILON = 1e-9 # Small number to avoid division by zero +NDIM_EXPECTED = 2 # Expected number of dimensions for input +MIN_FRAMES_TO_DROP = 2 # Minimum frames needed to drop first frame + +# Initialize logger for this module +logger = logging.getLogger(__name__) + + +# --------------------------- Result container --------------------------- + + +@dataclass(frozen=True) +class AutoFiltInsight: + """Container for auto-filtering results. + + Stores all outputs from the filtering workflow including paths, + cutoff frequencies, and metadata. + + Attributes: + output_dir: Base directory where all outputs are saved + video_path: Path to forward video showing filter evolution + cutoffs: List of cutoff frequencies used (Hz) + filtered_files: Dict mapping cutoff freq to saved .npy path + meta: Dictionary of metadata (parameters used) + filtered_collection: Tuple of filtered signal arrays + """ + # Non-default fields must come first + output_dir: Path + video_path: Path | None + + # Default fields (hide large arrays from repr) + cutoffs: list[float] = field( + default_factory=list, repr=False) + filtered_files: dict[float, Path] = field( + default_factory=dict, repr=False) + meta: dict[str, Any] = field( + default_factory=dict, repr=False) + filtered_collection: tuple[ArrayF64, ...] = field( + default_factory=tuple, repr=False + ) + + +# --------------------------- Helpers (I/O, plots) --------------------------- + + +def _resolve_dataset_path(user_path: str | os.PathLike[str]) -> Path: + """Resolve user input to a concrete dataset file path. + + Accepts either a file or folder. For folders, looks for a single + file with preference: .json > .npy > .npz. + + Args: + user_path: User-provided path (file or directory) + + Returns: + Resolved Path to a dataset file + + Raises: + FileNotFoundError: If path doesn't exist or no valid files found + ValueError: If multiple files of same type found (ambiguous) + """ + # Expand ~ and resolve to absolute path + p = Path(user_path).expanduser().resolve() + + # If it's already a file, return it + if p.is_file(): + return p + + # Check if path exists at all + if not p.exists(): + msg = f"Path does not exist: {p}" + raise FileNotFoundError(msg) + + # If it's a directory, search for dataset files + if p.is_dir(): + # Try each extension in preference order + for ext in (".json", ".npy", ".npz"): + # Find all files with this extension + hits = sorted(p.glob(f"*{ext}")) + + # Exactly one file found - use it + if len(hits) == 1: + return hits[0] + + # Multiple files found - ambiguous + if len(hits) > 1: + names = ", ".join(h.name for h in hits) + msg = f"Multiple {ext} files in {p}: {names}" + raise ValueError(msg) + + # No valid files found + msg = f"No .json/.npy/.npz found in {p}" + raise FileNotFoundError(msg) + + # Shouldn't reach here (not file, not dir, but exists?) + msg = f"Unsupported path: {p}" + raise FileNotFoundError(msg) + + +def _load_array_any( + path: Path, + *, + mmap_mode: Literal["r+", "r", "w+", "c"] | None = None, + enforce_2d: bool = True, +) -> NDArray[np.float64]: + """Load dataset from .json, .npy, or .npz file. + + Wraps loaded data into an Insight object for validation, + then returns the underlying array. + + Args: + path: Path to dataset file + mmap_mode: Memory-mapping mode for numpy.load + enforce_2d: If True, raise error if not 2D array + + Returns: + Loaded numpy array + + Raises: + ValueError: If file type unsupported, empty .npz, or wrong + dimensions + """ + # Get file extension (lowercase) + sfx = path.suffix.lower() + + # Load based on file type + if sfx == ".json": + # Load from JSON format + arr1 = np.load(path, mmap_mode=mmap_mode) + ins = Insight(arr1) + elif sfx == ".npy": + # Load from numpy binary format + arr = np.load(path, mmap_mode=mmap_mode) + ins = Insight(dataset=np.asarray(arr), meta={"source": path.name}) + elif sfx == ".npz": + # Load from compressed numpy format + z = np.load(path, mmap_mode=mmap_mode) + + # Check if npz is empty + if not z.files: + msg = "Empty .npz file." + raise ValueError(msg) + + # Use first key in npz + key = z.files[0] + ins = Insight( + dataset=np.asarray(z[key]), meta={"source": path.name, "key": key} + ) + else: + # Unsupported file type + msg = f"Unsupported file type: {sfx}" + raise ValueError(msg) + + # Validate dimensions if requested + if enforce_2d and ins.dataset.ndim != NDIM_EXPECTED: + msg = f"Expected 2D array (series x frames), got {ins.dataset.shape}" + raise ValueError(msg) + + # Return as numpy array + return np.asarray(ins.dataset) + + +def _make_dir_safe(directory: Path) -> None: + """Create directory and all parent directories if they don't exist. + + Args: + directory: Path to directory to create + """ + # Create directory with parents, don't error if exists + directory.mkdir(parents=True, exist_ok=True) + + +def _freq_label_for_folder(freq_hz: float) -> str: + """Convert frequency in Hz to human-readable string with units. + + Chooses appropriate unit (THz, GHz, MHz, kHz, Hz) based on + magnitude. + + Args: + freq_hz: Frequency in Hertz + + Returns: + Formatted string like "1.234GHz" + """ + # Choose unit based on frequency magnitude + if freq_hz >= FREQ_TERA: + return f"{freq_hz / FREQ_TERA:.3f}THz" + if freq_hz >= FREQ_GIGA: + return f"{freq_hz / FREQ_GIGA:.3f}GHz" + if freq_hz >= FREQ_MEGA: + return f"{freq_hz / FREQ_MEGA:.3f}MHz" + if freq_hz >= FREQ_KILO: + return f"{freq_hz / FREQ_KILO:.3f}kHz" + return f"{freq_hz:.3f}Hz" + + +def _plot_fft( + freq: NDArray[np.float64], + mag: NDArray[np.float64], + title: str, + path: Path, + mark_freqs: list[float] | None = None, +) -> None: + """Create and save FFT magnitude plot. + + Plots frequency vs magnitude with optional markers for cutoff + frequencies. + + Args: + freq: Frequency array in Hz + mag: Magnitude array (summed across all series) + title: Plot title + path: Where to save the plot + mark_freqs: Optional list of frequencies to mark with scatter + points + """ + # Create new figure + plt.figure() + + # Plot FFT magnitude vs frequency (in GHz) + plt.plot(freq / FREQ_GIGA, mag, lw=1.5, label="Summed |FFT|") + + # Add markers for cutoff frequencies if provided + if mark_freqs: + # Interpolate magnitude values at cutoff frequencies + y_interp = np.interp(mark_freqs, freq, mag) + # Plot cutoff markers + plt.scatter( + np.array(mark_freqs) / FREQ_GIGA, + y_interp, + s=30, + label="Cutoffs", + ) + + # Format plot + plt.title(title) + plt.xlabel("Frequency (GHz)") + plt.ylabel("Summed Magnitude |FFT|") + plt.grid(alpha=0.3) + plt.legend() + plt.tight_layout() + + # Save and close + plt.savefig(path, dpi=200) + plt.close() + + +def _plot_signals_with_kde( + signals: NDArray[np.float64], title: str, path: Path +) -> None: + """Create dual-panel plot: signals + KDE distribution. + + Left panel shows all signal traces plus mean. + Right panel shows KDE of all signal values. + + Args: + signals: 2D array (series x frames) + title: Plot title + path: Where to save the plot + """ + # Calculate mean signal across all series + mean_signal = np.mean(signals, axis=0) + + # Flatten all values for KDE + all_values = signals.ravel() + + # Create figure with 2 columns + _fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(10, 5)) + + # Left panel: plot all traces in gray + ax1.plot(signals.T, lw=0.3, alpha=0.35, c="gray") + # Overlay mean in red + ax1.plot(mean_signal, color="red", lw=1.0, label="Mean") + ax1.set_title(title) + ax1.set_xlabel("Frame") + ax1.set_ylabel("Signal") + ax1.grid(alpha=0.3) + ax1.legend() + + # Right panel: KDE of all values + sns.kdeplot(y=all_values, ax=ax2, fill=True, alpha=0.3) + ax2.set_title("KDE Distribution") + + # Save and close + plt.tight_layout() + plt.savefig(path, dpi=200) + plt.close() + + +def _plot_single_atom_comparison( + orig: NDArray[np.float64], + filt: NDArray[np.float64], + atom_idx: int, + path: Path, +) -> None: + """Plot original vs filtered signal for a single series. + + Args: + orig: Original signals (series x frames) + filt: Filtered signals (series x frames) + atom_idx: Index of series to plot + path: Where to save the plot + """ + # Create figure + plt.figure() + + # Plot original signal + plt.plot(orig[atom_idx], label="Original", lw=1.0) + # Plot filtered signal + plt.plot(filt[atom_idx], label="Filtered", lw=1.0) + + # Format plot + plt.title(f"Atom/Series {atom_idx}: Original vs Filtered") + plt.xlabel("Frame") + plt.ylabel("Signal") + plt.legend() + plt.grid(alpha=0.3) + plt.tight_layout() + + # Save and close + plt.savefig(path, dpi=200) + plt.close() + + +# --------------------------- Filt helpers --------------------------- + + +def _compute_fft_summed( + signals: NDArray[np.float64], dt: float +) -> tuple[NDArray[np.float64], NDArray[np.float64]]: + """Compute FFT along time axis and sum magnitudes across series. + + Only keeps positive frequencies. Useful for finding dominant + frequency components across all signals. + + Args: + signals: 2D array (series x frames) + dt: Time step in seconds + + Returns: + freq: Positive frequency array + mag_sum: Summed magnitude across all series + """ + # Get shape + _n_series, n_frames = signals.shape + + # Compute frequency bins + f_all = fftfreq(n_frames, d=dt) + + # Keep only positive frequencies + pos_mask = f_all > 0 + + # Compute FFT along time axis (axis=1) + fft_vals = fft(signals, axis=1)[:, pos_mask] + + # Sum magnitudes across all series + + mag_sum: NDArray[np.float64] = np.asarray(np.abs(fft_vals), + dtype=np.float64).sum(axis=0) + + # Get positive frequencies + freq: NDArray[np.float64] = np.asarray(f_all[pos_mask], dtype=np.float64) + + + return freq, mag_sum + + +def _find_cutoffs_biased( + freq: NDArray[np.float64], + mag: NDArray[np.float64], + num_levels: int, + low_frac: float = 0.20, + low_ratio: float = 2.0, + min_frac: float = 0.05, + max_frac: float = 0.95, +) -> list[float]: + """Find cutoff frequencies biased toward low frequencies. + + Uses cumulative FFT magnitude to select frequencies. Puts more + cutoffs in the low-frequency region (below low_frac) by a ratio + of low_ratio:1. + + Args: + freq: Frequency array + mag: Magnitude array + num_levels: Total number of cutoffs to find + low_frac: Cumulative fraction defining "low frequency" region + low_ratio: Ratio of low-freq to high-freq cutoffs + min_frac: Minimum cumulative fraction to consider + max_frac: Maximum cumulative fraction to consider + + Returns: + List of cutoff frequencies (sorted, unique) + """ + # Compute cumulative sum of magnitude + cum = np.cumsum(mag) + + # Get total magnitude + total = cum[-1] if cum.size else 0.0 + + # Handle all-zero case + if total == 0.0: + warn_msg = ( + "[WARN] Summed magnitude is all zeros; " + "using max frequency as single cutoff." + ) + logger.warning(warn_msg) + return [float(freq[-1])] if freq.size else [] + + # Normalize to cumulative fraction + cum = cum / total + + # Calculate how many cutoffs in each region + n_low = max(1, round(num_levels * (low_ratio / (low_ratio + 1.0)))) + n_high = max(1, num_levels - n_low) + + # Define boundary for low-frequency region + low_hi = max(min(low_frac, max_frac - 1e-6), min_frac + 1e-6) + + # Create thresholds for low-frequency region + th_low = np.linspace(min_frac, low_hi, n_low, endpoint=True) + + # Create thresholds for high-frequency region + th_high = np.linspace(low_hi + 1e-6, max_frac, n_high, endpoint=True) + + # Combine all thresholds + thresholds = np.concatenate([th_low, th_high]) + + # Find frequency for each threshold + cutoffs: list[float] = [] + for th in thresholds: + # Find index where cumulative reaches threshold + idx = np.searchsorted(cum, th, side="left") + + # Clamp to valid range + if idx >= len(freq): + idx = len(freq) - 1 + + # Get frequency at this index + c = float(freq[int(idx)]) + + # Only add if not duplicate + if len(cutoffs) == 0 or not np.isclose(c, cutoffs[-1]): + cutoffs.append(c) + + # Sort and remove duplicates + cutoffs = sorted(set(cutoffs)) + + # Log information about split point + idx_split = np.searchsorted(cum, low_frac, side="left") + f_split = ( + float(freq[min(int(idx_split), len(freq) - 1)]) + if len(freq) + else float("nan") + ) + info_msg = ( + f"[INFO] Biased cutoffs: low_frac={low_frac:.2f} " + f"(~f={f_split:.3e} Hz) | total unique={len(cutoffs)}" + ) + logger.info(info_msg) + + return cutoffs + + +def _butter_lowpass_filter( + signal: NDArray[np.float64], + cutoff: float, + fs: float, + order: int = DEFAULT_FILTER_ORDER, +) -> NDArray[np.float64]: + """Apply Butterworth lowpass filter to signal. + + Uses zero-phase filtering (filtfilt) to avoid phase distortion. + + Args: + signal: 1D signal array + cutoff: Cutoff frequency in Hz + fs: Sampling frequency in Hz + order: Filter order (higher = sharper cutoff) + + Returns: + Filtered signal array + """ + # Calculate Nyquist frequency + nyq = 0.5 * fs + + # Check if cutoff is valid + if cutoff >= nyq: + warn_msg = ( + f"[WARN] cutoff {cutoff:.3e} >= Nyquist {nyq:.3e}; " + "passing signal through." + ) + logger.warning(warn_msg) + return signal + + # Design Butterworth filter + b, a = butter(order, cutoff / nyq, btype="low") + + # Apply zero-phase filter + return filtfilt(b, a, signal) + + +def _remove_filter_artifacts( + signals: NDArray[np.float64], + frames_to_remove: int = DEFAULT_FRAMES_TO_REMOVE, +) -> NDArray[np.float64]: + """Remove edge frames affected by filtering artifacts. + + Trims the same number of frames from both start and end. + + Args: + signals: 2D array (series x frames) + frames_to_remove: Number of frames to remove from each end + + Returns: + Trimmed signal array + """ + # Get shape + _n_series, n_frames = signals.shape + + # Check if we have enough frames to trim + if n_frames <= 2 * frames_to_remove: + warn_msg = ( + f"[WARN] Not enough frames ({n_frames}) to remove " + f"{frames_to_remove} per side. Skipping trim." + ) + logger.warning(warn_msg) + return signals + + # Trim frames from both ends + trimmed = signals[:, frames_to_remove:-frames_to_remove] + + # Log the operation + logger.info( + f"[STEP] Removed {frames_to_remove} frames per side " + f"-> new shape {trimmed.shape}" + ) + + return trimmed + + +# --------------------------- Video helpers --------------------------- + + +def _draw_left_panel( + ax: Axes, + x: NDArray[np.float64], + mean: NDArray[np.float64], + std: NDArray[np.float64], + overlay: list[NDArray[np.float64]] | None, + title: str | None, + y_limits: tuple[float, float] | None, + show_legend: bool, +) -> None: + """Draw left panel of video frame showing signal traces. + + Args: + ax: Matplotlib axes to draw on + x: X-axis values (frame indices) + mean: Mean signal across all series + std: Standard deviation across all series + overlay: Optional list of individual traces to overlay + title: Panel title + y_limits: Optional (ymin, ymax) to fix y-axis + show_legend: Whether to show legend + """ + # Plot mean signal + ax.plot(x, mean, lw=1.2, label="Mean") + + # Fill area for +/- 1 standard deviation + ax.fill_between(x, mean - std, mean + std, alpha=0.25, label="+/- 1 sigma") + + # Overlay individual traces if provided + if overlay: + for tr in overlay: + ax.plot(x, tr, lw=0.6, alpha=0.35) + + # Set labels and title + ax.set_title(title or "Filtered") + ax.set_xlabel("Frame (trimmed)") + ax.set_ylabel("Signal") + ax.grid(alpha=0.3) + + # Add legend if requested + if show_legend: + ax.legend() + + # Set y-limits if provided + if y_limits is not None: + ax.set_ylim(*y_limits) + + +def _draw_kde_panel( + ax: Axes, + dist_values: NDArray[np.float64] | None, + kde_bw: float, +) -> None: + """Draw right panel of video frame showing KDE distribution. + + Args: + ax: Matplotlib axes to draw on + dist_values: Array of all signal values for KDE + kde_bw: Bandwidth adjustment for KDE + """ + # Set title and labels + ax.set_title("KDE") + ax.set_xlabel("Density") + ax.grid(alpha=0.3) + + # Return early if no data + if dist_values is None: + return + + # Convert to array and remove non-finite values + vals = np.asarray(dist_values) + vals = vals[np.isfinite(vals)] + + # Plot KDE if we have valid data with variance + if vals.size > 1 and np.nanstd(vals) > 0: + sns.kdeplot(y=vals, ax=ax, fill=True, alpha=0.3, bw_adjust=kde_bw) + # Just draw horizontal line if constant value + elif vals.size > 0: + ax.axhline(float(vals[0]), ls="--", alpha=0.6) + + # Move y-axis to right side + ax.yaxis.tick_right() + ax.yaxis.set_label_position("right") + + +def _render_frame_array( + mean: NDArray[np.float64], + std: NDArray[np.float64], + y_limits: tuple[float, float] | None = None, + title_override: str | None = None, + overlay: list[NDArray[np.float64]] | None = None, + dist_values: NDArray[np.float64] | None = None, + show_legend: bool = False, + kde_bw: float = 1.0, +) -> NDArray[np.uint8]: + """Render a single video frame as a numpy image array. + + Creates a two-panel figure (signals + KDE) and converts to RGB + array. + + Args: + mean: Mean signal + std: Standard deviation of signal + y_limits: Optional y-axis limits + title_override: Title for left panel + overlay: Optional traces to overlay + dist_values: Values for KDE plot + show_legend: Whether to show legend + kde_bw: KDE bandwidth adjustment + + Returns: + RGB image array (height x width x 3) + """ + # Create x-axis values + x = np.arange(mean.size, dtype=np.float64) + + # Create figure with 2 columns sharing y-axis + fig, (ax1, ax2) = plt.subplots( + ncols=2, + figsize=(10, 4), + sharey=True, + facecolor="white", + ) + + # Draw left panel (signals) + _draw_left_panel( + ax1, x, mean, std, overlay, title_override, y_limits, show_legend + ) + + # Draw right panel (KDE) + _draw_kde_panel(ax2, dist_values, kde_bw) + + # Adjust layout + fig.tight_layout() + + # Convert to image array + return _finalize_frame(fig) + + +def _finalize_frame(fig: Figure) -> NDArray[np.uint8]: + """Convert matplotlib figure to RGB numpy array. + + Ensures dimensions are even (required for some video codecs). + + Args: + fig: Matplotlib figure + + Returns: + RGB image array (height x width x 3) with even dimensions + """ + # Save figure to bytes buffer + buf = io.BytesIO() + fig.savefig(buf, format="png", dpi=220, facecolor="white") + plt.close(fig) + + # Read back as image + buf.seek(0) + img_raw = imageio.imread(buf) + + # Convert grayscale to RGB if needed + if img_raw.ndim == IMG_NDIM_GRAYSCALE: + img: NDArray[np.uint8] = np.stack( + [img_raw, img_raw, img_raw], axis=2 + ) + else: + img = img_raw + + # Remove alpha channel if present + if img.shape[2] == IMG_CHANNELS_RGBA: + img = img[:, :, :3] + + # Ensure even dimensions (crop if needed) + h, w = img.shape[:2] + if h % 2 or w % 2: + img = img[: h - (h % 2), : w - (w % 2), :] + + return img + + +def _render_video( + raw_trim: NDArray[np.float64], + filtered_list: list[NDArray[np.float64]], + cutoffs: list[float], + out_path_base: Path, + y_limits: tuple[float, float] | None = None, + max_overlay_traces: int = 10, + seed: int = 42, + frame_duration: float = 0.25, +) -> Path: + """Create videos showing filter evolution. + + First frame is original (unfiltered), subsequent frames show + progressively filtered signals. + + Args: + raw_trim: Original trimmed signals + filtered_list: List of filtered signal arrays (one per cutoff) + cutoffs: List of cutoff frequencies + out_path_base: Base path for output videos (without extension) + y_limits: Optional fixed y-axis limits for all frames + max_overlay_traces: Max number of individual traces to overlay + seed: Random seed for trace selection + frame_duration: Duration of each frame in seconds + + Returns: + Tuple of (video_path) + """ + # Initialize random number generator + rng = random.Random(seed) # noqa: S311 + frames: list[NDArray[np.uint8]] = [] + + # Simplify y_limits variable + yl = None if y_limits is None else y_limits + + # ---- RAW frame (original, unfiltered) ---- + + # Randomly select traces to overlay + picks = ( + rng.sample( + range(raw_trim.shape[0]), + k=min(max_overlay_traces, raw_trim.shape[0]), + ) + if raw_trim.shape[0] + else [] + ) + traces = [raw_trim[i] for i in picks] + + # Render original data frame + frames.append( + _render_frame_array( + raw_trim.mean(axis=0), + raw_trim.std(axis=0), + y_limits=yl, + title_override="Original (unfiltered, trimmed)", + overlay=traces, + dist_values=raw_trim.ravel(), + show_legend=False, + ) + ) + + # ---- FILTERED frames ---- + + # Create one frame for each cutoff level + for filt, cutoff in zip(filtered_list, cutoffs): + # Randomly select traces to overlay + picks = ( + rng.sample( + range(filt.shape[0]), k=min(max_overlay_traces, filt.shape[0]) + ) + if filt.shape[0] + else [] + ) + traces = [filt[i] for i in picks] + + # Create title with cutoff frequency + title = f"Filtered @ {_freq_label_for_folder(cutoff)}" + + # Render filtered frame + frames.append( + _render_frame_array( + filt.mean(axis=0), + filt.std(axis=0), + y_limits=yl, + title_override=title, + overlay=traces, + dist_values=filt.ravel(), + show_legend=False, + ) + ) + + # Calculate FPS from frame duration + fps = 1.0 / max(frame_duration, SMALL_EPSILON) + + # Define output paths + avi_path = out_path_base.with_suffix(".avi") + + # Try to write videos with detailed parameters + try: + + # Write video + writer = imageio.get_writer( + avi_path, + format="FFMPEG", # type: ignore[arg-type] + mode="I", + fps=fps, + codec="mpeg4", + bitrate="10M", + macro_block_size=None, + output_params=["-pix_fmt", "yuv420p"], + ) + for f in frames[::-1]: + writer.append_data(f) + writer.close() + except TypeError as e: + msg = f"Video creation failed for {avi_path}: {e}" + raise RuntimeError(msg) from e + + # Log success + logger.info(f"[OUT ] Saved AVI ({len(frames)} frames): {avi_path}") + + return avi_path + + +# --------------------------- Core workflow --------------------------- + + +def _validate_params(dt_ps: float, levels: int) -> tuple[float, float]: + """Validate input parameters and compute derived values. + + Args: + dt_ps: Time step in picoseconds + levels: Number of filter levels + + Returns: + Tuple of (dt_seconds, sampling_frequency_hz) + + Raises: + ValueError: If parameters are invalid + """ + # Check dt is positive + if dt_ps <= 0: + msg = "dt_ps must be > 0" + raise ValueError(msg) + + # Check levels is at least 1 + if levels < 1: + msg = "levels must be >= 1" + raise ValueError(msg) + + # Convert dt to seconds + dt = dt_ps * 1e-12 + + # Calculate sampling frequency + fs = 1.0 / dt + + # Log parameters + logger.info( + "dt = %.3g ps (%.3e s) | fs = %.3e Hz | Nyquist = %.3e Hz", + dt_ps, + dt, + fs, + 0.5 * fs, + ) + + return dt, fs + + +def _resolve_signals( + signals: NDArray[np.float64] | None, + path: str | Path, + drop_first_frame: bool, +) -> NDArray[np.float64]: + """Load or validate input signals. + + If signals array is provided, use it. Otherwise load from path. + Optionally drops first frame to remove initialization artifacts. + + Args: + signals: Optional pre-loaded signals array + path: Path to load signals from (if signals is None) + drop_first_frame: Whether to drop the first frame + + Returns: + Validated 2D signals array + + Raises: + ValueError: If array is not 2D or too few frames to drop + """ + # Load signals if not provided + if signals is None: + ds_path = _resolve_dataset_path(path) + signals = _load_array_any(ds_path) + + # Validate dimensions + if signals.ndim != NDIM_EXPECTED: + msg = f"Expected 2D array (series x frames), got {signals.shape}" + raise ValueError(msg) + + # Drop first frame if requested + if drop_first_frame: + # Check we have enough frames + if signals.shape[1] < MIN_FRAMES_TO_DROP: + msg = f"Need at least {MIN_FRAMES_TO_DROP} frames." + raise ValueError(msg) + # Remove first frame + signals = signals[:, 1:] + + # Log final shape + logger.info("Using signals -> shape %s", signals.shape) + + return signals + + +def _select_output_dir(out_dir: str | Path | None) -> Path: + """Create and return output directory path. + + Uses provided path or creates default in current directory. + + Args: + out_dir: Optional output directory path + + Returns: + Path to output directory (created if doesn't exist) + """ + # Use provided path or create default + base = ( + Path(out_dir) + if out_dir is not None + else Path.cwd() / "autofilter_outputs" + ) + + # Create directory if needed + _make_dir_safe(base) + + return base + + +def _fft_and_cutoffs( + signals: NDArray[np.float64], + dt: float, + levels: int, + low_frac: float, + low_ratio: float, +) -> tuple[NDArray[np.float64], NDArray[np.float64], list[float]]: + """Compute FFT and determine cutoff frequencies. + + Args: + signals: 2D signals array + dt: Time step in seconds + levels: Number of cutoff levels to find + low_frac: Fraction defining low-frequency region + low_ratio: Ratio of low to high frequency cutoffs + + Returns: + Tuple of (frequency_array, magnitude_array, cutoff_list) + """ + # Compute FFT + logger.info("[STEP] Computing summed FFT (original data) ...") + freq, mag = _compute_fft_summed(signals, dt) + + # Log frequency range + logger.info( + "Frequency bins: %d | Min/Max freq: %.3e/%.3e Hz", + len(freq), + freq.min(), + freq.max(), + ) + + # Find cutoff frequencies + logger.info( + "Selecting %d cutoff(s) with low-freq bias (<=%.2f cum |FFT|)", + levels, + low_frac, + ) + cutoffs = _find_cutoffs_biased( + freq, + mag, + levels, + low_frac=low_frac, + low_ratio=low_ratio, + min_frac=0.05, + max_frac=0.95, + ) + + # Log selected cutoffs + logger.info("Cutoffs (Hz, ascending): %s", [f"{c:.2e}" for c in cutoffs]) + + return freq, mag, cutoffs + + +def _process_level( + i: int, + total: int, + cutoff: float, + signals: NDArray[np.float64], + fs_hz: float, + dt: float, + folder: Path, + original_trim: NDArray[np.float64], + n_series: int, + frames_to_remove: int, + reuse_existing: bool, +) -> NDArray[np.float64]: + """Process a single filter level (cutoff frequency). + + Applies Butterworth filter, removes edge artifacts, and saves + filtered signals along with diagnostic plots. + + Args: + i: Current level index (1-based) + total: Total number of levels + cutoff: Cutoff frequency for this level + signals: Original signals array + fs_hz: Sampling frequency in Hz + dt: Time step in seconds + folder: Output folder for this level + original_trim: Trimmed original signals (for comparison) + n_series: Number of series (for random sampling) + frames_to_remove: Number of frames to trim from edges + reuse_existing: Whether to reuse existing filtered file + + Returns: + Filtered and trimmed signals array + """ + # Create output folder + _make_dir_safe(folder) + out_path = folder / "filtered_signals.npy" + + # Reuse existing file if requested and available + if out_path.exists() and reuse_existing: + logger.info("[LEVEL %d/%d] cutoff=%.3e Hz -> REUSE", i, total, cutoff) + return np.load(out_path) + + # Log start of computation + logger.info( + "[LEVEL %d/%d] cutoff=%.3e Hz -> COMPUTE & SAVE to: %s", + i, + total, + cutoff, + folder, + ) + + # Apply filter to each series + t0 = time.time() + filtered = np.array( + [_butter_lowpass_filter(row, cutoff, fs_hz) for row in signals] + ) + logger.info("Applied Butterworth via filtfilt in %.2fs", time.time() - t0) + + # Remove edge artifacts + filtered_trim = _remove_filter_artifacts( + filtered, + frames_to_remove=frames_to_remove, + ) + + # Save filtered signals + np.save(out_path, filtered_trim) + logger.info( + "Saved filtered signals: %s -> %s", + filtered_trim.shape, + out_path, + ) + + # Create FFT plot of filtered data + f_filt, mag_filt = _compute_fft_summed(filtered_trim, dt) + _plot_fft( + f_filt, + mag_filt, + f"Summed FFT (filtered, cutoff={cutoff:.2e} Hz)", + folder / "fft_plot.png", + ) + + # Create signals + KDE plot + kde_path = folder / "filt_kde.png" + _plot_signals_with_kde( + filtered_trim, + "Filt_Data + KDE", + kde_path, + ) + + # Create comparison plots for random series + n_pick = min(3, n_series) # Pick up to 3 series + random.seed(42) + rand_atoms = random.sample(range(n_series), n_pick) + + # Align original to same length as filtered + length = filtered_trim.shape[1] + original_aligned = original_trim[:, -length:] + + # Plot original vs filtered for selected series + for idx in rand_atoms: + _plot_single_atom_comparison( + original_aligned, + filtered_trim, + idx, + folder / f"atom_{idx}_comparison.png", + ) + + # Log completion + logger.info( + "Saved %d Original vs Filtered overlays: %s", + n_pick, + rand_atoms, + ) + + return filtered_trim + + +def auto_filtering( + signals: NDArray[np.float64] | None = None, + *, + path: str | Path = ".", + dt_ps: float = 100.0, + levels: int = 50, + out_dir: str | Path | None = None, + reuse_existing: bool = True, + frames_to_remove: int = DEFAULT_FRAMES_TO_REMOVE, + low_frac: float = 0.20, + low_ratio: float = 2.0, + seed: int = 42, + max_overlay_traces: int = 5, + frame_duration: float = 0.25, + drop_first_frame: bool = True, +) -> AutoFiltInsight: + """Automatic multi-level Butterworth lowpass filtering. + + Main workflow: + 1. Load/validate input signals + 2. Compute FFT to find frequency content + 3. Select multiple cutoff frequencies (biased to low freq) + 4. Apply Butterworth filter at each cutoff + 5. Create diagnostic plots and videos + + Args: + signals: Pre-loaded signals array (series x frames), or None + to load from path + path: Path to dataset file/folder (used if signals is None) + dt_ps: Time step in picoseconds + levels: Number of filter cutoff frequencies to test + out_dir: Output directory (default: ./autofilter_outputs) + reuse_existing: Reuse existing filtered files if available + frames_to_remove: Frames to trim from edges (removes filter + artifacts) + low_frac: Cumulative FFT fraction defining "low frequency" + (<0.20 = low) + low_ratio: Ratio of low-freq to high-freq cutoffs (2.0 = twice + as many low) + seed: Random seed for reproducibility + max_overlay_traces: Max traces to overlay in video frames + frame_duration: Duration of each video frame in seconds + drop_first_frame: Drop first frame from input (removes + initialization) + + Returns: + AutoFiltInsight object containing: + - cutoffs: List of cutoff frequencies used + - output_dir: Path to output directory + - filtered_files: Dict mapping cutoff to saved .npy file + - video_path: Path to forward evolution video + - meta: Dictionary of parameters used + - filtered_collection: Tuple of all filtered arrays + """ + # Validate parameters and compute derived values + dt, fs = _validate_params(dt_ps, levels) + + # Load or validate signals + signals = _resolve_signals(signals, path, drop_first_frame) + + # Set up output directory + base = _select_output_dir(out_dir) + + # Compute FFT and determine cutoff frequencies + freq, mag, cutoffs = _fft_and_cutoffs( + signals, + dt, + levels, + low_frac, + low_ratio, + ) + + # Plot original FFT with cutoff markers + _plot_fft( + freq, + mag, + "Summed FFT of Original Signals", + base / "original_fft.png", + mark_freqs=cutoffs, + ) + logger.info("[PLOT] Saved original FFT with cutoff markers.") + + # Set up for filtering loop + fs_hz = fs # Sampling frequency + n_series, _ = signals.shape + logger.info( + "Processing %d cutoff level(s); fs=%.3e Hz, Nyquist=%.3e Hz", + len(cutoffs), + fs_hz, + 0.5 * fs_hz, + ) + + # Trim original signals (for comparison and video) + original_trim = _remove_filter_artifacts( + signals, + frames_to_remove=frames_to_remove, + ) + + # Initialize result containers + filtered_collection: list[NDArray[np.float64]] = [] + used_cutoffs: list[float] = [] + filtered_paths: dict[float, Path] = {} + + # Initialize global y-limits from original data + global_ymin = float(np.nanmin(original_trim)) + global_ymax = float(np.nanmax(original_trim)) + + # Process each cutoff level + for i, cutoff in enumerate(sorted(cutoffs), start=1): + # Create folder for this level + folder = base / f"cutoff_{_freq_label_for_folder(cutoff)}" + + # Filter signals at this cutoff + filtered_trim = _process_level( + i, + len(cutoffs), + cutoff, + signals, + fs_hz, + dt, + folder, + original_trim, + n_series, + frames_to_remove, + reuse_existing, + ) + + # Update global y-limits to encompass this level + global_ymin = min(global_ymin, float(np.nanmin(filtered_trim))) + global_ymax = max(global_ymax, float(np.nanmax(filtered_trim))) + + # Store results + filtered_collection.append(filtered_trim) + used_cutoffs.append(cutoff) + filtered_paths[cutoff] = folder / "filtered_signals.npy" + + # Add padding to y-limits + pad = ( + 1e-7 * (global_ymax - global_ymin) + if global_ymax > global_ymin + else 1.0 + ) + yl = (global_ymin - pad, global_ymax + pad) + + # Create evolution videos with consistent y-limits + video_path = _render_video( + original_trim, + filtered_collection, + used_cutoffs, + out_path_base=base / "filtered_evolution", + y_limits=yl, + max_overlay_traces=max_overlay_traces, + seed=seed, + frame_duration=frame_duration, + ) + + # Save all filtered arrays to single compressed NPZ file + npz_path = base / "filtered_collection.npz" + np.savez_compressed(npz_path, *filtered_collection) + + # Collect metadata + meta = { + "dt_ps": dt_ps, + "frames_to_remove": frames_to_remove, + "levels": levels, + "low_frac": low_frac, + "low_ratio": low_ratio, + "seed": seed, + "frame_duration_s": frame_duration, + } + + # Return results container + return AutoFiltInsight( + cutoffs=used_cutoffs, + output_dir=base, + filtered_files=filtered_paths, + video_path=video_path, + meta=meta, + filtered_collection=tuple(filtered_collection) + ) From 1fe6ea495e4fe4a1e1cc1629ad4ffaeed55f4e86 Mon Sep 17 00:00:00 2001 From: dominosauro Date: Thu, 9 Oct 2025 15:22:48 +0200 Subject: [PATCH 3/9] Added overrides. --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index f5d2ee5d..ec052ce9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,5 +131,6 @@ module = [ 'trackpy.*', 'deeptime.*', 'sklearn.decomposition.*', + 'imageio.*', ] ignore_missing_imports = true From 8a7f4704ff7741996215d618578216566fc21bfe Mon Sep 17 00:00:00 2001 From: dominosauro Date: Thu, 9 Oct 2025 15:33:25 +0200 Subject: [PATCH 4/9] Formatting. --- .../data_processing/auto_filtering.py | 41 ++++++++----------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/src/dynsight/_internal/data_processing/auto_filtering.py b/src/dynsight/_internal/data_processing/auto_filtering.py index cd038158..a092ef27 100644 --- a/src/dynsight/_internal/data_processing/auto_filtering.py +++ b/src/dynsight/_internal/data_processing/auto_filtering.py @@ -34,22 +34,22 @@ # Frequency conversion constants FREQ_TERA = 1e12 # Terahertz in Hz -FREQ_GIGA = 1e9 # Gigahertz in Hz -FREQ_MEGA = 1e6 # Megahertz in Hz -FREQ_KILO = 1e3 # Kilohertz in Hz +FREQ_GIGA = 1e9 # Gigahertz in Hz +FREQ_MEGA = 1e6 # Megahertz in Hz +FREQ_KILO = 1e3 # Kilohertz in Hz # Default parameters for filtering DEFAULT_FRAMES_TO_REMOVE = 20 # Frames to trim from each end -DEFAULT_FILTER_ORDER = 4 # Butterworth filter order +DEFAULT_FILTER_ORDER = 4 # Butterworth filter order # Image processing constants -IMG_NDIM_GRAYSCALE = 2 # Number of dimensions for grayscale -IMG_CHANNELS_RGBA = 4 # Number of channels for RGBA images +IMG_NDIM_GRAYSCALE = 2 # Number of dimensions for grayscale +IMG_CHANNELS_RGBA = 4 # Number of channels for RGBA images # Numerical constants -SMALL_EPSILON = 1e-9 # Small number to avoid division by zero -NDIM_EXPECTED = 2 # Expected number of dimensions for input -MIN_FRAMES_TO_DROP = 2 # Minimum frames needed to drop first frame +SMALL_EPSILON = 1e-9 # Small number to avoid division by zero +NDIM_EXPECTED = 2 # Expected number of dimensions for input +MIN_FRAMES_TO_DROP = 2 # Minimum frames needed to drop first frame # Initialize logger for this module logger = logging.getLogger(__name__) @@ -73,17 +73,15 @@ class AutoFiltInsight: meta: Dictionary of metadata (parameters used) filtered_collection: Tuple of filtered signal arrays """ + # Non-default fields must come first output_dir: Path video_path: Path | None # Default fields (hide large arrays from repr) - cutoffs: list[float] = field( - default_factory=list, repr=False) - filtered_files: dict[float, Path] = field( - default_factory=dict, repr=False) - meta: dict[str, Any] = field( - default_factory=dict, repr=False) + cutoffs: list[float] = field(default_factory=list, repr=False) + filtered_files: dict[float, Path] = field(default_factory=dict, repr=False) + meta: dict[str, Any] = field(default_factory=dict, repr=False) filtered_collection: tuple[ArrayF64, ...] = field( default_factory=tuple, repr=False ) @@ -404,13 +402,13 @@ def _compute_fft_summed( # Sum magnitudes across all series - mag_sum: NDArray[np.float64] = np.asarray(np.abs(fft_vals), - dtype=np.float64).sum(axis=0) + mag_sum: NDArray[np.float64] = np.asarray( + np.abs(fft_vals), dtype=np.float64 + ).sum(axis=0) # Get positive frequencies freq: NDArray[np.float64] = np.asarray(f_all[pos_mask], dtype=np.float64) - return freq, mag_sum @@ -752,9 +750,7 @@ def _finalize_frame(fig: Figure) -> NDArray[np.uint8]: # Convert grayscale to RGB if needed if img_raw.ndim == IMG_NDIM_GRAYSCALE: - img: NDArray[np.uint8] = np.stack( - [img_raw, img_raw, img_raw], axis=2 - ) + img: NDArray[np.uint8] = np.stack([img_raw, img_raw, img_raw], axis=2) else: img = img_raw @@ -869,7 +865,6 @@ def _render_video( # Try to write videos with detailed parameters try: - # Write video writer = imageio.get_writer( avi_path, @@ -1357,5 +1352,5 @@ def auto_filtering( filtered_files=filtered_paths, video_path=video_path, meta=meta, - filtered_collection=tuple(filtered_collection) + filtered_collection=tuple(filtered_collection), ) From 3b4178aa7449b94f6b054bf8ee5b35b45a7d5049 Mon Sep 17 00:00:00 2001 From: dominosauro Date: Thu, 9 Oct 2025 15:41:58 +0200 Subject: [PATCH 5/9] Added dependencies. --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index ec052ce9..f2a7f100 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "trackpy", "ultralytics", "deeptime", + "imageio", ] # Set by cpctools. From d5ff3d877b6b2a78aeb3965952f2f24b881f5f6b Mon Sep 17 00:00:00 2001 From: dominosauro Date: Thu, 9 Oct 2025 15:48:48 +0200 Subject: [PATCH 6/9] Added dependencies. --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index f2a7f100..8c049296 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "ultralytics", "deeptime", "imageio", + "seaborn", ] # Set by cpctools. @@ -133,5 +134,6 @@ module = [ 'deeptime.*', 'sklearn.decomposition.*', 'imageio.*', + 'seaborn.*, ] ignore_missing_imports = true From ebd2941d1d7fb758e093f585c230a0bb8cc40614 Mon Sep 17 00:00:00 2001 From: dominosauro Date: Thu, 9 Oct 2025 15:56:29 +0200 Subject: [PATCH 7/9] Adjusted typo. --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8c049296..5cd5deaa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -134,6 +134,6 @@ module = [ 'deeptime.*', 'sklearn.decomposition.*', 'imageio.*', - 'seaborn.*, + 'seaborn.*', ] ignore_missing_imports = true From 3ea92c8fed391a335a60f3821d63efce1e60710a Mon Sep 17 00:00:00 2001 From: dominosauro Date: Thu, 27 Nov 2025 11:15:14 +0100 Subject: [PATCH 8/9] Added AutoFiltering code. --- .../data_processing/auto_filtering.py | 2052 +++++++++-------- src/dynsight/data_processing.py | 2 + 2 files changed, 1032 insertions(+), 1022 deletions(-) diff --git a/src/dynsight/_internal/data_processing/auto_filtering.py b/src/dynsight/_internal/data_processing/auto_filtering.py index a092ef27..70ce46b6 100644 --- a/src/dynsight/_internal/data_processing/auto_filtering.py +++ b/src/dynsight/_internal/data_processing/auto_filtering.py @@ -76,7 +76,11 @@ class AutoFiltInsight: # Non-default fields must come first output_dir: Path - video_path: Path | None + collection_path: Path + + # Optional fields + video_path: Path | None = None + fft_video_path: Path | None = None # Default fields (hide large arrays from repr) cutoffs: list[float] = field(default_factory=list, repr=False) @@ -87,59 +91,35 @@ class AutoFiltInsight: ) -# --------------------------- Helpers (I/O, plots) --------------------------- +# --------------------------- Helper Functions --------------------------- def _resolve_dataset_path(user_path: str | os.PathLike[str]) -> Path: - """Resolve user input to a concrete dataset file path. - - Accepts either a file or folder. For folders, looks for a single - file with preference: .json > .npy > .npz. - - Args: - user_path: User-provided path (file or directory) - - Returns: - Resolved Path to a dataset file - - Raises: - FileNotFoundError: If path doesn't exist or no valid files found - ValueError: If multiple files of same type found (ambiguous) - """ - # Expand ~ and resolve to absolute path + """Resolve user input to a concrete dataset file path.""" p = Path(user_path).expanduser().resolve() - # If it's already a file, return it if p.is_file(): return p - # Check if path exists at all if not p.exists(): msg = f"Path does not exist: {p}" raise FileNotFoundError(msg) - # If it's a directory, search for dataset files if p.is_dir(): - # Try each extension in preference order for ext in (".json", ".npy", ".npz"): - # Find all files with this extension hits = sorted(p.glob(f"*{ext}")) - # Exactly one file found - use it if len(hits) == 1: return hits[0] - # Multiple files found - ambiguous if len(hits) > 1: names = ", ".join(h.name for h in hits) msg = f"Multiple {ext} files in {p}: {names}" raise ValueError(msg) - # No valid files found msg = f"No .json/.npy/.npz found in {p}" raise FileNotFoundError(msg) - # Shouldn't reach here (not file, not dir, but exists?) msg = f"Unsupported path: {p}" raise FileNotFoundError(msg) @@ -150,86 +130,44 @@ def _load_array_any( mmap_mode: Literal["r+", "r", "w+", "c"] | None = None, enforce_2d: bool = True, ) -> NDArray[np.float64]: - """Load dataset from .json, .npy, or .npz file. - - Wraps loaded data into an Insight object for validation, - then returns the underlying array. - - Args: - path: Path to dataset file - mmap_mode: Memory-mapping mode for numpy.load - enforce_2d: If True, raise error if not 2D array - - Returns: - Loaded numpy array - - Raises: - ValueError: If file type unsupported, empty .npz, or wrong - dimensions - """ - # Get file extension (lowercase) + """Load dataset from .json, .npy, or .npz file.""" sfx = path.suffix.lower() - # Load based on file type if sfx == ".json": - # Load from JSON format arr1 = np.load(path, mmap_mode=mmap_mode) ins = Insight(arr1) elif sfx == ".npy": - # Load from numpy binary format arr = np.load(path, mmap_mode=mmap_mode) ins = Insight(dataset=np.asarray(arr), meta={"source": path.name}) elif sfx == ".npz": - # Load from compressed numpy format z = np.load(path, mmap_mode=mmap_mode) - # Check if npz is empty if not z.files: msg = "Empty .npz file." raise ValueError(msg) - # Use first key in npz key = z.files[0] ins = Insight( dataset=np.asarray(z[key]), meta={"source": path.name, "key": key} ) else: - # Unsupported file type msg = f"Unsupported file type: {sfx}" raise ValueError(msg) - # Validate dimensions if requested if enforce_2d and ins.dataset.ndim != NDIM_EXPECTED: msg = f"Expected 2D array (series x frames), got {ins.dataset.shape}" raise ValueError(msg) - # Return as numpy array return np.asarray(ins.dataset) def _make_dir_safe(directory: Path) -> None: - """Create directory and all parent directories if they don't exist. - - Args: - directory: Path to directory to create - """ - # Create directory with parents, don't error if exists + """Create directory and all parent directories if they don't exist.""" directory.mkdir(parents=True, exist_ok=True) def _freq_label_for_folder(freq_hz: float) -> str: - """Convert frequency in Hz to human-readable string with units. - - Chooses appropriate unit (THz, GHz, MHz, kHz, Hz) based on - magnitude. - - Args: - freq_hz: Frequency in Hertz - - Returns: - Formatted string like "1.234GHz" - """ - # Choose unit based on frequency magnitude + """Convert frequency in Hz to human-readable string with units.""" if freq_hz >= FREQ_TERA: return f"{freq_hz / FREQ_TERA:.3f}THz" if freq_hz >= FREQ_GIGA: @@ -241,942 +179,1104 @@ def _freq_label_for_folder(freq_hz: float) -> str: return f"{freq_hz:.3f}Hz" -def _plot_fft( - freq: NDArray[np.float64], - mag: NDArray[np.float64], - title: str, - path: Path, - mark_freqs: list[float] | None = None, -) -> None: - """Create and save FFT magnitude plot. - - Plots frequency vs magnitude with optional markers for cutoff - frequencies. - - Args: - freq: Frequency array in Hz - mag: Magnitude array (summed across all series) - title: Plot title - path: Where to save the plot - mark_freqs: Optional list of frequencies to mark with scatter - points +# --------------------------- Main Class --------------------------- + + +class AutoFilteringPipeline: + """Automatic multi-level Butterworth lowpass filtering pipeline. + + This class provides a modular interface for applying multi-level + filtering to time series data. Each output type can be generated + independently. + + Example: + >>> import numpy as np + >>> from dynsight._internal.data_processing.auto_filtering import ( + ... AutoFilteringPipeline, + ... ) + >>> # Create dummy data: 10 series, 100 frames + >>> data = np.random.randn(10, 100) + >>> pipeline = AutoFilteringPipeline( + ... signals=data, + ... dt_ps=100.0, + ... levels=5, + ... out_dir="./outputs" + ... ) + >>> pipeline.compute_fft_and_cutoffs() # doctest: +SKIP + >>> pipeline.apply_filtering() # doctest: +SKIP + >>> result = pipeline.save_filtered_collection() # doctest: +SKIP + >>> pipeline.save_fft_plots() # doctest: +SKIP + >>> pipeline.save_cutoff_folders() # doctest: +SKIP + >>> pipeline.create_signal_video() # doctest: +SKIP + >>> pipeline.create_fft_video() # doctest: +SKIP """ - # Create new figure - plt.figure() - - # Plot FFT magnitude vs frequency (in GHz) - plt.plot(freq / FREQ_GIGA, mag, lw=1.5, label="Summed |FFT|") - - # Add markers for cutoff frequencies if provided - if mark_freqs: - # Interpolate magnitude values at cutoff frequencies - y_interp = np.interp(mark_freqs, freq, mag) - # Plot cutoff markers - plt.scatter( - np.array(mark_freqs) / FREQ_GIGA, - y_interp, - s=30, - label="Cutoffs", - ) - - # Format plot - plt.title(title) - plt.xlabel("Frequency (GHz)") - plt.ylabel("Summed Magnitude |FFT|") - plt.grid(alpha=0.3) - plt.legend() - plt.tight_layout() - - # Save and close - plt.savefig(path, dpi=200) - plt.close() - - -def _plot_signals_with_kde( - signals: NDArray[np.float64], title: str, path: Path -) -> None: - """Create dual-panel plot: signals + KDE distribution. - - Left panel shows all signal traces plus mean. - Right panel shows KDE of all signal values. - - Args: - signals: 2D array (series x frames) - title: Plot title - path: Where to save the plot - """ - # Calculate mean signal across all series - mean_signal = np.mean(signals, axis=0) - - # Flatten all values for KDE - all_values = signals.ravel() - - # Create figure with 2 columns - _fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(10, 5)) - - # Left panel: plot all traces in gray - ax1.plot(signals.T, lw=0.3, alpha=0.35, c="gray") - # Overlay mean in red - ax1.plot(mean_signal, color="red", lw=1.0, label="Mean") - ax1.set_title(title) - ax1.set_xlabel("Frame") - ax1.set_ylabel("Signal") - ax1.grid(alpha=0.3) - ax1.legend() - - # Right panel: KDE of all values - sns.kdeplot(y=all_values, ax=ax2, fill=True, alpha=0.3) - ax2.set_title("KDE Distribution") - - # Save and close - plt.tight_layout() - plt.savefig(path, dpi=200) - plt.close() - - -def _plot_single_atom_comparison( - orig: NDArray[np.float64], - filt: NDArray[np.float64], - atom_idx: int, - path: Path, -) -> None: - """Plot original vs filtered signal for a single series. - - Args: - orig: Original signals (series x frames) - filt: Filtered signals (series x frames) - atom_idx: Index of series to plot - path: Where to save the plot - """ - # Create figure - plt.figure() - - # Plot original signal - plt.plot(orig[atom_idx], label="Original", lw=1.0) - # Plot filtered signal - plt.plot(filt[atom_idx], label="Filtered", lw=1.0) - - # Format plot - plt.title(f"Atom/Series {atom_idx}: Original vs Filtered") - plt.xlabel("Frame") - plt.ylabel("Signal") - plt.legend() - plt.grid(alpha=0.3) - plt.tight_layout() - - # Save and close - plt.savefig(path, dpi=200) - plt.close() - - -# --------------------------- Filt helpers --------------------------- - - -def _compute_fft_summed( - signals: NDArray[np.float64], dt: float -) -> tuple[NDArray[np.float64], NDArray[np.float64]]: - """Compute FFT along time axis and sum magnitudes across series. - - Only keeps positive frequencies. Useful for finding dominant - frequency components across all signals. - - Args: - signals: 2D array (series x frames) - dt: Time step in seconds - - Returns: - freq: Positive frequency array - mag_sum: Summed magnitude across all series - """ - # Get shape - _n_series, n_frames = signals.shape - - # Compute frequency bins - f_all = fftfreq(n_frames, d=dt) - - # Keep only positive frequencies - pos_mask = f_all > 0 - - # Compute FFT along time axis (axis=1) - fft_vals = fft(signals, axis=1)[:, pos_mask] - - # Sum magnitudes across all series - - mag_sum: NDArray[np.float64] = np.asarray( - np.abs(fft_vals), dtype=np.float64 - ).sum(axis=0) - # Get positive frequencies - freq: NDArray[np.float64] = np.asarray(f_all[pos_mask], dtype=np.float64) - - return freq, mag_sum - - -def _find_cutoffs_biased( - freq: NDArray[np.float64], - mag: NDArray[np.float64], - num_levels: int, - low_frac: float = 0.20, - low_ratio: float = 2.0, - min_frac: float = 0.05, - max_frac: float = 0.95, -) -> list[float]: - """Find cutoff frequencies biased toward low frequencies. - - Uses cumulative FFT magnitude to select frequencies. Puts more - cutoffs in the low-frequency region (below low_frac) by a ratio - of low_ratio:1. - - Args: - freq: Frequency array - mag: Magnitude array - num_levels: Total number of cutoffs to find - low_frac: Cumulative fraction defining "low frequency" region - low_ratio: Ratio of low-freq to high-freq cutoffs - min_frac: Minimum cumulative fraction to consider - max_frac: Maximum cumulative fraction to consider + def __init__( + self, + signals: NDArray[np.float64] | None = None, + path: str | Path = ".", + dt_ps: float = 100.0, + levels: int = 50, + out_dir: str | Path | None = None, + reuse_existing: bool = True, + frames_to_remove: int = DEFAULT_FRAMES_TO_REMOVE, + low_frac: float = 0.20, + low_ratio: float = 2.0, + seed: int = 42, + drop_first_frame: bool = True, + ) -> None: + """Initialize the filtering pipeline. + + Args: + signals: Pre-loaded signals array (series x frames), or None + path: Path to dataset file/folder (used if signals is None) + dt_ps: Time step in picoseconds + levels: Number of filter cutoff frequencies to test + out_dir: Output directory (default: ./autofilter_outputs) + reuse_existing: Reuse existing filtered files if available + frames_to_remove: Frames to trim from edges + low_frac: Cumulative FFT fraction defining "low frequency" + low_ratio: Ratio of low-freq to high-freq cutoffs + seed: Random seed for reproducibility + drop_first_frame: Drop first frame from input + """ + # Validate and store parameters + self.dt_ps = dt_ps + self.levels = levels + self.frames_to_remove = frames_to_remove + self.low_frac = low_frac + self.low_ratio = low_ratio + self.seed = seed + self.reuse_existing = reuse_existing + + # Validate parameters and compute derived values + self.dt, self.fs = self._validate_params() + + # Load or validate signals + self.signals = self._resolve_signals(signals, path, drop_first_frame) + self.n_series, self.n_frames = self.signals.shape + + # Set up output directory + self.output_dir = self._select_output_dir(out_dir) + + # Initialize state variables (computed later) + self.freq: NDArray[np.float64] | None = None + self.mag: NDArray[np.float64] | None = None + self.cutoffs: list[float] = [] + self.original_trim: NDArray[np.float64] | None = None + self.filtered_collection: list[ArrayF64] = [] + self.filtered_paths: dict[float, Path] = {} + self.global_ymin: float | None = None + self.global_ymax: float | None = None + + logger.info("Pipeline initialized: %s", self.output_dir) + + def _validate_params(self) -> tuple[float, float]: + """Validate input parameters and compute derived values.""" + if self.dt_ps <= 0: + msg = "dt_ps must be > 0" + raise ValueError(msg) - Returns: - List of cutoff frequencies (sorted, unique) - """ - # Compute cumulative sum of magnitude - cum = np.cumsum(mag) + if self.levels < 1: + msg = "levels must be >= 1" + raise ValueError(msg) - # Get total magnitude - total = cum[-1] if cum.size else 0.0 + dt = self.dt_ps * 1e-12 + fs = 1.0 / dt - # Handle all-zero case - if total == 0.0: - warn_msg = ( - "[WARN] Summed magnitude is all zeros; " - "using max frequency as single cutoff." + logger.info( + "dt = %.3g ps (%.3e s) | fs = %.3e Hz | Nyquist = %.3e Hz", + self.dt_ps, + dt, + fs, + 0.5 * fs, ) - logger.warning(warn_msg) - return [float(freq[-1])] if freq.size else [] - - # Normalize to cumulative fraction - cum = cum / total - # Calculate how many cutoffs in each region - n_low = max(1, round(num_levels * (low_ratio / (low_ratio + 1.0)))) - n_high = max(1, num_levels - n_low) - - # Define boundary for low-frequency region - low_hi = max(min(low_frac, max_frac - 1e-6), min_frac + 1e-6) - - # Create thresholds for low-frequency region - th_low = np.linspace(min_frac, low_hi, n_low, endpoint=True) + return dt, fs + + def _resolve_signals( + self, + signals: NDArray[np.float64] | None, + path: str | Path, + drop_first_frame: bool, + ) -> NDArray[np.float64]: + """Load or validate input signals.""" + if signals is None: + ds_path = _resolve_dataset_path(path) + signals = _load_array_any(ds_path) + + if signals.ndim != NDIM_EXPECTED: + msg = f"Expected 2D array (series x frames), got {signals.shape}" + raise ValueError(msg) - # Create thresholds for high-frequency region - th_high = np.linspace(low_hi + 1e-6, max_frac, n_high, endpoint=True) + if drop_first_frame: + if signals.shape[1] < MIN_FRAMES_TO_DROP: + msg = f"Need at least {MIN_FRAMES_TO_DROP} frames." + raise ValueError(msg) + signals = signals[:, 1:] - # Combine all thresholds - thresholds = np.concatenate([th_low, th_high]) + logger.info("Using signals -> shape %s", signals.shape) - # Find frequency for each threshold - cutoffs: list[float] = [] - for th in thresholds: - # Find index where cumulative reaches threshold - idx = np.searchsorted(cum, th, side="left") + return signals - # Clamp to valid range - if idx >= len(freq): - idx = len(freq) - 1 + def _select_output_dir(self, out_dir: str | Path | None) -> Path: + """Create and return output directory path.""" + base = ( + Path(out_dir) + if out_dir is not None + else Path.cwd() / "autofilter_outputs" + ) - # Get frequency at this index - c = float(freq[int(idx)]) + _make_dir_safe(base) - # Only add if not duplicate - if len(cutoffs) == 0 or not np.isclose(c, cutoffs[-1]): - cutoffs.append(c) + return base - # Sort and remove duplicates - cutoffs = sorted(set(cutoffs)) + def compute_fft_and_cutoffs(self) -> None: + """Compute FFT and determine cutoff frequencies. - # Log information about split point - idx_split = np.searchsorted(cum, low_frac, side="left") - f_split = ( - float(freq[min(int(idx_split), len(freq) - 1)]) - if len(freq) - else float("nan") - ) - info_msg = ( - f"[INFO] Biased cutoffs: low_frac={low_frac:.2f} " - f"(~f={f_split:.3e} Hz) | total unique={len(cutoffs)}" - ) - logger.info(info_msg) + This must be called before apply_filtering(). + """ + logger.info("[STEP] Computing summed FFT (original data) ...") + self.freq, self.mag = self._compute_fft_summed(self.signals, self.dt) - return cutoffs + logger.info( + "Frequency bins: %d | Min/Max freq: %.3e/%.3e Hz", + len(self.freq), + self.freq.min(), + self.freq.max(), + ) + logger.info( + "Selecting %d cutoff(s) with low-freq bias (<=%.2f cum |FFT|)", + self.levels, + self.low_frac, + ) + self.cutoffs = self._find_cutoffs_biased( + self.freq, + self.mag, + self.levels, + ) -def _butter_lowpass_filter( - signal: NDArray[np.float64], - cutoff: float, - fs: float, - order: int = DEFAULT_FILTER_ORDER, -) -> NDArray[np.float64]: - """Apply Butterworth lowpass filter to signal. + logger.info( + "Cutoffs (Hz, ascending): %s", + [f"{c:.2e}" for c in self.cutoffs], + ) - Uses zero-phase filtering (filtfilt) to avoid phase distortion. + def apply_filtering(self) -> None: + """Apply Butterworth filtering at all cutoff frequencies. - Args: - signal: 1D signal array - cutoff: Cutoff frequency in Hz - fs: Sampling frequency in Hz - order: Filter order (higher = sharper cutoff) + This must be called after compute_fft_and_cutoffs(). + """ + if not self.cutoffs: + msg = "Must call compute_fft_and_cutoffs() first" + raise RuntimeError(msg) - Returns: - Filtered signal array - """ - # Calculate Nyquist frequency - nyq = 0.5 * fs - - # Check if cutoff is valid - if cutoff >= nyq: - warn_msg = ( - f"[WARN] cutoff {cutoff:.3e} >= Nyquist {nyq:.3e}; " - "passing signal through." + # Trim original signals + self.original_trim = self._remove_filter_artifacts( + self.signals, + frames_to_remove=self.frames_to_remove, ) - logger.warning(warn_msg) - return signal - - # Design Butterworth filter - b, a = butter(order, cutoff / nyq, btype="low") - # Apply zero-phase filter - return filtfilt(b, a, signal) + # Initialize global y-limits from original data + self.global_ymin = float(np.nanmin(self.original_trim)) + self.global_ymax = float(np.nanmax(self.original_trim)) + logger.info( + "Processing %d cutoff level(s); fs=%.3e Hz, Nyquist=%.3e Hz", + len(self.cutoffs), + self.fs, + 0.5 * self.fs, + ) -def _remove_filter_artifacts( - signals: NDArray[np.float64], - frames_to_remove: int = DEFAULT_FRAMES_TO_REMOVE, -) -> NDArray[np.float64]: - """Remove edge frames affected by filtering artifacts. + # Process each cutoff level + for i, cutoff in enumerate(sorted(self.cutoffs), start=1): + filtered_trim = self._apply_single_cutoff(i, cutoff) - Trims the same number of frames from both start and end. + # Update global y-limits + ymin_filt = float(np.nanmin(filtered_trim)) + self.global_ymin = min(self.global_ymin, ymin_filt) + ymax_filt = float(np.nanmax(filtered_trim)) + self.global_ymax = max(self.global_ymax, ymax_filt) - Args: - signals: 2D array (series x frames) - frames_to_remove: Number of frames to remove from each end + self.filtered_collection.append(filtered_trim) - Returns: - Trimmed signal array - """ - # Get shape - _n_series, n_frames = signals.shape - - # Check if we have enough frames to trim - if n_frames <= 2 * frames_to_remove: - warn_msg = ( - f"[WARN] Not enough frames ({n_frames}) to remove " - f"{frames_to_remove} per side. Skipping trim." + logger.info( + "Filtering complete: %d levels processed", + len(self.cutoffs), ) - logger.warning(warn_msg) - return signals - - # Trim frames from both ends - trimmed = signals[:, frames_to_remove:-frames_to_remove] - # Log the operation - logger.info( - f"[STEP] Removed {frames_to_remove} frames per side " - f"-> new shape {trimmed.shape}" - ) + def _apply_single_cutoff(self, level_idx: int, cutoff: float) -> ArrayF64: + """Apply filtering for a single cutoff (core logic only).""" + logger.info( + "[LEVEL %d/%d] cutoff=%.3e Hz -> COMPUTE", + level_idx, + len(self.cutoffs), + cutoff, + ) - return trimmed + # Apply filter to each series + t0 = time.time() + filtered = np.array( + [ + self._butter_lowpass_filter(row, cutoff, self.fs) + for row in self.signals + ] + ) + logger.info( + "Applied Butterworth via filtfilt in %.2fs", + time.time() - t0, + ) + # Remove edge artifacts + return self._remove_filter_artifacts( + filtered, + frames_to_remove=self.frames_to_remove, + ) -# --------------------------- Video helpers --------------------------- + def save_filtered_collection( + self, filename: str = "filtered_collection.npz" + ) -> Path: + """Save all filtered arrays to single NPZ file (MANDATORY OUTPUT). + + Returns: + Path to saved NPZ file + """ + if not self.filtered_collection: + msg = "Must call apply_filtering() first" + raise RuntimeError(msg) + + npz_path = self.output_dir / filename + np.savez_compressed(npz_path, *self.filtered_collection) + + logger.info( + "[OUT] Saved filtered collection (%d levels): %s", + len(self.filtered_collection), + npz_path, + ) + return npz_path + + def save_cutoff_folders( + self, + save_fft_plots: bool = True, + save_kde_plots: bool = True, + save_comparison_plots: bool = True, + n_comparison_atoms: int = 3, + ) -> dict[float, Path]: + """Save individual folders for each cutoff with diagnostics. + + Args: + save_fft_plots: Save FFT plot for each cutoff + save_kde_plots: Save signals + KDE plot for each cutoff + save_comparison_plots: Save original vs filtered comparison + n_comparison_atoms: Number of random series to compare + + Returns: + Dictionary mapping cutoff frequency to folder path + """ + if not self.filtered_collection: + msg = "Must call apply_filtering() first" + raise RuntimeError(msg) + + if self.original_trim is None: + msg = "original_trim is None" + raise RuntimeError(msg) + + folders = {} + + for i, (cutoff, filtered_trim) in enumerate( + zip(self.cutoffs, self.filtered_collection), start=1 + ): + label = _freq_label_for_folder(cutoff) + folder = self.output_dir / f"cutoff_{label}" + _make_dir_safe(folder) + + # Save filtered signals + out_path = folder / "filtered_signals.npy" + np.save(out_path, filtered_trim) + + # Save FFT plot + if save_fft_plots: + f_filt, mag_filt = self._compute_fft_summed( + filtered_trim, self.dt + ) + self._plot_fft( + f_filt, + mag_filt, + f"Summed FFT (filtered, cutoff={cutoff:.2e} Hz)", + folder / "fft_plot.png", + ) + + # Save KDE plot + if save_kde_plots: + self._plot_signals_with_kde( + filtered_trim, + "Filtered Data + KDE", + folder / "filt_kde.png", + ) + + # Save comparison plots + if save_comparison_plots: + n_pick = min(n_comparison_atoms, self.n_series) + random.seed(self.seed) + rand_atoms = random.sample(range(self.n_series), n_pick) + + length = filtered_trim.shape[1] + original_aligned = self.original_trim[:, -length:] + + for idx in rand_atoms: + self._plot_single_atom_comparison( + original_aligned, + filtered_trim, + idx, + folder / f"atom_{idx}_comparison.png", + ) + + folders[cutoff] = folder + self.filtered_paths[cutoff] = out_path + + logger.info( + "[OUT] Saved cutoff folder [%d/%d]: %s", + i, + len(self.cutoffs), + folder, + ) -def _draw_left_panel( - ax: Axes, - x: NDArray[np.float64], - mean: NDArray[np.float64], - std: NDArray[np.float64], - overlay: list[NDArray[np.float64]] | None, - title: str | None, - y_limits: tuple[float, float] | None, - show_legend: bool, -) -> None: - """Draw left panel of video frame showing signal traces. + return folders - Args: - ax: Matplotlib axes to draw on - x: X-axis values (frame indices) - mean: Mean signal across all series - std: Standard deviation across all series - overlay: Optional list of individual traces to overlay - title: Panel title - y_limits: Optional (ymin, ymax) to fix y-axis - show_legend: Whether to show legend - """ - # Plot mean signal - ax.plot(x, mean, lw=1.2, label="Mean") + def save_fft_plots(self, mark_cutoffs: bool = True) -> Path: + """Save FFT plot of original signals (OPTIONAL). - # Fill area for +/- 1 standard deviation - ax.fill_between(x, mean - std, mean + std, alpha=0.25, label="+/- 1 sigma") + Args: + mark_cutoffs: Whether to mark cutoff frequencies on plot - # Overlay individual traces if provided - if overlay: - for tr in overlay: - ax.plot(x, tr, lw=0.6, alpha=0.35) + Returns: + Path to saved plot + """ + if self.freq is None or self.mag is None: + msg = "Must call compute_fft_and_cutoffs() first" + raise RuntimeError(msg) - # Set labels and title - ax.set_title(title or "Filtered") - ax.set_xlabel("Frame (trimmed)") - ax.set_ylabel("Signal") - ax.grid(alpha=0.3) + plot_path = self.output_dir / "original_fft.png" - # Add legend if requested - if show_legend: - ax.legend() + self._plot_fft( + self.freq, + self.mag, + "Summed FFT of Original Signals", + plot_path, + mark_freqs=self.cutoffs if mark_cutoffs else None, + ) - # Set y-limits if provided - if y_limits is not None: - ax.set_ylim(*y_limits) + logger.info("[OUT] Saved original FFT plot: %s", plot_path) + + return plot_path + + def create_signal_video( + self, + max_overlay_traces: int = 5, + frame_duration: float = 0.25, + filename: str = "filtered_evolution", + ) -> Path: + """Create video showing signal evolution during filtering. + + Args: + max_overlay_traces: Max number of individual traces to overlay + frame_duration: Duration of each frame in seconds + filename: Base filename (without extension) + + Returns: + Path to saved video file + """ + if not self.filtered_collection: + msg = "Must call apply_filtering() first" + raise RuntimeError(msg) + + if self.original_trim is None: + msg = "original_trim is None" + raise RuntimeError(msg) + + if self.global_ymin is None or self.global_ymax is None: + msg = "global y-limits are None" + raise RuntimeError(msg) + + # Add padding to y-limits + y_range = self.global_ymax - self.global_ymin + pad = 1e-7 * y_range if y_range > 0 else 1.0 + yl = (self.global_ymin - pad, self.global_ymax + pad) + + video_path = self._render_video( + self.original_trim, + self.filtered_collection, + self.cutoffs, + out_path_base=self.output_dir / filename, + y_limits=yl, + max_overlay_traces=max_overlay_traces, + seed=self.seed, + frame_duration=frame_duration, + ) + logger.info("[OUT] Saved signal evolution video: %s", video_path) + + return video_path + + def create_fft_video( + self, + frame_duration: float = 0.25, + filename: str = "fft_evolution", + ) -> Path: + """Create video showing FFT evolution during filtering. + + Args: + frame_duration: Duration of each frame in seconds + filename: Base filename (without extension) + + Returns: + Path to saved video file + """ + if not self.filtered_collection: + msg = "Must call apply_filtering() first" + raise RuntimeError(msg) + + if self.freq is None or self.mag is None: + msg = "Must call compute_fft_and_cutoffs() first" + raise RuntimeError(msg) + + fft_video_path = self._render_fft_evolution_video( + self.freq, + self.mag, + self.filtered_collection, + self.cutoffs, + self.dt, + out_path_base=self.output_dir / filename, + frame_duration=frame_duration, + ) -def _draw_kde_panel( - ax: Axes, - dist_values: NDArray[np.float64] | None, - kde_bw: float, -) -> None: - """Draw right panel of video frame showing KDE distribution. + logger.info("[OUT] Saved FFT evolution video: %s", fft_video_path) + + return fft_video_path + + def get_result(self) -> AutoFiltInsight: + """Get result container with all metadata and paths. + + Returns: + AutoFiltInsight object with all result information + """ + # Find paths if they exist + collection_path = self.output_dir / "filtered_collection.npz" + video_path = self.output_dir / "filtered_evolution.avi" + fft_video_path = self.output_dir / "fft_evolution.avi" + + meta = { + "dt_ps": self.dt_ps, + "frames_to_remove": self.frames_to_remove, + "levels": self.levels, + "low_frac": self.low_frac, + "low_ratio": self.low_ratio, + "seed": self.seed, + } + + # Ensure collection_path exists before creating result + if not collection_path.exists(): + msg = "filtered_collection.npz does not exist" + raise RuntimeError(msg) + + return AutoFiltInsight( + cutoffs=self.cutoffs, + output_dir=self.output_dir, + collection_path=collection_path, + filtered_files=self.filtered_paths, + video_path=video_path if video_path.exists() else None, + fft_video_path=( + fft_video_path if fft_video_path.exists() else None + ), + meta=meta, + filtered_collection=tuple(self.filtered_collection), + ) - Args: - ax: Matplotlib axes to draw on - dist_values: Array of all signal values for KDE - kde_bw: Bandwidth adjustment for KDE - """ - # Set title and labels - ax.set_title("KDE") - ax.set_xlabel("Density") - ax.grid(alpha=0.3) - - # Return early if no data - if dist_values is None: - return - - # Convert to array and remove non-finite values - vals = np.asarray(dist_values) - vals = vals[np.isfinite(vals)] - - # Plot KDE if we have valid data with variance - if vals.size > 1 and np.nanstd(vals) > 0: - sns.kdeplot(y=vals, ax=ax, fill=True, alpha=0.3, bw_adjust=kde_bw) - # Just draw horizontal line if constant value - elif vals.size > 0: - ax.axhline(float(vals[0]), ls="--", alpha=0.6) - - # Move y-axis to right side - ax.yaxis.tick_right() - ax.yaxis.set_label_position("right") - - -def _render_frame_array( - mean: NDArray[np.float64], - std: NDArray[np.float64], - y_limits: tuple[float, float] | None = None, - title_override: str | None = None, - overlay: list[NDArray[np.float64]] | None = None, - dist_values: NDArray[np.float64] | None = None, - show_legend: bool = False, - kde_bw: float = 1.0, -) -> NDArray[np.uint8]: - """Render a single video frame as a numpy image array. - - Creates a two-panel figure (signals + KDE) and converts to RGB - array. + # --------------------------- Private Helper Methods --------------- - Args: - mean: Mean signal - std: Standard deviation of signal - y_limits: Optional y-axis limits - title_override: Title for left panel - overlay: Optional traces to overlay - dist_values: Values for KDE plot - show_legend: Whether to show legend - kde_bw: KDE bandwidth adjustment + def _compute_fft_summed( + self, signals: NDArray[np.float64], dt: float + ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: + """Compute FFT along time axis and sum magnitudes across series.""" + _n_series, n_frames = signals.shape - Returns: - RGB image array (height x width x 3) - """ - # Create x-axis values - x = np.arange(mean.size, dtype=np.float64) - - # Create figure with 2 columns sharing y-axis - fig, (ax1, ax2) = plt.subplots( - ncols=2, - figsize=(10, 4), - sharey=True, - facecolor="white", - ) + f_all = fftfreq(n_frames, d=dt) + pos_mask = f_all > 0 - # Draw left panel (signals) - _draw_left_panel( - ax1, x, mean, std, overlay, title_override, y_limits, show_legend - ) + fft_vals = fft(signals, axis=1)[:, pos_mask] - # Draw right panel (KDE) - _draw_kde_panel(ax2, dist_values, kde_bw) + mag_sum: NDArray[np.float64] = np.asarray( + np.abs(fft_vals), dtype=np.float64 + ).sum(axis=0) - # Adjust layout - fig.tight_layout() + freq: NDArray[np.float64] = np.asarray( + f_all[pos_mask], dtype=np.float64 + ) - # Convert to image array - return _finalize_frame(fig) + return freq, mag_sum + + def _find_cutoffs_biased( + self, + freq: NDArray[np.float64], + mag: NDArray[np.float64], + num_levels: int, + ) -> list[float]: + """Find cutoff frequencies biased toward low frequencies.""" + cum = np.cumsum(mag) + total = cum[-1] if cum.size else 0.0 + + if total == 0.0: + warn_msg = ( + "[WARN] Summed magnitude is all zeros; " + "using max frequency as single cutoff." + ) + logger.warning(warn_msg) + return [float(freq[-1])] if freq.size else [] + cum = cum / total -def _finalize_frame(fig: Figure) -> NDArray[np.uint8]: - """Convert matplotlib figure to RGB numpy array. + ratio_sum = self.low_ratio + 1.0 + n_low = max(1, round(num_levels * (self.low_ratio / ratio_sum))) + n_high = max(1, num_levels - n_low) - Ensures dimensions are even (required for some video codecs). + min_frac = 0.05 + max_frac = 0.95 + low_hi = max(min(self.low_frac, max_frac - 1e-6), min_frac + 1e-6) - Args: - fig: Matplotlib figure + th_low = np.linspace(min_frac, low_hi, n_low, endpoint=True) + th_high = np.linspace(low_hi + 1e-6, max_frac, n_high, endpoint=True) - Returns: - RGB image array (height x width x 3) with even dimensions - """ - # Save figure to bytes buffer - buf = io.BytesIO() - fig.savefig(buf, format="png", dpi=220, facecolor="white") - plt.close(fig) - - # Read back as image - buf.seek(0) - img_raw = imageio.imread(buf) - - # Convert grayscale to RGB if needed - if img_raw.ndim == IMG_NDIM_GRAYSCALE: - img: NDArray[np.uint8] = np.stack([img_raw, img_raw, img_raw], axis=2) - else: - img = img_raw + thresholds = np.concatenate([th_low, th_high]) - # Remove alpha channel if present - if img.shape[2] == IMG_CHANNELS_RGBA: - img = img[:, :, :3] + cutoffs: list[float] = [] + for th in thresholds: + idx = np.searchsorted(cum, th, side="left") - # Ensure even dimensions (crop if needed) - h, w = img.shape[:2] - if h % 2 or w % 2: - img = img[: h - (h % 2), : w - (w % 2), :] + if idx >= len(freq): + idx = len(freq) - 1 - return img + c = float(freq[int(idx)]) + if len(cutoffs) == 0 or not np.isclose(c, cutoffs[-1]): + cutoffs.append(c) -def _render_video( - raw_trim: NDArray[np.float64], - filtered_list: list[NDArray[np.float64]], - cutoffs: list[float], - out_path_base: Path, - y_limits: tuple[float, float] | None = None, - max_overlay_traces: int = 10, - seed: int = 42, - frame_duration: float = 0.25, -) -> Path: - """Create videos showing filter evolution. + cutoffs = sorted(set(cutoffs)) - First frame is original (unfiltered), subsequent frames show - progressively filtered signals. + idx_split = np.searchsorted(cum, self.low_frac, side="left") + f_split = ( + float(freq[min(int(idx_split), len(freq) - 1)]) + if len(freq) + else float("nan") + ) + info_msg = ( + f"[INFO] Biased cutoffs: low_frac={self.low_frac:.2f} " + f"(~f={f_split:.3e} Hz) | total unique={len(cutoffs)}" + ) + logger.info(info_msg) + + return cutoffs + + def _butter_lowpass_filter( + self, + signal: NDArray[np.float64], + cutoff: float, + fs: float, + order: int = DEFAULT_FILTER_ORDER, + ) -> NDArray[np.float64]: + """Apply Butterworth lowpass filter to signal.""" + nyq = 0.5 * fs + + if cutoff >= nyq: + warn_msg = ( + f"[WARN] cutoff {cutoff:.3e} >= Nyquist {nyq:.3e}; " + "passing signal through." + ) + logger.warning(warn_msg) + return signal - Args: - raw_trim: Original trimmed signals - filtered_list: List of filtered signal arrays (one per cutoff) - cutoffs: List of cutoff frequencies - out_path_base: Base path for output videos (without extension) - y_limits: Optional fixed y-axis limits for all frames - max_overlay_traces: Max number of individual traces to overlay - seed: Random seed for trace selection - frame_duration: Duration of each frame in seconds + b, a = butter(order, cutoff / nyq, btype="low") - Returns: - Tuple of (video_path) - """ - # Initialize random number generator - rng = random.Random(seed) # noqa: S311 - frames: list[NDArray[np.uint8]] = [] + return filtfilt(b, a, signal) - # Simplify y_limits variable - yl = None if y_limits is None else y_limits + def _remove_filter_artifacts( + self, + signals: NDArray[np.float64], + frames_to_remove: int = DEFAULT_FRAMES_TO_REMOVE, + ) -> NDArray[np.float64]: + """Remove edge frames affected by filtering artifacts.""" + _n_series, n_frames = signals.shape - # ---- RAW frame (original, unfiltered) ---- + if n_frames <= 2 * frames_to_remove: + warn_msg = ( + f"[WARN] Not enough frames ({n_frames}) to remove " + f"{frames_to_remove} per side. Skipping trim." + ) + logger.warning(warn_msg) + return signals - # Randomly select traces to overlay - picks = ( - rng.sample( - range(raw_trim.shape[0]), - k=min(max_overlay_traces, raw_trim.shape[0]), - ) - if raw_trim.shape[0] - else [] - ) - traces = [raw_trim[i] for i in picks] + trimmed = signals[:, frames_to_remove:-frames_to_remove] - # Render original data frame - frames.append( - _render_frame_array( - raw_trim.mean(axis=0), - raw_trim.std(axis=0), - y_limits=yl, - title_override="Original (unfiltered, trimmed)", - overlay=traces, - dist_values=raw_trim.ravel(), - show_legend=False, + logger.info( + f"[STEP] Removed {frames_to_remove} frames per side " + f"-> new shape {trimmed.shape}" ) - ) - # ---- FILTERED frames ---- + return trimmed + + def _plot_fft( + self, + freq: NDArray[np.float64], + mag: NDArray[np.float64], + title: str, + path: Path, + mark_freqs: list[float] | None = None, + ) -> None: + """Create and save FFT magnitude plot.""" + plt.figure() + + plt.plot(freq / FREQ_GIGA, mag, lw=1.5, label="Summed |FFT|") + + if mark_freqs: + y_interp = np.interp(mark_freqs, freq, mag) + plt.scatter( + np.array(mark_freqs) / FREQ_GIGA, + y_interp, + s=30, + label="Cutoffs", + ) - # Create one frame for each cutoff level - for filt, cutoff in zip(filtered_list, cutoffs): - # Randomly select traces to overlay + plt.title(title) + plt.xlabel("Frequency (GHz)") + plt.ylabel("Summed Magnitude |FFT|") + plt.grid(alpha=0.3) + plt.legend() + plt.tight_layout() + + plt.savefig(path, dpi=200) + plt.close() + + def _plot_signals_with_kde( + self, signals: NDArray[np.float64], title: str, path: Path + ) -> None: + """Create dual-panel plot: signals + KDE distribution.""" + mean_signal = np.mean(signals, axis=0) + all_values = signals.ravel() + + _fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(10, 5)) + + ax1.plot(signals.T, lw=0.3, alpha=0.35, c="gray") + ax1.plot(mean_signal, color="red", lw=1.0, label="Mean") + ax1.set_title(title) + ax1.set_xlabel("Frame") + ax1.set_ylabel("Signal") + ax1.grid(alpha=0.3) + ax1.legend() + + sns.kdeplot(y=all_values, ax=ax2, fill=True, alpha=0.3) + ax2.set_title("KDE Distribution") + + plt.tight_layout() + plt.savefig(path, dpi=200) + plt.close() + + def _plot_single_atom_comparison( + self, + orig: NDArray[np.float64], + filt: NDArray[np.float64], + atom_idx: int, + path: Path, + ) -> None: + """Plot original vs filtered signal for a single series.""" + plt.figure() + + plt.plot(orig[atom_idx], label="Original", lw=1.0) + plt.plot(filt[atom_idx], label="Filtered", lw=1.0) + + plt.title(f"Atom/Series {atom_idx}: Original vs Filtered") + plt.xlabel("Frame") + plt.ylabel("Signal") + plt.legend() + plt.grid(alpha=0.3) + plt.tight_layout() + + plt.savefig(path, dpi=200) + plt.close() + + def _render_video( + self, + raw_trim: NDArray[np.float64], + filtered_list: list[NDArray[np.float64]], + cutoffs: list[float], + out_path_base: Path, + y_limits: tuple[float, float] | None = None, + max_overlay_traces: int = 10, + seed: int = 42, + frame_duration: float = 0.25, + ) -> Path: + """Create videos showing filter evolution.""" + rng = random.Random(seed) # noqa: S311 + frames: list[NDArray[np.uint8]] = [] + + yl = None if y_limits is None else y_limits + + # RAW frame (original, unfiltered) picks = ( rng.sample( - range(filt.shape[0]), k=min(max_overlay_traces, filt.shape[0]) + range(raw_trim.shape[0]), + k=min(max_overlay_traces, raw_trim.shape[0]), ) - if filt.shape[0] + if raw_trim.shape[0] else [] ) - traces = [filt[i] for i in picks] + traces = [raw_trim[i] for i in picks] - # Create title with cutoff frequency - title = f"Filtered @ {_freq_label_for_folder(cutoff)}" - - # Render filtered frame frames.append( - _render_frame_array( - filt.mean(axis=0), - filt.std(axis=0), + self._render_frame_array( + raw_trim.mean(axis=0), + raw_trim.std(axis=0), y_limits=yl, - title_override=title, + title_override="Original (unfiltered, trimmed)", overlay=traces, - dist_values=filt.ravel(), + dist_values=raw_trim.ravel(), show_legend=False, ) ) - # Calculate FPS from frame duration - fps = 1.0 / max(frame_duration, SMALL_EPSILON) - - # Define output paths - avi_path = out_path_base.with_suffix(".avi") - - # Try to write videos with detailed parameters - try: - # Write video - writer = imageio.get_writer( - avi_path, - format="FFMPEG", # type: ignore[arg-type] - mode="I", - fps=fps, - codec="mpeg4", - bitrate="10M", - macro_block_size=None, - output_params=["-pix_fmt", "yuv420p"], - ) - for f in frames[::-1]: - writer.append_data(f) - writer.close() - except TypeError as e: - msg = f"Video creation failed for {avi_path}: {e}" - raise RuntimeError(msg) from e - - # Log success - logger.info(f"[OUT ] Saved AVI ({len(frames)} frames): {avi_path}") - - return avi_path - - -# --------------------------- Core workflow --------------------------- - - -def _validate_params(dt_ps: float, levels: int) -> tuple[float, float]: - """Validate input parameters and compute derived values. - - Args: - dt_ps: Time step in picoseconds - levels: Number of filter levels - - Returns: - Tuple of (dt_seconds, sampling_frequency_hz) - - Raises: - ValueError: If parameters are invalid - """ - # Check dt is positive - if dt_ps <= 0: - msg = "dt_ps must be > 0" - raise ValueError(msg) - - # Check levels is at least 1 - if levels < 1: - msg = "levels must be >= 1" - raise ValueError(msg) - - # Convert dt to seconds - dt = dt_ps * 1e-12 - - # Calculate sampling frequency - fs = 1.0 / dt + # FILTERED frames + for filt, cutoff in zip(filtered_list, cutoffs): + n_series = filt.shape[0] + picks = ( + rng.sample( + range(n_series), k=min(max_overlay_traces, n_series) + ) + if n_series + else [] + ) + traces = [filt[i] for i in picks] + + title = f"Filtered @ {_freq_label_for_folder(cutoff)}" + + frames.append( + self._render_frame_array( + filt.mean(axis=0), + filt.std(axis=0), + y_limits=yl, + title_override=title, + overlay=traces, + dist_values=filt.ravel(), + show_legend=False, + ) + ) - # Log parameters - logger.info( - "dt = %.3g ps (%.3e s) | fs = %.3e Hz | Nyquist = %.3e Hz", - dt_ps, - dt, - fs, - 0.5 * fs, - ) + fps = 1.0 / max(frame_duration, SMALL_EPSILON) - return dt, fs + avi_path = out_path_base.with_suffix(".avi") + try: + writer = imageio.get_writer( + avi_path, + format="FFMPEG", # type: ignore[arg-type] + mode="I", + fps=fps, + codec="mpeg4", + bitrate="10M", + macro_block_size=None, + output_params=["-pix_fmt", "yuv420p"], + ) + for f in frames[::-1]: + writer.append_data(f) + writer.close() + except TypeError as e: + msg = f"Video creation failed for {avi_path}: {e}" + raise RuntimeError(msg) from e + + logger.info(f"[OUT ] Saved AVI ({len(frames)} frames): {avi_path}") + + return avi_path + + def _render_frame_array( + self, + mean: NDArray[np.float64], + std: NDArray[np.float64], + y_limits: tuple[float, float] | None = None, + title_override: str | None = None, + overlay: list[NDArray[np.float64]] | None = None, + dist_values: NDArray[np.float64] | None = None, + show_legend: bool = False, + kde_bw: float = 1.0, + ) -> NDArray[np.uint8]: + """Render a single video frame as a numpy image array.""" + x = np.arange(mean.size, dtype=np.float64) + + fig, (ax1, ax2) = plt.subplots( + ncols=2, + figsize=(10, 4), + sharey=True, + facecolor="white", + ) -def _resolve_signals( - signals: NDArray[np.float64] | None, - path: str | Path, - drop_first_frame: bool, -) -> NDArray[np.float64]: - """Load or validate input signals. + self._draw_left_panel( + ax1, + x, + mean, + std, + overlay, + title_override, + y_limits, + show_legend, + ) - If signals array is provided, use it. Otherwise load from path. - Optionally drops first frame to remove initialization artifacts. + self._draw_kde_panel(ax2, dist_values, kde_bw) - Args: - signals: Optional pre-loaded signals array - path: Path to load signals from (if signals is None) - drop_first_frame: Whether to drop the first frame + fig.tight_layout() - Returns: - Validated 2D signals array + return self._finalize_frame(fig) - Raises: - ValueError: If array is not 2D or too few frames to drop - """ - # Load signals if not provided - if signals is None: - ds_path = _resolve_dataset_path(path) - signals = _load_array_any(ds_path) - - # Validate dimensions - if signals.ndim != NDIM_EXPECTED: - msg = f"Expected 2D array (series x frames), got {signals.shape}" - raise ValueError(msg) + def _draw_left_panel( + self, + ax: Axes, + x: NDArray[np.float64], + mean: NDArray[np.float64], + std: NDArray[np.float64], + overlay: list[NDArray[np.float64]] | None, + title: str | None, + y_limits: tuple[float, float] | None, + show_legend: bool, + ) -> None: + """Draw left panel of video frame showing signal traces.""" + ax.plot(x, mean, lw=1.2, label="Mean") - # Drop first frame if requested - if drop_first_frame: - # Check we have enough frames - if signals.shape[1] < MIN_FRAMES_TO_DROP: - msg = f"Need at least {MIN_FRAMES_TO_DROP} frames." - raise ValueError(msg) - # Remove first frame - signals = signals[:, 1:] + ax.fill_between( + x, mean - std, mean + std, alpha=0.25, label="+/- 1 sigma" + ) - # Log final shape - logger.info("Using signals -> shape %s", signals.shape) + if overlay: + for tr in overlay: + ax.plot(x, tr, lw=0.6, alpha=0.35) + + ax.set_title(title or "Filtered") + ax.set_xlabel("Frame (trimmed)") + ax.set_ylabel("Signal") + ax.grid(alpha=0.3) + + if show_legend: + ax.legend() + + if y_limits is not None: + ax.set_ylim(*y_limits) + + def _draw_kde_panel( + self, + ax: Axes, + dist_values: NDArray[np.float64] | None, + kde_bw: float, + ) -> None: + """Draw right panel of video frame showing KDE distribution.""" + ax.set_title("KDE") + ax.set_xlabel("Density") + ax.grid(alpha=0.3) + + if dist_values is None: + return + + vals = np.asarray(dist_values) + vals = vals[np.isfinite(vals)] + + if vals.size > 1 and np.nanstd(vals) > 0: + sns.kdeplot( + y=vals, ax=ax, fill=True, alpha=0.3, bw_adjust=kde_bw + ) + elif vals.size > 0: + ax.axhline(float(vals[0]), ls="--", alpha=0.6) - return signals + ax.yaxis.tick_right() + ax.yaxis.set_label_position("right") + def _finalize_frame(self, fig: Figure) -> NDArray[np.uint8]: + """Convert matplotlib figure to RGB numpy array.""" + buf = io.BytesIO() + fig.savefig(buf, format="png", dpi=220, facecolor="white") + plt.close(fig) -def _select_output_dir(out_dir: str | Path | None) -> Path: - """Create and return output directory path. + buf.seek(0) + img_raw = imageio.imread(buf) - Uses provided path or creates default in current directory. + if img_raw.ndim == IMG_NDIM_GRAYSCALE: + img: NDArray[np.uint8] = np.stack( + [img_raw, img_raw, img_raw], axis=2 + ) + else: + img = img_raw - Args: - out_dir: Optional output directory path + if img.shape[2] == IMG_CHANNELS_RGBA: + img = img[:, :, :3] - Returns: - Path to output directory (created if doesn't exist) - """ - # Use provided path or create default - base = ( - Path(out_dir) - if out_dir is not None - else Path.cwd() / "autofilter_outputs" - ) + h, w = img.shape[:2] + if h % 2 or w % 2: + img = img[: h - (h % 2), : w - (w % 2), :] - # Create directory if needed - _make_dir_safe(base) + return img - return base + def _render_fft_evolution_video( + self, + freq_original: NDArray[np.float64], + mag_original: NDArray[np.float64], + filtered_list: list[NDArray[np.float64]], + cutoffs: list[float], + dt: float, + out_path_base: Path, + frame_duration: float = 0.25, + ) -> Path: + """Create video showing FFT evolution during filtering.""" + frames: list[NDArray[np.uint8]] = [] + logger.info("[FFT VIDEO] Rendering original frame (no filtering)...") -def _fft_and_cutoffs( - signals: NDArray[np.float64], - dt: float, - levels: int, - low_frac: float, - low_ratio: float, -) -> tuple[NDArray[np.float64], NDArray[np.float64], list[float]]: - """Compute FFT and determine cutoff frequencies. + frames.append( + self._render_fft_frame( + freq_original, + mag_original, + mag_original, + cutoff=None, + title_left="Original FFT (unfiltered)", + title_right="Original FFT (unfiltered)", + ) + ) - Args: - signals: 2D signals array - dt: Time step in seconds - levels: Number of cutoff levels to find - low_frac: Fraction defining low-frequency region - low_ratio: Ratio of low to high frequency cutoffs + for i, (filt, cutoff) in enumerate( + zip(filtered_list, cutoffs), start=1 + ): + logger.info( + "[FFT VIDEO] Rendering frame %d/%d (cutoff=%.3e Hz)...", + i, + len(cutoffs), + cutoff, + ) - Returns: - Tuple of (frequency_array, magnitude_array, cutoff_list) - """ - # Compute FFT - logger.info("[STEP] Computing summed FFT (original data) ...") - freq, mag = _compute_fft_summed(signals, dt) - - # Log frequency range - logger.info( - "Frequency bins: %d | Min/Max freq: %.3e/%.3e Hz", - len(freq), - freq.min(), - freq.max(), - ) + freq_filt, mag_filt = self._compute_fft_summed(filt, dt) - # Find cutoff frequencies - logger.info( - "Selecting %d cutoff(s) with low-freq bias (<=%.2f cum |FFT|)", - levels, - low_frac, - ) - cutoffs = _find_cutoffs_biased( - freq, - mag, - levels, - low_frac=low_frac, - low_ratio=low_ratio, - min_frac=0.05, - max_frac=0.95, - ) + mag_filt_interp = np.interp( + freq_original, + freq_filt, + mag_filt, + left=0.0, + right=0.0, + ) - # Log selected cutoffs - logger.info("Cutoffs (Hz, ascending): %s", [f"{c:.2e}" for c in cutoffs]) + cutoff_label = _freq_label_for_folder(cutoff) + frames.append( + self._render_fft_frame( + freq_original, + mag_original, + mag_filt_interp, + cutoff=cutoff, + title_left=f"Original FFT (cutoff @ {cutoff_label})", + title_right=f"Filtered FFT (cutoff @ {cutoff_label})", + ) + ) - return freq, mag, cutoffs + fps = 1.0 / max(frame_duration, SMALL_EPSILON) + avi_path = out_path_base.with_suffix(".avi") -def _process_level( - i: int, - total: int, - cutoff: float, - signals: NDArray[np.float64], - fs_hz: float, - dt: float, - folder: Path, - original_trim: NDArray[np.float64], - n_series: int, - frames_to_remove: int, - reuse_existing: bool, -) -> NDArray[np.float64]: - """Process a single filter level (cutoff frequency). + try: + writer = imageio.get_writer( + avi_path, + format="FFMPEG", # type: ignore[arg-type] + mode="I", + fps=fps, + codec="mpeg4", + bitrate="10M", + macro_block_size=None, + output_params=["-pix_fmt", "yuv420p"], + ) + for f in frames[::-1]: + writer.append_data(f) + writer.close() + except TypeError as e: + msg = f"FFT video creation failed for {avi_path}: {e}" + raise RuntimeError(msg) from e + + logger.info( + "[OUT ] Saved FFT evolution video (%d frames): %s", + len(frames), + avi_path, + ) - Applies Butterworth filter, removes edge artifacts, and saves - filtered signals along with diagnostic plots. + return avi_path + + def _render_fft_frame( + self, + freq: NDArray[np.float64], + mag_original: NDArray[np.float64], + mag_filtered: NDArray[np.float64], + cutoff: float | None, + title_left: str, + title_right: str, + ) -> NDArray[np.uint8]: + """Render a single FFT comparison frame.""" + fig, (ax1, ax2) = plt.subplots( + ncols=2, + figsize=(12, 4), + facecolor="white", + ) - Args: - i: Current level index (1-based) - total: Total number of levels - cutoff: Cutoff frequency for this level - signals: Original signals array - fs_hz: Sampling frequency in Hz - dt: Time step in seconds - folder: Output folder for this level - original_trim: Trimmed original signals (for comparison) - n_series: Number of series (for random sampling) - frames_to_remove: Number of frames to trim from edges - reuse_existing: Whether to reuse existing filtered file + freq_ghz = freq / FREQ_GIGA - Returns: - Filtered and trimmed signals array - """ - # Create output folder - _make_dir_safe(folder) - out_path = folder / "filtered_signals.npy" - - # Reuse existing file if requested and available - if out_path.exists() and reuse_existing: - logger.info("[LEVEL %d/%d] cutoff=%.3e Hz -> REUSE", i, total, cutoff) - return np.load(out_path) - - # Log start of computation - logger.info( - "[LEVEL %d/%d] cutoff=%.3e Hz -> COMPUTE & SAVE to: %s", - i, - total, - cutoff, - folder, - ) + ax1.plot( + freq_ghz, + mag_original, + lw=1.5, + label="Original FFT", + color="blue", + ) - # Apply filter to each series - t0 = time.time() - filtered = np.array( - [_butter_lowpass_filter(row, cutoff, fs_hz) for row in signals] - ) - logger.info("Applied Butterworth via filtfilt in %.2fs", time.time() - t0) + if cutoff is not None: + cutoff_ghz = cutoff / FREQ_GIGA + y_interp = np.interp(cutoff, freq, mag_original) - # Remove edge artifacts - filtered_trim = _remove_filter_artifacts( - filtered, - frames_to_remove=frames_to_remove, - ) + ax1.axvline(cutoff_ghz, color="red", ls="--", lw=1.5, alpha=0.7) - # Save filtered signals - np.save(out_path, filtered_trim) - logger.info( - "Saved filtered signals: %s -> %s", - filtered_trim.shape, - out_path, - ) + cutoff_label = _freq_label_for_folder(cutoff) + ax1.scatter( + [cutoff_ghz], + [y_interp], + s=80, + color="red", + zorder=5, + label=f"Cutoff: {cutoff_label}", + ) - # Create FFT plot of filtered data - f_filt, mag_filt = _compute_fft_summed(filtered_trim, dt) - _plot_fft( - f_filt, - mag_filt, - f"Summed FFT (filtered, cutoff={cutoff:.2e} Hz)", - folder / "fft_plot.png", - ) + ax1.set_title(title_left) + ax1.set_xlabel("Frequency (GHz)") + ax1.set_ylabel("Summed Magnitude |FFT|") + ax1.grid(alpha=0.3) + ax1.legend() + + ax2.plot( + freq_ghz, + mag_filtered, + lw=1.5, + label="Filtered FFT", + color="green", + ) + ax2.set_title(title_right) + ax2.set_xlabel("Frequency (GHz)") + ax2.set_ylabel("Summed Magnitude |FFT|") + ax2.grid(alpha=0.3) + ax2.legend() - # Create signals + KDE plot - kde_path = folder / "filt_kde.png" - _plot_signals_with_kde( - filtered_trim, - "Filt_Data + KDE", - kde_path, - ) + fig.tight_layout() - # Create comparison plots for random series - n_pick = min(3, n_series) # Pick up to 3 series - random.seed(42) - rand_atoms = random.sample(range(n_series), n_pick) - - # Align original to same length as filtered - length = filtered_trim.shape[1] - original_aligned = original_trim[:, -length:] - - # Plot original vs filtered for selected series - for idx in rand_atoms: - _plot_single_atom_comparison( - original_aligned, - filtered_trim, - idx, - folder / f"atom_{idx}_comparison.png", - ) + return self._finalize_frame(fig) - # Log completion - logger.info( - "Saved %d Original vs Filtered overlays: %s", - n_pick, - rand_atoms, - ) - return filtered_trim +# --------------------------- Convenience Function --------------------------- def auto_filtering( signals: NDArray[np.float64] | None = None, - *, path: str | Path = ".", dt_ps: float = 100.0, levels: int = 50, @@ -1189,168 +1289,76 @@ def auto_filtering( max_overlay_traces: int = 5, frame_duration: float = 0.25, drop_first_frame: bool = True, + save_cutoff_folders: bool = True, + save_fft_plots: bool = True, + save_signal_video: bool = True, + save_fft_video: bool = True, ) -> AutoFiltInsight: - """Automatic multi-level Butterworth lowpass filtering. + """Convenience function for complete automatic filtering workflow. - Main workflow: - 1. Load/validate input signals - 2. Compute FFT to find frequency content - 3. Select multiple cutoff frequencies (biased to low freq) - 4. Apply Butterworth filter at each cutoff - 5. Create diagnostic plots and videos + This function provides the same interface as the original code but + uses the new class-based implementation internally. Each output type + can be controlled via parameters. Args: signals: Pre-loaded signals array (series x frames), or None - to load from path path: Path to dataset file/folder (used if signals is None) dt_ps: Time step in picoseconds levels: Number of filter cutoff frequencies to test out_dir: Output directory (default: ./autofilter_outputs) reuse_existing: Reuse existing filtered files if available - frames_to_remove: Frames to trim from edges (removes filter - artifacts) + frames_to_remove: Frames to trim from edges low_frac: Cumulative FFT fraction defining "low frequency" - (<0.20 = low) - low_ratio: Ratio of low-freq to high-freq cutoffs (2.0 = twice - as many low) + low_ratio: Ratio of low-freq to high-freq cutoffs seed: Random seed for reproducibility max_overlay_traces: Max traces to overlay in video frames frame_duration: Duration of each video frame in seconds - drop_first_frame: Drop first frame from input (removes - initialization) + drop_first_frame: Drop first frame from input + save_cutoff_folders: Save individual cutoff folders + save_fft_plots: Save FFT plots + save_signal_video: Create signal evolution video + save_fft_video: Create FFT evolution video Returns: - AutoFiltInsight object containing: - - cutoffs: List of cutoff frequencies used - - output_dir: Path to output directory - - filtered_files: Dict mapping cutoff to saved .npy file - - video_path: Path to forward evolution video - - meta: Dictionary of parameters used - - filtered_collection: Tuple of all filtered arrays + AutoFiltInsight object with all results and metadata """ - # Validate parameters and compute derived values - dt, fs = _validate_params(dt_ps, levels) - - # Load or validate signals - signals = _resolve_signals(signals, path, drop_first_frame) - - # Set up output directory - base = _select_output_dir(out_dir) - - # Compute FFT and determine cutoff frequencies - freq, mag, cutoffs = _fft_and_cutoffs( - signals, - dt, - levels, - low_frac, - low_ratio, - ) - - # Plot original FFT with cutoff markers - _plot_fft( - freq, - mag, - "Summed FFT of Original Signals", - base / "original_fft.png", - mark_freqs=cutoffs, - ) - logger.info("[PLOT] Saved original FFT with cutoff markers.") - - # Set up for filtering loop - fs_hz = fs # Sampling frequency - n_series, _ = signals.shape - logger.info( - "Processing %d cutoff level(s); fs=%.3e Hz, Nyquist=%.3e Hz", - len(cutoffs), - fs_hz, - 0.5 * fs_hz, - ) - - # Trim original signals (for comparison and video) - original_trim = _remove_filter_artifacts( - signals, + # Create pipeline + pipeline = AutoFilteringPipeline( + signals=signals, + path=path, + dt_ps=dt_ps, + levels=levels, + out_dir=out_dir, + reuse_existing=reuse_existing, frames_to_remove=frames_to_remove, + low_frac=low_frac, + low_ratio=low_ratio, + seed=seed, + drop_first_frame=drop_first_frame, ) - # Initialize result containers - filtered_collection: list[NDArray[np.float64]] = [] - used_cutoffs: list[float] = [] - filtered_paths: dict[float, Path] = {} - - # Initialize global y-limits from original data - global_ymin = float(np.nanmin(original_trim)) - global_ymax = float(np.nanmax(original_trim)) + # Run core analysis + pipeline.compute_fft_and_cutoffs() + pipeline.apply_filtering() - # Process each cutoff level - for i, cutoff in enumerate(sorted(cutoffs), start=1): - # Create folder for this level - folder = base / f"cutoff_{_freq_label_for_folder(cutoff)}" + # Save filtered collection (MANDATORY) + pipeline.save_filtered_collection() - # Filter signals at this cutoff - filtered_trim = _process_level( - i, - len(cutoffs), - cutoff, - signals, - fs_hz, - dt, - folder, - original_trim, - n_series, - frames_to_remove, - reuse_existing, - ) + # Save optional outputs based on parameters + if save_fft_plots: + pipeline.save_fft_plots() - # Update global y-limits to encompass this level - global_ymin = min(global_ymin, float(np.nanmin(filtered_trim))) - global_ymax = max(global_ymax, float(np.nanmax(filtered_trim))) + if save_cutoff_folders: + pipeline.save_cutoff_folders() - # Store results - filtered_collection.append(filtered_trim) - used_cutoffs.append(cutoff) - filtered_paths[cutoff] = folder / "filtered_signals.npy" + if save_signal_video: + pipeline.create_signal_video( + max_overlay_traces=max_overlay_traces, + frame_duration=frame_duration, + ) - # Add padding to y-limits - pad = ( - 1e-7 * (global_ymax - global_ymin) - if global_ymax > global_ymin - else 1.0 - ) - yl = (global_ymin - pad, global_ymax + pad) - - # Create evolution videos with consistent y-limits - video_path = _render_video( - original_trim, - filtered_collection, - used_cutoffs, - out_path_base=base / "filtered_evolution", - y_limits=yl, - max_overlay_traces=max_overlay_traces, - seed=seed, - frame_duration=frame_duration, - ) + if save_fft_video: + pipeline.create_fft_video(frame_duration=frame_duration) - # Save all filtered arrays to single compressed NPZ file - npz_path = base / "filtered_collection.npz" - np.savez_compressed(npz_path, *filtered_collection) - - # Collect metadata - meta = { - "dt_ps": dt_ps, - "frames_to_remove": frames_to_remove, - "levels": levels, - "low_frac": low_frac, - "low_ratio": low_ratio, - "seed": seed, - "frame_duration_s": frame_duration, - } - - # Return results container - return AutoFiltInsight( - cutoffs=used_cutoffs, - output_dir=base, - filtered_files=filtered_paths, - video_path=video_path, - meta=meta, - filtered_collection=tuple(filtered_collection), - ) + # Return result + return pipeline.get_result() diff --git a/src/dynsight/data_processing.py b/src/dynsight/data_processing.py index b5845793..b1ef7a0b 100644 --- a/src/dynsight/data_processing.py +++ b/src/dynsight/data_processing.py @@ -1,6 +1,7 @@ """data processing package.""" from dynsight._internal.data_processing.auto_filtering import ( + AutoFilteringPipeline, auto_filtering, ) from dynsight._internal.data_processing.classify import ( @@ -22,6 +23,7 @@ ) __all__ = [ + "AutoFilteringPipeline", "applyclassification", "auto_filtering", "createreferencesfromtrajectory", From aac9580f05f05f64e5d463d446b72a1259875344 Mon Sep 17 00:00:00 2001 From: Dominosauro Date: Thu, 27 Nov 2025 11:26:47 +0100 Subject: [PATCH 9/9] Format --- src/dynsight/_internal/data_processing/auto_filtering.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/dynsight/_internal/data_processing/auto_filtering.py b/src/dynsight/_internal/data_processing/auto_filtering.py index 70ce46b6..4f51c46c 100644 --- a/src/dynsight/_internal/data_processing/auto_filtering.py +++ b/src/dynsight/_internal/data_processing/auto_filtering.py @@ -1084,9 +1084,7 @@ def _draw_kde_panel( vals = vals[np.isfinite(vals)] if vals.size > 1 and np.nanstd(vals) > 0: - sns.kdeplot( - y=vals, ax=ax, fill=True, alpha=0.3, bw_adjust=kde_bw - ) + sns.kdeplot(y=vals, ax=ax, fill=True, alpha=0.3, bw_adjust=kde_bw) elif vals.size > 0: ax.axhline(float(vals[0]), ls="--", alpha=0.6)