From c796d33245e26fe96e926b6555554f1686449bc2 Mon Sep 17 00:00:00 2001 From: dominiquef Date: Mon, 12 Jan 2026 14:44:29 -0800 Subject: [PATCH 01/12] Initial commit for match options and driver class. --- .../plate_simulation/match/__init__.py | 9 ++ .../plate_simulation/match/driver.py | 117 ++++++++++++++++++ .../plate_simulation/match/options.py | 51 ++++++++ 3 files changed, 177 insertions(+) create mode 100644 simpeg_drivers/plate_simulation/match/__init__.py create mode 100644 simpeg_drivers/plate_simulation/match/driver.py create mode 100644 simpeg_drivers/plate_simulation/match/options.py diff --git a/simpeg_drivers/plate_simulation/match/__init__.py b/simpeg_drivers/plate_simulation/match/__init__.py new file mode 100644 index 00000000..df32b204 --- /dev/null +++ b/simpeg_drivers/plate_simulation/match/__init__.py @@ -0,0 +1,9 @@ +# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' +# Copyright (c) 2023-2026 Mira Geoscience Ltd. ' +# ' +# This file is part of simpeg-drivers package. ' +# ' +# simpeg-drivers is distributed under the terms and conditions of the MIT License ' +# (see LICENSE file at the root of this source code package). ' +# ' +# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' diff --git a/simpeg_drivers/plate_simulation/match/driver.py b/simpeg_drivers/plate_simulation/match/driver.py new file mode 100644 index 00000000..07386167 --- /dev/null +++ b/simpeg_drivers/plate_simulation/match/driver.py @@ -0,0 +1,117 @@ +# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' +# Copyright (c) 2023-2026 Mira Geoscience Ltd. ' +# ' +# This file is part of simpeg-drivers package. ' +# ' +# simpeg-drivers is distributed under the terms and conditions of the MIT License ' +# (see LICENSE file at the root of this source code package). ' +# ' +# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' + +from __future__ import annotations + +import sys +from pathlib import Path + +import numpy as np +from geoapps_utils.utils.importing import GeoAppsError +from geoapps_utils.utils.logger import get_logger +from geoh5py import Workspace +from geoh5py.groups import UIJsonGroup +from geoh5py.shared.utils import ( + dict_to_json_str, + fetch_active_workspace, + uuid_from_values, +) +from geoh5py.ui_json.ui_json import BaseUIJson +from geoh5py.ui_json.utils import flatten +from typing_extensions import Self + +from simpeg_drivers.driver import BaseDriver +from simpeg_drivers.plate_simulation.match.options import MatchOptions + + +logger = get_logger(name=__name__, level_name=False, propagate=False, add_name=False) + + +# TODO: Can we make this generic (PlateMatchDriver -> MatchDriver)? +class PlateMatchDriver(BaseDriver): + """Sets up and manages workers to run all combinations of swepts parameters.""" + + _params_class = MatchOptions + + def __init__(self, params: MatchOptions, workers: list[tuple[str]] | None = None): + super().__init__(params, workers=workers) + + self.out_group = self.validate_out_group(self.params.out_group) + + @property + def out_group(self) -> UIJsonGroup: + """ + Returns the output group for the simulation. + """ + return self._out_group + + @out_group.setter + def out_group(self, value: UIJsonGroup): + if not isinstance(value, UIJsonGroup): + raise TypeError("Output group must be a UIJsonGroup.") + + if self.params.out_group != value: + self.params.out_group = value + self.params.update_out_group_options() + + self._out_group = value + + def validate_out_group(self, out_group: UIJsonGroup | None) -> UIJsonGroup: + """ + Validate or create a UIJsonGroup to store results. + + :param value: Output group from selection. + """ + if isinstance(out_group, UIJsonGroup): + return out_group + + with fetch_active_workspace(self.params.geoh5, mode="r+"): + out_group = UIJsonGroup.create( + self.params.geoh5, + name=self.params.title, + ) + out_group.entity_type.name = self.params.title + + return out_group + + @classmethod + def start(cls, filepath: str | Path, mode="r", **_) -> Self: + """Start the parameter matching from a ui.json file.""" + logger.info("Loading input file . . .") + filepath = Path(filepath).resolve() + uijson = BaseUIJson.read(filepath) + + with Workspace(uijson.geoh5, mode=mode) as workspace: + try: + options = MatchOptions.build(uijson.to_params(workspace=workspace)) + logger.info("Initializing application . . .") + driver = cls(options) + logger.info("Running application . . .") + driver.run() + logger.info("Results saved to %s", options.geoh5.h5file) + + except GeoAppsError as error: + logger.warning("\n\nApplicationError: %s\n\n", error) + sys.exit(1) + + return driver + + def run(self): + """Loop over all trials and run a worker for each unique parameter set.""" + + logger.info( + "Running %s . . .", + self.params.template.options["title"], + ) + + +if __name__ == "__main__": + file = Path(sys.argv[1]) + PlateMatchDriver.start(file) diff --git a/simpeg_drivers/plate_simulation/match/options.py b/simpeg_drivers/plate_simulation/match/options.py new file mode 100644 index 00000000..abaabb3d --- /dev/null +++ b/simpeg_drivers/plate_simulation/match/options.py @@ -0,0 +1,51 @@ +# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' +# Copyright (c) 2023-2026 Mira Geoscience Ltd. ' +# ' +# This file is part of simpeg-drivers package. ' +# ' +# simpeg-drivers is distributed under the terms and conditions of the MIT License ' +# (see LICENSE file at the root of this source code package). ' +# ' +# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' + +import itertools +from pathlib import Path +from typing import ClassVar + +import numpy as np +from geoapps_utils.base import Options +from geoh5py.groups import PropertyGroup, UIJsonGroup +from geoh5py.objects import Points +from geoh5py.objects.surveys.electromagnetics.airborne_tem import AirborneTEMReceivers +from geoh5py.shared.utils import stringify +from geoh5py.ui_json import InputFile +from pydantic import BaseModel, ConfigDict, field_serializer + +from simpeg_drivers import assets_path + + +class MatchOptions(Options): + """ + Options for matching signal from a survey against a library of simulations. + + :param survey: A Time-Domain Airborne TEM survey object. + :param data: A property group containing observed data. + :param targets: A Points object containing the target locations. + :param strike_angles: An optional data array containing strike angles for each + target location. + :param simulations: Directory to store simulation files. + """ + + model_config = ConfigDict(frozen=False, arbitrary_types_allowed=True) + + name: ClassVar[str] = "plate_match" + default_ui_json: ClassVar[Path] = assets_path() / "uijson/plate_match.ui.json" + title: ClassVar[str] = "Plate Match" + run_command: ClassVar[str] = "simpeg_drivers.plate_simulation.match.driver" + out_group: UIJsonGroup | None = None + + survey: AirborneTEMReceivers + data: PropertyGroup + targets: Points + strike_angles: np.ndarray | None = None + simulations: ClassVar[Path] From ea56d4e774636cf0ab06e8131a0ac923ceac223f Mon Sep 17 00:00:00 2001 From: dominiquef Date: Tue, 13 Jan 2026 12:21:38 -0800 Subject: [PATCH 02/12] Continue work --- .../uijson/plate_match.ui.json | 111 ++++++++++++++++ .../plate_simulation/match/driver.py | 125 +++++++++++++++--- .../plate_simulation/match/options.py | 17 +-- .../plate_simulation/match/uijson.py | 26 ++++ 4 files changed, 256 insertions(+), 23 deletions(-) create mode 100644 simpeg_drivers-assets/uijson/plate_match.ui.json create mode 100644 simpeg_drivers/plate_simulation/match/uijson.py diff --git a/simpeg_drivers-assets/uijson/plate_match.ui.json b/simpeg_drivers-assets/uijson/plate_match.ui.json new file mode 100644 index 00000000..86ec9a2d --- /dev/null +++ b/simpeg_drivers-assets/uijson/plate_match.ui.json @@ -0,0 +1,111 @@ +{ + "version": "0.2.0-alpha.1", + "title": "Plate Match", + "icon": "maxwellplate", + "documentation": "https://mirageoscience-simpeg-drivers.readthedocs-hosted.com/en/latest/plate-simulation/", + "conda_environment": "simpeg_drivers", + "run_command": "simpeg_drivers.driver", + "geoh5": "", + "monitoring_directory": "", + "inversion_type": "plate match", + "survey": { + "main": true, + "label": "Survey", + "meshType": [ + "{6a057fdc-b355-11e3-95be-fd84a7ffcb88}", + "{19730589-fd28-4649-9de0-ad47249d9aba}" + ], + "tooltip": "Airborne TEM survey containing transmitter and receiver locations.", + "value": "" + }, + "data": { + "association": [ + "Cell", + "Vertex" + ], + "dataType": "Float", + "dataGroupType": "Multi-element", + "main": true, + "label": "EM Data", + "parent": "survey", + "tooltip": "Observed EM data to fit during plate matching.", + "value": "" + }, + "queries": { + "main": true, + "label": "Query Points", + "meshType": [ + "{6a057fdc-b355-11e3-95be-fd84a7ffcb88}", + "{202C5DB1-A56D-4004-9CAD-BAAFD8899406}" + ], + "tooltip": "Locations of EM anomalies on the survey to evaluate the match.", + "value": "" + }, + "max_distance": { + "main": true, + "label": "Query Max Distance (m)", + "value": 1000.0, + "tooltip": "Length on either side of the query points to evaluate the match." + }, + "strike_angles": { + "association": [ + "Cell", + "Vertex" + ], + "dataType": "Float", + "main": true, + "label": "Strike Angles (degrees)", + "parent": "queries", + "optional": true, + "enabled": false, + "tooltip": "Data containing the estimated angles between the flight path and dipping body.", + "value": "" + }, + "topography_object": { + "main": true, + "group": "Topography", + "label": "Topography", + "meshType": [ + "{202c5db1-a56d-4004-9cad-baafd8899406}", + "{6a057fdc-b355-11e3-95be-fd84a7ffcb88}", + "{f26feba3-aded-494b-b9e9-b2bbcbe298e1}", + "{48f5054a-1c5c-4ca4-9048-80f36dc60a06}", + "{b020a277-90e2-4cd7-84d6-612ee3f25051}" + ], + "value": "", + "tooltip": "Select a topography object to define the drape height for the survey." + }, + "topography": { + "association": [ + "Vertex", + "Cell" + ], + "dataType": "Float", + "group": "Topography", + "main": true, + "optional": true, + "enabled": false, + "label": "Elevation channel", + "tooltip": "Set elevation from channel. If not set the topography will be set from the geometry of the selected 'topography' object", + "parent": "topography_object", + "dependency": "topography_object", + "dependencyType": "enabled", + "value": "", + "verbose": 2 + }, + "simulations": { + "main": true, + "label": "Simulations directory", + "directoryOnly": true, + "tootip": "Directory where pre-computed simulations are stored.", + "value": "./simulations" + }, + "out_group": { + "label": "UIJsonGroup", + "value": "", + "groupType": "{BB50AC61-A657-4926-9C82-067658E246A0}", + "visible": true, + "optional": true, + "enabled": false + } +} diff --git a/simpeg_drivers/plate_simulation/match/driver.py b/simpeg_drivers/plate_simulation/match/driver.py index 07386167..2ccb796c 100644 --- a/simpeg_drivers/plate_simulation/match/driver.py +++ b/simpeg_drivers/plate_simulation/match/driver.py @@ -15,16 +15,17 @@ import numpy as np from geoapps_utils.utils.importing import GeoAppsError +from geoapps_utils.utils.locations import topo_drape_elevation from geoapps_utils.utils.logger import get_logger +from geoapps_utils.utils.transformations import rotate_xyz from geoh5py import Workspace -from geoh5py.groups import UIJsonGroup +from geoh5py.groups import SimPEGGroup +from geoh5py.objects import AirborneTEMReceivers, Surface from geoh5py.shared.utils import ( - dict_to_json_str, fetch_active_workspace, - uuid_from_values, ) from geoh5py.ui_json.ui_json import BaseUIJson -from geoh5py.ui_json.utils import flatten +from scipy.spatial import cKDTree from typing_extensions import Self from simpeg_drivers.driver import BaseDriver @@ -34,7 +35,6 @@ logger = get_logger(name=__name__, level_name=False, propagate=False, add_name=False) -# TODO: Can we make this generic (PlateMatchDriver -> MatchDriver)? class PlateMatchDriver(BaseDriver): """Sets up and manages workers to run all combinations of swepts parameters.""" @@ -46,16 +46,16 @@ def __init__(self, params: MatchOptions, workers: list[tuple[str]] | None = None self.out_group = self.validate_out_group(self.params.out_group) @property - def out_group(self) -> UIJsonGroup: + def out_group(self) -> SimPEGGroup: """ Returns the output group for the simulation. """ return self._out_group @out_group.setter - def out_group(self, value: UIJsonGroup): - if not isinstance(value, UIJsonGroup): - raise TypeError("Output group must be a UIJsonGroup.") + def out_group(self, value: SimPEGGroup): + if not isinstance(value, SimPEGGroup): + raise TypeError("Output group must be a SimPEGGroup.") if self.params.out_group != value: self.params.out_group = value @@ -63,17 +63,17 @@ def out_group(self, value: UIJsonGroup): self._out_group = value - def validate_out_group(self, out_group: UIJsonGroup | None) -> UIJsonGroup: + def validate_out_group(self, out_group: SimPEGGroup | None) -> SimPEGGroup: """ - Validate or create a UIJsonGroup to store results. + Validate or create a SimPEGGroup to store results. - :param value: Output group from selection. + :param out_group: Output group from selection. """ - if isinstance(out_group, UIJsonGroup): + if isinstance(out_group, SimPEGGroup): return out_group with fetch_active_workspace(self.params.geoh5, mode="r+"): - out_group = UIJsonGroup.create( + out_group = SimPEGGroup.create( self.params.geoh5, name=self.params.title, ) @@ -111,7 +111,102 @@ def run(self): self.params.template.options["title"], ) + tree = cKDTree(self.params.survey.vertices[:, :2]) + + topo_z = ( + self.params.topography.values + if self.params.topography.values is not None + else self.params.topography_object.locations[:, 2] + ) + topo_drape_z = topo_drape_elevation( + self.params.survey.vertices, + self.params.topography_object.locations, + topo_z, + triangulation=self.params.topography_object.cells + if isinstance(self.params.topography_object, Surface) + else None, + ) + + for ii, query in enumerate(self.params.queries.vertices): + nearest = tree.query(query[:2], k=1)[0] + line_mask = np.where( + self.params.survey.parts == self.params.survey.parts[nearest] + )[0] + distances = np.linalg.norm( + self.params.survey.vertices[nearest, :2] + - self.params.survey.vertices[line_mask, :2], + axis=1, + ) + dist_mask = distances < self.params.max_distance + indices = line_mask[dist_mask] + + # Compute local coordinates for the current line segment + line_dist = distances[dist_mask] + line_dist[indices < nearest] *= -1.0 + local_xyz = np.c_[ + line_dist, + np.zeros_like(line_dist), + self.params.survey.vertices[indices, 2] - topo_drape_z[indices], + ] + + if self.params.strike_angles is not None: + angle = self.params.strike_angles[ii] + local_xyz = rotate_xyz( + local_xyz, [0, 0, 0], self.params.strike_angles[ii], 0 + ) + + # Convert to polar coordinates (distance, azimuth, height) + local_polar = np.c_[ + line_dist, + 90 - (np.rad2deg(np.arctan2(local_xyz[:, 0], local_xyz[:, 1])) % 180), + local_xyz[:, 2], + ] + + projection = None + data = {} + for file in self.params.simulations.iterdir(): + if Path(file).resolve().suffix == ".geoh5": + with Workspace(file, mode="r") as ws: + sim = next( + group + for group in ws.groups + if isinstance(group, SimPEGGroup) and "Plate" in group.name + ) + fwr = next( + child + for child in sim.children + if isinstance(child, SimPEGGroup) + ) + survey = next( + child + for child in fwr.children + if isinstance(child, AirborneTEMReceivers) + ) + group = survey.get_entity("Iteration_0_z")[0] + data[Path(file).stem] = group.table() + + # Create a projection matrix to interpolate simulated data to the observation locations + if projection is None: + dist = np.sign(survey.vertices[:, 0]) * np.linalg.norm( + survey.vertices[:, :2], axis=1 + ) + azm = ( + 90 + - np.rad2deg( + np.arctan2( + survey.vertices[:, 0], survey.vertices[:, 1] + ) + ) + % 180 + ).round(decimals=1) + height = (survey.vertices[:, 2]).round(decimals=1) + polar_coordinates = np.c_[dist, height, azm] + + shape = (-1, len(np.unique(azm)), len(np.unique(height))) + polar_coordinates.reshape(shape) + if __name__ == "__main__": - file = Path(sys.argv[1]) + # file = Path(sys.argv[1]) + file = Path(r"C:\Users\dominiquef\Documents\Workspace\Teck\RnD\plate_match.ui.json") PlateMatchDriver.start(file) diff --git a/simpeg_drivers/plate_simulation/match/options.py b/simpeg_drivers/plate_simulation/match/options.py index abaabb3d..ad0308a7 100644 --- a/simpeg_drivers/plate_simulation/match/options.py +++ b/simpeg_drivers/plate_simulation/match/options.py @@ -14,12 +14,10 @@ import numpy as np from geoapps_utils.base import Options -from geoh5py.groups import PropertyGroup, UIJsonGroup -from geoh5py.objects import Points +from geoh5py.groups import PropertyGroup, SimPEGGroup +from geoh5py.objects import Grid2D, Points from geoh5py.objects.surveys.electromagnetics.airborne_tem import AirborneTEMReceivers -from geoh5py.shared.utils import stringify -from geoh5py.ui_json import InputFile -from pydantic import BaseModel, ConfigDict, field_serializer +from pydantic import ConfigDict from simpeg_drivers import assets_path @@ -30,7 +28,7 @@ class MatchOptions(Options): :param survey: A Time-Domain Airborne TEM survey object. :param data: A property group containing observed data. - :param targets: A Points object containing the target locations. + :param queries: A Points object containing the target locations. :param strike_angles: An optional data array containing strike angles for each target location. :param simulations: Directory to store simulation files. @@ -42,10 +40,13 @@ class MatchOptions(Options): default_ui_json: ClassVar[Path] = assets_path() / "uijson/plate_match.ui.json" title: ClassVar[str] = "Plate Match" run_command: ClassVar[str] = "simpeg_drivers.plate_simulation.match.driver" - out_group: UIJsonGroup | None = None + out_group: SimPEGGroup | None = None survey: AirborneTEMReceivers data: PropertyGroup - targets: Points + queries: Points strike_angles: np.ndarray | None = None + max_distance: float = 1000.0 + topography_object: Points | Grid2D + topography: np.ndarray | None = None simulations: ClassVar[Path] diff --git a/simpeg_drivers/plate_simulation/match/uijson.py b/simpeg_drivers/plate_simulation/match/uijson.py new file mode 100644 index 00000000..e2fc3c23 --- /dev/null +++ b/simpeg_drivers/plate_simulation/match/uijson.py @@ -0,0 +1,26 @@ +# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' +# Copyright (c) 2023-2026 Mira Geoscience Ltd. ' +# ' +# This file is part of simpeg-drivers package. ' +# ' +# simpeg-drivers is distributed under the terms and conditions of the MIT License ' +# (see LICENSE file at the root of this source code package). ' +# ' +# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' + +from geoh5py.ui_json.forms import DataForm, FloatForm, ObjectForm, StringForm +from geoh5py.ui_json.ui_json import BaseUIJson +from pydantic import ConfigDict + + +class PlateMatchUIJson(BaseUIJson): + model_config = ConfigDict(arbitrary_types_allowed=True) + + survey: ObjectForm + data: DataForm + queries: ObjectForm + strike_angles: DataForm + max_distance: FloatForm + topography_object: ObjectForm + topography: DataForm + simulations: StringForm From afd31c4061f946e8b5382bb8977a3c1d7df9cd58 Mon Sep 17 00:00:00 2001 From: dominiquef Date: Thu, 15 Jan 2026 20:29:15 -0800 Subject: [PATCH 03/12] First full run --- .../uijson/plate_match.ui.json | 4 +- .../plate_simulation/match/driver.py | 346 +++++++++++++----- .../plate_simulation/match/options.py | 18 +- 3 files changed, 268 insertions(+), 100 deletions(-) diff --git a/simpeg_drivers-assets/uijson/plate_match.ui.json b/simpeg_drivers-assets/uijson/plate_match.ui.json index 86ec9a2d..04a42d51 100644 --- a/simpeg_drivers-assets/uijson/plate_match.ui.json +++ b/simpeg_drivers-assets/uijson/plate_match.ui.json @@ -4,7 +4,7 @@ "icon": "maxwellplate", "documentation": "https://mirageoscience-simpeg-drivers.readthedocs-hosted.com/en/latest/plate-simulation/", "conda_environment": "simpeg_drivers", - "run_command": "simpeg_drivers.driver", + "run_command": "simpeg_drivers.plate_simulation.match.driver", "geoh5": "", "monitoring_directory": "", "inversion_type": "plate match", @@ -103,7 +103,7 @@ "out_group": { "label": "UIJsonGroup", "value": "", - "groupType": "{BB50AC61-A657-4926-9C82-067658E246A0}", + "groupType": "{55ed3daf-c192-4d4b-a439-60fa987fe2b8}", "visible": true, "optional": true, "enabled": false diff --git a/simpeg_drivers/plate_simulation/match/driver.py b/simpeg_drivers/plate_simulation/match/driver.py index 2ccb796c..797a0743 100644 --- a/simpeg_drivers/plate_simulation/match/driver.py +++ b/simpeg_drivers/plate_simulation/match/driver.py @@ -17,21 +17,28 @@ from geoapps_utils.utils.importing import GeoAppsError from geoapps_utils.utils.locations import topo_drape_elevation from geoapps_utils.utils.logger import get_logger -from geoapps_utils.utils.transformations import rotate_xyz +from geoapps_utils.utils.plotting import symlog from geoh5py import Workspace -from geoh5py.groups import SimPEGGroup +from geoh5py.groups import PropertyGroup, SimPEGGroup from geoh5py.objects import AirborneTEMReceivers, Surface from geoh5py.shared.utils import ( fetch_active_workspace, ) +from geoh5py.ui_json import InputFile from geoh5py.ui_json.ui_json import BaseUIJson +from scipy import signal +from scipy.sparse import csr_matrix, diags from scipy.spatial import cKDTree +from tqdm import tqdm from typing_extensions import Self from simpeg_drivers.driver import BaseDriver from simpeg_drivers.plate_simulation.match.options import MatchOptions +from simpeg_drivers.plate_simulation.options import PlateSimulationOptions +# import matplotlib.pyplot as plt + logger = get_logger(name=__name__, level_name=False, propagate=False, add_name=False) @@ -43,6 +50,8 @@ class PlateMatchDriver(BaseDriver): def __init__(self, params: MatchOptions, workers: list[tuple[str]] | None = None): super().__init__(params, workers=workers) + self._drape_heights = self.set_drape_height() + self.out_group = self.validate_out_group(self.params.out_group) @property @@ -82,15 +91,16 @@ def validate_out_group(self, out_group: SimPEGGroup | None) -> SimPEGGroup: return out_group @classmethod - def start(cls, filepath: str | Path, mode="r", **_) -> Self: + def start(cls, filepath: str | Path, mode="r+", **_) -> Self: """Start the parameter matching from a ui.json file.""" logger.info("Loading input file . . .") filepath = Path(filepath).resolve() - uijson = BaseUIJson.read(filepath) + # uijson = BaseUIJson.read(filepath) + uijson = InputFile.read_ui_json(filepath) - with Workspace(uijson.geoh5, mode=mode) as workspace: + with uijson.geoh5.open(mode=mode): try: - options = MatchOptions.build(uijson.to_params(workspace=workspace)) + options = MatchOptions.build(uijson) logger.info("Initializing application . . .") driver = cls(options) logger.info("Running application . . .") @@ -103,110 +113,256 @@ def start(cls, filepath: str | Path, mode="r", **_) -> Self: return driver - def run(self): - """Loop over all trials and run a worker for each unique parameter set.""" + def set_drape_height(self) -> np.ndarray: + """Set drape heights based on topography object and optional topography data.""" - logger.info( - "Running %s . . .", - self.params.template.options["title"], - ) + topo = self.params.topography_object.locations - tree = cKDTree(self.params.survey.vertices[:, :2]) + if self.params.topography is not None: + topo[:, 2] = self.params.topography.values - topo_z = ( - self.params.topography.values - if self.params.topography.values is not None - else self.params.topography_object.locations[:, 2] - ) topo_drape_z = topo_drape_elevation( self.params.survey.vertices, - self.params.topography_object.locations, - topo_z, + topo, triangulation=self.params.topography_object.cells if isinstance(self.params.topography_object, Surface) else None, ) + return topo_drape_z[:, 2] + + def normalized_data(self, property_group: PropertyGroup, threshold=5) -> np.ndarray: + """ + Return data from a property group with symlog scaling and zero mean. + + :param property_group: Property group containing data channels. + :param threshold: Percentile threshold for symlog normalization. + + :return: Normalized data array. + """ + table = property_group.table() + data_array = np.vstack([table[name] for name in table.dtype.names]) + thresh = np.percentile(np.abs(data_array), threshold) + log_data = symlog(data_array, thresh) + return log_data - np.mean(log_data, axis=1)[:, None] + + def fetch_survey(self, workspace: Workspace) -> AirborneTEMReceivers | None: + """Fetch the survey from the workspace.""" + for group in workspace.groups: + if isinstance(group, SimPEGGroup): + for child in group.children: + if isinstance(child, AirborneTEMReceivers): + return child + + return None + + def spatial_interpolation( + self, + indices: np.ndarray, + locations: np.ndarray, + strike_angle: float | None = None, + ) -> csr_matrix: + """ + Create a spatial interpolation matrix from simulation to observation locations. + + :param indices: Indices for the line segment of the observation locations. + :param locations: Positions to interpolate from. + :param strike_angle: Optional strike angle to correct azimuths. + + :return: Spatial interpolation matrix. + """ + # Compute local coordinates for the current line segment + local_polar = self.xyz_to_polar(self.params.survey.vertices[indices, :]) + local_polar[:, 1] = ( + 0.0 if strike_angle is None else strike_angle + ) # Align azimuths to zero + + # Convert to polar coordinates (distance, azimuth, height) + query_polar = self.xyz_to_polar(locations) + + # Get the 8 nearest neighbors in the simulation to each observation point + sim_tree = cKDTree(query_polar) + rad, inds = sim_tree.query(local_polar, k=8) + + weights = (rad**2.0 + 1e-1) ** -1 + row_ids = np.kron(np.arange(local_polar.shape[0]), np.ones(8)) + inv_dist_op = csr_matrix( + (weights.flatten(), (row_ids, np.hstack(inds.flatten()))), + shape=(local_polar.shape[0], locations.shape[0]), + ) + + # Normalize the rows + row_sum = np.asarray(inv_dist_op.sum(axis=1)).flatten() ** -1.0 + return diags(row_sum) @ inv_dist_op + + @staticmethod + def xyz_to_polar(xyz: np.ndarray) -> np.ndarray: + """ + Convert Cartesian coordinates to polar coordinates defined as + (distance, azimuth, height), where distance is signed based on the + x-coordinate relative to the mean location. + + :param xyz: Cartesian coordinates. + + :return: Polar coordinates (distance, azimuth, height). + """ + mean_loc = np.mean(xyz, axis=0) + distances = np.sign(xyz[:, 0] - mean_loc[0]) * np.linalg.norm( + xyz[:, :2] - mean_loc[:2], axis=1 + ) + azimuths = 90 - (np.rad2deg(np.arctan2(xyz[:, 0], xyz[:, 1])) % 180) + return np.c_[distances, azimuths, xyz[:, 2]] + + @staticmethod + def time_interpolation( + query_times: np.ndarray, sim_times: np.ndarray + ) -> csr_matrix: + """ + Create a time interpolation matrix from simulation to observation times. + + :param query_times: Observation times. + :param sim_times: Simulation times. + + :return: Time interpolation matrix. + """ + right = np.searchsorted(sim_times, query_times) + + inds = np.r_[right - 1, right] + + row_ids = np.tile(np.arange(len(query_times)), 2) + weights = (np.abs(query_times[row_ids] - sim_times[inds]) + 1e-12) ** -1 + + time_projection = csr_matrix( + (weights.flatten(), (row_ids, np.hstack(inds.flatten()))), + shape=(len(query_times), len(sim_times)), + ) + row_sum = np.asarray(time_projection.sum(axis=1)).flatten() ** -1.0 + return diags(row_sum) @ time_projection + + def get_segment_indices(self, nearest: int) -> np.ndarray: + """ + Get indices of line segment for a given nearest vertex. + + :param nearest: Nearest vertex index. + """ + line_mask = np.where( + self.params.survey.parts == self.params.survey.parts[nearest] + )[0] + distances = np.linalg.norm( + self.params.survey.vertices[nearest, :2] + - self.params.survey.vertices[line_mask, :2], + axis=1, + ) + dist_mask = distances < self.params.max_distance + indices = line_mask[dist_mask] + return indices + + def run(self): + """Loop over all trials and run a worker for each unique parameter set.""" + + logger.info( + "Running %s . . .", + self.params.title, + ) + observed = self.normalized_data(self.params.data) + + scores = [] + files_id = [] + tree = cKDTree(self.params.survey.vertices[:, :2]) + spatial_projection = None + time_projection = None for ii, query in enumerate(self.params.queries.vertices): - nearest = tree.query(query[:2], k=1)[0] - line_mask = np.where( - self.params.survey.parts == self.params.survey.parts[nearest] - )[0] - distances = np.linalg.norm( - self.params.survey.vertices[nearest, :2] - - self.params.survey.vertices[line_mask, :2], - axis=1, - ) - dist_mask = distances < self.params.max_distance - indices = line_mask[dist_mask] - - # Compute local coordinates for the current line segment - line_dist = distances[dist_mask] - line_dist[indices < nearest] *= -1.0 - local_xyz = np.c_[ - line_dist, - np.zeros_like(line_dist), - self.params.survey.vertices[indices, 2] - topo_drape_z[indices], - ] - - if self.params.strike_angles is not None: - angle = self.params.strike_angles[ii] - local_xyz = rotate_xyz( - local_xyz, [0, 0, 0], self.params.strike_angles[ii], 0 - ) - - # Convert to polar coordinates (distance, azimuth, height) - local_polar = np.c_[ - line_dist, - 90 - (np.rad2deg(np.arctan2(local_xyz[:, 0], local_xyz[:, 1])) % 180), - local_xyz[:, 2], - ] - - projection = None - data = {} - for file in self.params.simulations.iterdir(): - if Path(file).resolve().suffix == ".geoh5": - with Workspace(file, mode="r") as ws: - sim = next( - group - for group in ws.groups - if isinstance(group, SimPEGGroup) and "Plate" in group.name + for sim_file in tqdm(self.params.simulation_files): + with Workspace(sim_file, mode="r") as ws: + survey = self.fetch_survey(ws) + + if survey is None: + logger.warning("No survey found in %s, skipping.", sim_file) + continue + + simulated = self.normalized_data( + survey.get_entity("Iteration_0_z")[0] + ) + + # Create a projection matrix to interpolate simulated data to the observation locations + # Assume that lines of simulations are centered at origin + if spatial_projection is None: + nearest = tree.query(query[:2], k=1)[1] + indices = self.get_segment_indices(nearest) + spatial_projection = self.spatial_interpolation( + indices, + survey.vertices, + self.params.strike_angles.values[ii], ) - fwr = next( - child - for child in sim.children - if isinstance(child, SimPEGGroup) + + if time_projection is None: + query_times = np.asarray(self.params.survey.channels) + simulated_times = np.asarray(survey.channels) + + # Only interpolate for times within the simulated range + time_mask = (query_times > simulated_times.min()) & ( + query_times < simulated_times.max() ) - survey = next( - child - for child in fwr.children - if isinstance(child, AirborneTEMReceivers) + time_projection = self.time_interpolation( + query_times[time_mask], simulated_times ) - group = survey.get_entity("Iteration_0_z")[0] - data[Path(file).stem] = group.table() - - # Create a projection matrix to interpolate simulated data to the observation locations - if projection is None: - dist = np.sign(survey.vertices[:, 0]) * np.linalg.norm( - survey.vertices[:, :2], axis=1 - ) - azm = ( - 90 - - np.rad2deg( - np.arctan2( - survey.vertices[:, 0], survey.vertices[:, 1] - ) - ) - % 180 - ).round(decimals=1) - height = (survey.vertices[:, 2]).round(decimals=1) - polar_coordinates = np.c_[dist, height, azm] - - shape = (-1, len(np.unique(azm)), len(np.unique(height))) - polar_coordinates.reshape(shape) + observed = observed[time_mask, :] + + pred = time_projection @ (spatial_projection @ simulated.T).T + + score = 0.0 + + # if sim_file.stem == "0e50d2da-7ab0-5484-9ffd-365f076cce98": + # + # fig, ax = plt.figure(), plt.subplot() + + # Metric: normalized cross-correlation + for obs, pre in zip(observed[:, indices], pred, strict=True): + # Full cross-correlation + corr = signal.correlate( + obs, pre, mode="full" + ) # corr[k] ~ sum_t y[t] * x[t - k] + # Normalize by energy to get correlation coefficient in [-1, 1] + denom = np.linalg.norm(pre) * np.linalg.norm(obs) + if denom == 0: + corr_norm = np.zeros_like(corr) + else: + corr_norm = corr / denom + + score += np.max(corr_norm) + # if sim_file.stem == "0e50d2da-7ab0-5484-9ffd-365f076cce98": + # ax.plot(obs , 'r') + # ax.plot(pre, 'k') + + # if sim_file.stem == "0e50d2da-7ab0-5484-9ffd-365f076cce98": + # plt.show() + + scores.append(score) + files_id.append(sim_file) + + spatial_projection = None + time_projection = None + + ranked = np.argsort(scores) + print("Top 3 matches:") + for rank in ranked[-1:][::-1]: + print(f"File: {files_id[rank].stem:30s} Score: {scores[rank]:.4f}") + with Workspace(files_id[rank], mode="r") as ws: + survey = self.fetch_survey(ws) + ui_json = survey.parent.parent.options + ui_json["geoh5"] = ws + ifile = InputFile(ui_json=ui_json) + options = PlateSimulationOptions.build(ifile) + + plate = survey.parent.parent.get_entity("plate")[0].copy( + parent=self.params.out_group + ) + plate.vertices = plate.vertices + query + + print(f"Best parameters:{options.model.model_dump_json(indent=2)}") if __name__ == "__main__": - # file = Path(sys.argv[1]) - file = Path(r"C:\Users\dominiquef\Documents\Workspace\Teck\RnD\plate_match.ui.json") + file = Path(sys.argv[1]) + # file = Path(r"C:\Users\dominiquef\Documents\Workspace\Teck\RnD\plate_match_v2.ui.json") PlateMatchDriver.start(file) diff --git a/simpeg_drivers/plate_simulation/match/options.py b/simpeg_drivers/plate_simulation/match/options.py index ad0308a7..2d5a32c8 100644 --- a/simpeg_drivers/plate_simulation/match/options.py +++ b/simpeg_drivers/plate_simulation/match/options.py @@ -14,6 +14,7 @@ import numpy as np from geoapps_utils.base import Options +from geoh5py.data import FloatData from geoh5py.groups import PropertyGroup, SimPEGGroup from geoh5py.objects import Grid2D, Points from geoh5py.objects.surveys.electromagnetics.airborne_tem import AirborneTEMReceivers @@ -45,8 +46,19 @@ class MatchOptions(Options): survey: AirborneTEMReceivers data: PropertyGroup queries: Points - strike_angles: np.ndarray | None = None + strike_angles: FloatData | None = None max_distance: float = 1000.0 topography_object: Points | Grid2D - topography: np.ndarray | None = None - simulations: ClassVar[Path] + topography: FloatData | None = None + simulations: str + + @property + def simulation_files(self) -> list[Path]: + """Path to simulation files directory.""" + sim_dir = self.geoh5.h5file.parent / self.simulations + simulation_files = [] + for file in sim_dir.iterdir(): + if Path(file).resolve().suffix == ".geoh5": + simulation_files.append(Path(file)) + + return simulation_files From 5bea23ca37de507d76bf5e92f2c6947aa105b19c Mon Sep 17 00:00:00 2001 From: dominiquef Date: Thu, 15 Jan 2026 20:36:41 -0800 Subject: [PATCH 04/12] Fix to run from main driver --- simpeg_drivers-assets/uijson/plate_match.ui.json | 1 + simpeg_drivers/__init__.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/simpeg_drivers-assets/uijson/plate_match.ui.json b/simpeg_drivers-assets/uijson/plate_match.ui.json index 04a42d51..e536f3fd 100644 --- a/simpeg_drivers-assets/uijson/plate_match.ui.json +++ b/simpeg_drivers-assets/uijson/plate_match.ui.json @@ -8,6 +8,7 @@ "geoh5": "", "monitoring_directory": "", "inversion_type": "plate match", + "forward_only": true, "survey": { "main": true, "label": "Survey", diff --git a/simpeg_drivers/__init__.py b/simpeg_drivers/__init__.py index 09790b88..a2aefb59 100644 --- a/simpeg_drivers/__init__.py +++ b/simpeg_drivers/__init__.py @@ -176,4 +176,8 @@ def assets_path() -> Path: "simpeg_drivers.plate_simulation.sweep.driver", {"forward": "PlateSweepDriver"}, ), + "plate match": ( + "simpeg_drivers.plate_simulation.match.driver", + {"forward": "PlateMatchDriver"}, + ), } From 66c0e3fc66f5ee1c906de5f1a4d0d7b49051c7b1 Mon Sep 17 00:00:00 2001 From: dominiquef Date: Thu, 15 Jan 2026 20:39:00 -0800 Subject: [PATCH 05/12] Attach plate simulation params to plate surface --- simpeg_drivers/plate_simulation/match/driver.py | 1 + 1 file changed, 1 insertion(+) diff --git a/simpeg_drivers/plate_simulation/match/driver.py b/simpeg_drivers/plate_simulation/match/driver.py index 797a0743..e443734f 100644 --- a/simpeg_drivers/plate_simulation/match/driver.py +++ b/simpeg_drivers/plate_simulation/match/driver.py @@ -358,6 +358,7 @@ def run(self): parent=self.params.out_group ) plate.vertices = plate.vertices + query + plate.metadata = options.model.model_dump() print(f"Best parameters:{options.model.model_dump_json(indent=2)}") From 65c4c9d751ea679b3b303c3410064a2861e87224 Mon Sep 17 00:00:00 2001 From: dominiquef Date: Fri, 16 Jan 2026 12:44:07 -0800 Subject: [PATCH 06/12] Re jig dask setup on base class. Parallel run on plate_sim.match --- simpeg_drivers/driver.py | 133 ++++--- .../electromagnetics/base_1d_driver.py | 24 ++ .../plate_simulation/match/driver.py | 337 ++++++++++-------- 3 files changed, 282 insertions(+), 212 deletions(-) diff --git a/simpeg_drivers/driver.py b/simpeg_drivers/driver.py index 8867a5d5..90978e0a 100644 --- a/simpeg_drivers/driver.py +++ b/simpeg_drivers/driver.py @@ -16,21 +16,17 @@ import cProfile import pstats -import multiprocessing import contextlib from copy import deepcopy import sys from datetime import datetime, timedelta import logging -from multiprocessing.pool import ThreadPool from pathlib import Path from time import time from typing_extensions import Self import numpy as np -from dask import config as dconf - from dask.distributed import get_client, Client, LocalCluster, performance_report from geoapps_utils.base import Driver, Options @@ -41,7 +37,6 @@ from geoh5py.groups import SimPEGGroup from geoh5py.objects import FEMSurvey from geoh5py.shared.utils import fetch_active_workspace -from geoh5py.shared.exceptions import Geoh5FileClosedError from geoh5py.ui_json import InputFile from simpeg import ( @@ -158,6 +153,62 @@ def validate_workers(self, workers: list[tuple[str]] | None) -> list[tuple[str]] return workers + @classmethod + def start_dask_run( + cls, + ifile, + n_workers: int | None = None, + n_threads: int | None = None, + save_report: bool = True, + ): + """ + Sets Dask config settings. + + :param ifile: Input file path. + :param n_workers: Number of Dask workers. + :param n_threads: Number of threads per Dask worker. + :param save_report: Whether to save a performance report. + """ + distributed_process = ( + n_workers is not None and n_workers > 1 + ) or n_threads is not None + + cluster = ( + LocalCluster( + processes=True, + n_workers=n_workers, + threads_per_worker=n_threads, + ) + if distributed_process + else None + ) + profiler = cProfile.Profile() + profiler.enable() + + with ( + cluster.get_client() + if cluster is not None + else contextlib.nullcontext() as context_client + ): + # Full run + with ( + performance_report(filename=ifile.parent / "dask_profile.html") + if (save_report and isinstance(context_client, Client)) + else contextlib.nullcontext() + ): + cls.start(ifile) + sys.stdout.close() + + profiler.disable() + + if save_report: + with open( + ifile.parent / "runtime_profile.txt", encoding="utf-8", mode="w" + ) as s: + ps = pstats.Stats(profiler, stream=s) + ps.sort_stats("cumulative") + ps.print_stats() + class InversionDriver(BaseDriver): _params_class = BaseForwardOptions | BaseInversionOptions @@ -485,8 +536,6 @@ def run(self): sys.stdout = self.logger self.logger.start() - self.configure_dask() - with fetch_active_workspace(self.workspace, mode="r+"): simpeg_inversion = self.inversion @@ -728,18 +777,6 @@ def get_tiles(self): sorting=self.simulation.survey.sorting, ) - def configure_dask(self): - """Sets Dask config settings.""" - - if self.client: - dconf.set(scheduler=self.client) - else: - n_cpu = self.params.compute.n_cpu - if n_cpu is None: - n_cpu = int(multiprocessing.cpu_count()) - - dconf.set(scheduler="threads", pool=ThreadPool(n_cpu)) - @classmethod def start(cls, filepath: str | Path | InputFile, **kwargs) -> Self: """ @@ -840,57 +877,11 @@ def get_path(self, filepath: str | Path) -> str: if __name__ == "__main__": file = Path(sys.argv[1]).resolve() input_file = load_ui_json_as_dict(file) - n_workers = input_file.get("n_workers", None) - n_threads = input_file.get("n_threads", None) - save_report = input_file.get("performance_report", False) - - # Force distributed on 1D problems - if "1D" in input_file.get("title") and n_workers is None: - cpu_count = multiprocessing.cpu_count() - - if cpu_count < 16: - n_threads = n_threads or 2 - else: - n_threads = n_threads or 4 - - n_workers = cpu_count // n_threads - - distributed_process = ( - n_workers is not None and n_workers > 1 - ) or n_threads is not None - + # Need to know the driver class before starting dask driver_class = InversionDriver.from_input_file(input_file) - - cluster = ( - LocalCluster( - processes=True, - n_workers=n_workers, - threads_per_worker=n_threads, - ) - if distributed_process - else None + driver_class.start_dask_run( + file, + n_workers=input_file.get("n_workers", None), + n_threads=input_file.get("n_threads", None), + save_report=input_file.get("performance_report", False), ) - profiler = cProfile.Profile() - profiler.enable() - - with ( - cluster.get_client() - if cluster is not None - else contextlib.nullcontext() as context_client - ): - # Full run - with ( - performance_report(filename=file.parent / "dask_profile.html") - if (save_report and isinstance(context_client, Client)) - else contextlib.nullcontext() - ): - driver_class.start(file) - sys.stdout.close() - - profiler.disable() - - if save_report: - with open(file.parent / "runtime_profile.txt", encoding="utf-8", mode="w") as s: - ps = pstats.Stats(profiler, stream=s) - ps.sort_stats("cumulative") - ps.print_stats() diff --git a/simpeg_drivers/electromagnetics/base_1d_driver.py b/simpeg_drivers/electromagnetics/base_1d_driver.py index 7d5bed8f..c0c95771 100644 --- a/simpeg_drivers/electromagnetics/base_1d_driver.py +++ b/simpeg_drivers/electromagnetics/base_1d_driver.py @@ -125,3 +125,27 @@ def workers(self): else: self._workers = np.arange(multiprocessing.cpu_count()).tolist() return self._workers + + @classmethod + def start_dask_run( + cls, + ifile, + n_workers: int | None = None, + n_threads: int | None = None, + save_report: bool = True, + ): + """Overload configurations of BaseDriver Dask config settings.""" + # Force distributed on 1D problems + if n_workers is None: + cpu_count = multiprocessing.cpu_count() + + if cpu_count < 16: + n_threads = n_threads or 2 + else: + n_threads = n_threads or 4 + + n_workers = cpu_count // n_threads + + super().start_dask_run( + ifile, n_workers=n_workers, n_threads=n_threads, save_report=save_report + ) diff --git a/simpeg_drivers/plate_simulation/match/driver.py b/simpeg_drivers/plate_simulation/match/driver.py index e443734f..bd6820ba 100644 --- a/simpeg_drivers/plate_simulation/match/driver.py +++ b/simpeg_drivers/plate_simulation/match/driver.py @@ -10,10 +10,13 @@ from __future__ import annotations +import multiprocessing import sys from pathlib import Path +from time import time import numpy as np +from geoapps_utils.run import load_ui_json_as_dict from geoapps_utils.utils.importing import GeoAppsError from geoapps_utils.utils.locations import topo_drape_elevation from geoapps_utils.utils.logger import get_logger @@ -25,15 +28,15 @@ fetch_active_workspace, ) from geoh5py.ui_json import InputFile -from geoh5py.ui_json.ui_json import BaseUIJson from scipy import signal from scipy.sparse import csr_matrix, diags from scipy.spatial import cKDTree -from tqdm import tqdm from typing_extensions import Self from simpeg_drivers.driver import BaseDriver from simpeg_drivers.plate_simulation.match.options import MatchOptions + +# from simpeg_drivers.plate_simulation.match.uijson import PlateMatchUIJson from simpeg_drivers.plate_simulation.options import PlateSimulationOptions @@ -51,9 +54,60 @@ def __init__(self, params: MatchOptions, workers: list[tuple[str]] | None = None super().__init__(params, workers=workers) self._drape_heights = self.set_drape_height() - + self._template = self.get_template() + self._time_mask, self._time_projection = self.time_mask_and_projection() self.out_group = self.validate_out_group(self.params.out_group) + def get_template(self): + """ + Get a template simulation to extract time sampling. + """ + with Workspace(self.params.simulation_files[0], mode="r") as ws: + survey = fetch_survey(ws) + if survey.channels is None: + raise GeoAppsError( + f"No time channels found in survey of {self.params.simulation_files[0]}" + ) + + if survey.vertices is None: + raise GeoAppsError( + f"No receiver locations found in survey of {self.params.simulation_files[0]}" + ) + + return survey + + def time_mask_and_projection(self) -> tuple[np.ndarray, csr_matrix]: + """ + Create a time mask and interpolation matrix from simulation to observation times. + + Assumes that all simulations in the directory have the same time sampling. + + :return: Time mask and time interpolation matrix. + """ + + simulated_times = np.asarray(self._template.channels) + + query_times = np.asarray(self.params.survey.channels) + # Only interpolate for times within the simulated range + time_mask = (query_times > simulated_times.min()) & ( + query_times < simulated_times.max() + ) + query_times = query_times[time_mask] + right = np.searchsorted(simulated_times, query_times) + inds = np.r_[right - 1, right] + row_ids = np.tile(np.arange(len(query_times)), 2) + + # Create inverse distance weighting matrix + weights = (np.abs(query_times[row_ids] - simulated_times[inds]) + 1e-12) ** -1 + time_projection = csr_matrix( + (weights.flatten(), (row_ids, np.hstack(inds.flatten()))), + shape=(len(query_times), len(simulated_times)), + ) + row_sum = np.asarray(time_projection.sum(axis=1)).flatten() ** -1.0 + time_projection = diags(row_sum) @ time_projection + + return time_mask, time_projection + @property def out_group(self) -> SimPEGGroup: """ @@ -95,7 +149,9 @@ def start(cls, filepath: str | Path, mode="r+", **_) -> Self: """Start the parameter matching from a ui.json file.""" logger.info("Loading input file . . .") filepath = Path(filepath).resolve() - # uijson = BaseUIJson.read(filepath) + + # TODO: Replace with UIJson when fully implemented + # uijson = PlateMatchUIJson.read(filepath) uijson = InputFile.read_ui_json(filepath) with uijson.geoh5.open(mode=mode): @@ -130,42 +186,15 @@ def set_drape_height(self) -> np.ndarray: ) return topo_drape_z[:, 2] - def normalized_data(self, property_group: PropertyGroup, threshold=5) -> np.ndarray: - """ - Return data from a property group with symlog scaling and zero mean. - - :param property_group: Property group containing data channels. - :param threshold: Percentile threshold for symlog normalization. - - :return: Normalized data array. - """ - table = property_group.table() - data_array = np.vstack([table[name] for name in table.dtype.names]) - thresh = np.percentile(np.abs(data_array), threshold) - log_data = symlog(data_array, thresh) - return log_data - np.mean(log_data, axis=1)[:, None] - - def fetch_survey(self, workspace: Workspace) -> AirborneTEMReceivers | None: - """Fetch the survey from the workspace.""" - for group in workspace.groups: - if isinstance(group, SimPEGGroup): - for child in group.children: - if isinstance(child, AirborneTEMReceivers): - return child - - return None - def spatial_interpolation( self, indices: np.ndarray, - locations: np.ndarray, strike_angle: float | None = None, ) -> csr_matrix: """ Create a spatial interpolation matrix from simulation to observation locations. :param indices: Indices for the line segment of the observation locations. - :param locations: Positions to interpolate from. :param strike_angle: Optional strike angle to correct azimuths. :return: Spatial interpolation matrix. @@ -177,7 +206,7 @@ def spatial_interpolation( ) # Align azimuths to zero # Convert to polar coordinates (distance, azimuth, height) - query_polar = self.xyz_to_polar(locations) + query_polar = self.xyz_to_polar(self._template.vertices) # Get the 8 nearest neighbors in the simulation to each observation point sim_tree = cKDTree(query_polar) @@ -187,7 +216,7 @@ def spatial_interpolation( row_ids = np.kron(np.arange(local_polar.shape[0]), np.ones(8)) inv_dist_op = csr_matrix( (weights.flatten(), (row_ids, np.hstack(inds.flatten()))), - shape=(local_polar.shape[0], locations.shape[0]), + shape=(local_polar.shape[0], self._template.vertices.shape[0]), ) # Normalize the rows @@ -213,32 +242,6 @@ def xyz_to_polar(xyz: np.ndarray) -> np.ndarray: azimuths = 90 - (np.rad2deg(np.arctan2(xyz[:, 0], xyz[:, 1])) % 180) return np.c_[distances, azimuths, xyz[:, 2]] - @staticmethod - def time_interpolation( - query_times: np.ndarray, sim_times: np.ndarray - ) -> csr_matrix: - """ - Create a time interpolation matrix from simulation to observation times. - - :param query_times: Observation times. - :param sim_times: Simulation times. - - :return: Time interpolation matrix. - """ - right = np.searchsorted(sim_times, query_times) - - inds = np.r_[right - 1, right] - - row_ids = np.tile(np.arange(len(query_times)), 2) - weights = (np.abs(query_times[row_ids] - sim_times[inds]) + 1e-12) ** -1 - - time_projection = csr_matrix( - (weights.flatten(), (row_ids, np.hstack(inds.flatten()))), - shape=(len(query_times), len(sim_times)), - ) - row_sum = np.asarray(time_projection.sum(axis=1)).flatten() ** -1.0 - return diags(row_sum) @ time_projection - def get_segment_indices(self, nearest: int) -> np.ndarray: """ Get indices of line segment for a given nearest vertex. @@ -259,96 +262,49 @@ def get_segment_indices(self, nearest: int) -> np.ndarray: def run(self): """Loop over all trials and run a worker for each unique parameter set.""" - logger.info( "Running %s . . .", self.params.title, ) - observed = self.normalized_data(self.params.data) - - scores = [] - files_id = [] + observed = normalized_data(self.params.data)[self._time_mask, :] tree = cKDTree(self.params.survey.vertices[:, :2]) - spatial_projection = None - time_projection = None - for ii, query in enumerate(self.params.queries.vertices): - for sim_file in tqdm(self.params.simulation_files): - with Workspace(sim_file, mode="r") as ws: - survey = self.fetch_survey(ws) - if survey is None: - logger.warning("No survey found in %s, skipping.", sim_file) - continue + for ii, query in enumerate(self.params.queries.vertices): + tasks = [] + nearest = tree.query(query[:2], k=1)[1] + indices = self.get_segment_indices(nearest) + spatial_projection = self.spatial_interpolation( + indices, + self.params.strike_angles.values[ii], + ) - simulated = self.normalized_data( - survey.get_entity("Iteration_0_z")[0] + file_split = np.array_split(self.params.simulation_files, len(self.workers)) + + ct = time() + for file_batch in file_split: + tasks.append( + self.client.submit( + process_files_batch, + file_batch, + spatial_projection, + self._time_projection, + observed[:, indices], ) + ) - # Create a projection matrix to interpolate simulated data to the observation locations - # Assume that lines of simulations are centered at origin - if spatial_projection is None: - nearest = tree.query(query[:2], k=1)[1] - indices = self.get_segment_indices(nearest) - spatial_projection = self.spatial_interpolation( - indices, - survey.vertices, - self.params.strike_angles.values[ii], - ) - - if time_projection is None: - query_times = np.asarray(self.params.survey.channels) - simulated_times = np.asarray(survey.channels) - - # Only interpolate for times within the simulated range - time_mask = (query_times > simulated_times.min()) & ( - query_times < simulated_times.max() - ) - time_projection = self.time_interpolation( - query_times[time_mask], simulated_times - ) - observed = observed[time_mask, :] - - pred = time_projection @ (spatial_projection @ simulated.T).T - - score = 0.0 - - # if sim_file.stem == "0e50d2da-7ab0-5484-9ffd-365f076cce98": - # - # fig, ax = plt.figure(), plt.subplot() - - # Metric: normalized cross-correlation - for obs, pre in zip(observed[:, indices], pred, strict=True): - # Full cross-correlation - corr = signal.correlate( - obs, pre, mode="full" - ) # corr[k] ~ sum_t y[t] * x[t - k] - # Normalize by energy to get correlation coefficient in [-1, 1] - denom = np.linalg.norm(pre) * np.linalg.norm(obs) - if denom == 0: - corr_norm = np.zeros_like(corr) - else: - corr_norm = corr / denom - - score += np.max(corr_norm) - # if sim_file.stem == "0e50d2da-7ab0-5484-9ffd-365f076cce98": - # ax.plot(obs , 'r') - # ax.plot(pre, 'k') - - # if sim_file.stem == "0e50d2da-7ab0-5484-9ffd-365f076cce98": - # plt.show() - - scores.append(score) - files_id.append(sim_file) - - spatial_projection = None - time_projection = None + scores = np.hstack(self.client.gather(tasks)) + print(f"Processing time: {time() - ct:.1f} seconds") ranked = np.argsort(scores) - print("Top 3 matches:") + for rank in ranked[-1:][::-1]: - print(f"File: {files_id[rank].stem:30s} Score: {scores[rank]:.4f}") - with Workspace(files_id[rank], mode="r") as ws: - survey = self.fetch_survey(ws) + logger.info( + "File: %s \nScore: %.4f", + self.params.simulation_files[rank].name, + scores[rank], + ) + with Workspace(self.params.simulation_files[rank], mode="r") as ws: + survey = fetch_survey(ws) ui_json = survey.parent.parent.options ui_json["geoh5"] = ws ifile = InputFile(ui_json=ui_json) @@ -357,13 +313,112 @@ def run(self): plate = survey.parent.parent.get_entity("plate")[0].copy( parent=self.params.out_group ) - plate.vertices = plate.vertices + query + + # Set position of plate to query location + center = self.params.survey.vertices[nearest] + center[2] = self._drape_heights[nearest] + plate.vertices = plate.vertices + center plate.metadata = options.model.model_dump() print(f"Best parameters:{options.model.model_dump_json(indent=2)}") + @classmethod + def start_dask_run( + cls, + ifile, + n_workers: int | None = None, + n_threads: int | None = None, + save_report: bool = True, + ): + """Overload configurations of BaseDriver Dask config settings.""" + # Force distributed on 1D problems + if n_workers is None: + cpu_count = multiprocessing.cpu_count() + + if cpu_count < 16: + n_threads = n_threads or 2 + else: + n_threads = n_threads or 4 + + n_workers = cpu_count // n_threads + + super().start_dask_run( + ifile, n_workers=n_workers, n_threads=n_threads, save_report=save_report + ) + + +def normalized_data(property_group: PropertyGroup, threshold=5) -> np.ndarray: + """ + Return data from a property group with symlog scaling and zero mean. + + :param property_group: Property group containing data channels. + :param threshold: Percentile threshold for symlog normalization. + + :return: Normalized data array. + """ + table = property_group.table() + data_array = np.vstack([table[name] for name in table.dtype.names]) + thresh = np.percentile(np.abs(data_array), threshold) + log_data = symlog(data_array, thresh) + return log_data - np.mean(log_data, axis=1)[:, None] + + +def fetch_survey(workspace: Workspace) -> AirborneTEMReceivers | None: + """Fetch the survey from the workspace.""" + for group in workspace.groups: + if isinstance(group, SimPEGGroup): + for child in group.children: + if isinstance(child, AirborneTEMReceivers): + return child + + return None + + +def process_files_batch( + files: Path | list[Path], spatial_projection, time_projection, observed +): + scores = [] + + if isinstance(files, Path): + files = [files] + + for sim_file in files: + with Workspace(sim_file, mode="r") as ws: + survey = fetch_survey(ws) + + if survey is None: + logger.warning("No survey found in %s, skipping.", sim_file) + continue + + simulated = normalized_data(survey.get_entity("Iteration_0_z")[0]) + + pred = time_projection @ (spatial_projection @ simulated.T).T + score = 0.0 + + # Metric: normalized cross-correlation + for obs, pre in zip(observed, pred, strict=True): + # Full cross-correlation + corr = signal.correlate(obs, pre, mode="full") + # Normalize by energy to get correlation coefficient in [-1, 1] + denom = np.linalg.norm(pre) * np.linalg.norm(obs) + if denom == 0: + corr_norm = np.zeros_like(corr) + else: + corr_norm = corr / denom + + score += np.max(corr_norm) + + scores.append(score) + + return scores + if __name__ == "__main__": - file = Path(sys.argv[1]) - # file = Path(r"C:\Users\dominiquef\Documents\Workspace\Teck\RnD\plate_match_v2.ui.json") - PlateMatchDriver.start(file) + file = Path(sys.argv[1]).resolve() + input_file = load_ui_json_as_dict(file) + PlateMatchDriver.start_dask_run( + file, + n_workers=input_file.get("n_workers", None), + n_threads=input_file.get("n_threads", None), + save_report=input_file.get("performance_report", False), + ) From 12c5416256ccc8ef466e8540e95768ede2af8b76 Mon Sep 17 00:00:00 2001 From: dominiquef Date: Fri, 16 Jan 2026 12:56:36 -0800 Subject: [PATCH 07/12] Add progress bar --- simpeg_drivers/plate_simulation/match/driver.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/simpeg_drivers/plate_simulation/match/driver.py b/simpeg_drivers/plate_simulation/match/driver.py index bd6820ba..fece2d6b 100644 --- a/simpeg_drivers/plate_simulation/match/driver.py +++ b/simpeg_drivers/plate_simulation/match/driver.py @@ -13,9 +13,9 @@ import multiprocessing import sys from pathlib import Path -from time import time import numpy as np +from dask.distributed import progress from geoapps_utils.run import load_ui_json_as_dict from geoapps_utils.utils.importing import GeoAppsError from geoapps_utils.utils.locations import topo_drape_elevation @@ -278,9 +278,10 @@ def run(self): self.params.strike_angles.values[ii], ) - file_split = np.array_split(self.params.simulation_files, len(self.workers)) + file_split = np.array_split( + self.params.simulation_files, len(self.workers) * 10 + ) - ct = time() for file_batch in file_split: tasks.append( self.client.submit( @@ -291,10 +292,9 @@ def run(self): observed[:, indices], ) ) - + # Display progress bar + progress(tasks) scores = np.hstack(self.client.gather(tasks)) - - print(f"Processing time: {time() - ct:.1f} seconds") ranked = np.argsort(scores) for rank in ranked[-1:][::-1]: From 6741a6668338afba1434981c1e068b9fb5636831 Mon Sep 17 00:00:00 2001 From: dominiquef Date: Fri, 16 Jan 2026 14:58:25 -0800 Subject: [PATCH 08/12] Refactor code to create sparse inv interpolation --- .../plate_simulation/match/driver.py | 100 ++++++++---------- 1 file changed, 43 insertions(+), 57 deletions(-) diff --git a/simpeg_drivers/plate_simulation/match/driver.py b/simpeg_drivers/plate_simulation/match/driver.py index fece2d6b..4e98199e 100644 --- a/simpeg_drivers/plate_simulation/match/driver.py +++ b/simpeg_drivers/plate_simulation/match/driver.py @@ -56,7 +56,6 @@ def __init__(self, params: MatchOptions, workers: list[tuple[str]] | None = None self._drape_heights = self.set_drape_height() self._template = self.get_template() self._time_mask, self._time_projection = self.time_mask_and_projection() - self.out_group = self.validate_out_group(self.params.out_group) def get_template(self): """ @@ -84,9 +83,7 @@ def time_mask_and_projection(self) -> tuple[np.ndarray, csr_matrix]: :return: Time mask and time interpolation matrix. """ - simulated_times = np.asarray(self._template.channels) - query_times = np.asarray(self.params.survey.channels) # Only interpolate for times within the simulated range time_mask = (query_times > simulated_times.min()) & ( @@ -94,56 +91,16 @@ def time_mask_and_projection(self) -> tuple[np.ndarray, csr_matrix]: ) query_times = query_times[time_mask] right = np.searchsorted(simulated_times, query_times) - inds = np.r_[right - 1, right] - row_ids = np.tile(np.arange(len(query_times)), 2) - - # Create inverse distance weighting matrix - weights = (np.abs(query_times[row_ids] - simulated_times[inds]) + 1e-12) ** -1 - time_projection = csr_matrix( - (weights.flatten(), (row_ids, np.hstack(inds.flatten()))), - shape=(len(query_times), len(simulated_times)), - ) - row_sum = np.asarray(time_projection.sum(axis=1)).flatten() ** -1.0 - time_projection = diags(row_sum) @ time_projection + inds = np.c_[right - 1, right].flatten() + row_ids = np.repeat(np.arange(len(query_times)), 2) + # Create inverse distance weighting matrix based on time difference + time_diff = np.abs(query_times[row_ids] - simulated_times[inds]) + time_projection = self.inverse_weighted_operator( + time_diff, inds, (len(query_times), len(simulated_times)), 1.0, 1e-12 + ) return time_mask, time_projection - @property - def out_group(self) -> SimPEGGroup: - """ - Returns the output group for the simulation. - """ - return self._out_group - - @out_group.setter - def out_group(self, value: SimPEGGroup): - if not isinstance(value, SimPEGGroup): - raise TypeError("Output group must be a SimPEGGroup.") - - if self.params.out_group != value: - self.params.out_group = value - self.params.update_out_group_options() - - self._out_group = value - - def validate_out_group(self, out_group: SimPEGGroup | None) -> SimPEGGroup: - """ - Validate or create a SimPEGGroup to store results. - - :param out_group: Output group from selection. - """ - if isinstance(out_group, SimPEGGroup): - return out_group - - with fetch_active_workspace(self.params.geoh5, mode="r+"): - out_group = SimPEGGroup.create( - self.params.geoh5, - name=self.params.title, - ) - out_group.entity_type.name = self.params.title - - return out_group - @classmethod def start(cls, filepath: str | Path, mode="r+", **_) -> Self: """Start the parameter matching from a ui.json file.""" @@ -212,13 +169,40 @@ def spatial_interpolation( sim_tree = cKDTree(query_polar) rad, inds = sim_tree.query(local_polar, k=8) - weights = (rad**2.0 + 1e-1) ** -1 - row_ids = np.kron(np.arange(local_polar.shape[0]), np.ones(8)) - inv_dist_op = csr_matrix( - (weights.flatten(), (row_ids, np.hstack(inds.flatten()))), - shape=(local_polar.shape[0], self._template.vertices.shape[0]), + return self.inverse_weighted_operator( + rad.flatten(), + inds.flatten(), + (local_polar.shape[0], self._template.vertices.shape[0]), + 2.0, + 1e-1, ) + @staticmethod + def inverse_weighted_operator( + values: np.ndarray, + col_indices: np.ndarray, + shape: tuple, + power: float, + threshold: float, + ) -> csr_matrix: + """ + Create an inverse distance weighted sparse matrix. + + :param values: Distance values. + :param col_indices: Column indices for the sparse matrix. + :param shape: Shape of the sparse matrix. + :param power: Power for the inverse distance weighting. + :param threshold: Threshold to avoid singularities. + + :return: Inverse distance weighted sparse matrix. + """ + weights = (values**power + threshold) ** -1 + n_vals_row = weights.shape[0] // shape[0] + row_ids = np.repeat(np.arange(shape[0]), n_vals_row) + inv_dist_op = csr_matrix( + (weights, (row_ids, col_indices)), + shape=shape, + ) # Normalize the rows row_sum = np.asarray(inv_dist_op.sum(axis=1)).flatten() ** -1.0 return diags(row_sum) @ inv_dist_op @@ -391,7 +375,6 @@ def process_files_batch( continue simulated = normalized_data(survey.get_entity("Iteration_0_z")[0]) - pred = time_projection @ (spatial_projection @ simulated.T).T score = 0.0 @@ -414,7 +397,10 @@ def process_files_batch( if __name__ == "__main__": - file = Path(sys.argv[1]).resolve() + # file = Path(sys.argv[1]).resolve() + file = Path( + r"C:\Users\dominiquef\Documents\Workspace\Teck\RnD\plate_match_v2.ui.json" + ) input_file = load_ui_json_as_dict(file) PlateMatchDriver.start_dask_run( file, From 6396a83846f5a8cfe4f5f81f9d127198834bb251 Mon Sep 17 00:00:00 2001 From: dominiquef Date: Fri, 16 Jan 2026 15:01:46 -0800 Subject: [PATCH 09/12] Move out_group setter to BaseDriver. Remove duplicate sub-class setters --- simpeg_drivers/driver.py | 68 ++++++++++--------- simpeg_drivers/plate_simulation/driver.py | 8 --- .../plate_simulation/sweep/driver.py | 36 ---------- 3 files changed, 37 insertions(+), 75 deletions(-) diff --git a/simpeg_drivers/driver.py b/simpeg_drivers/driver.py index 90978e0a..84245c87 100644 --- a/simpeg_drivers/driver.py +++ b/simpeg_drivers/driver.py @@ -95,6 +95,7 @@ def __init__( workers: list[str] | None = None, ): super().__init__(params) + self.out_group = self.validate_out_group(self.params.out_group) self._client: Client | bool = self.validate_client(client) if getattr(self.params, "store_sensitivities", None) == "disk" and self.client: @@ -104,6 +105,42 @@ def __init__( self._workers: list[tuple[str]] | None = self.validate_workers(workers) + @property + def out_group(self) -> SimPEGGroup: + """ + Returns the output group for the simulation. + """ + return self._out_group + + @out_group.setter + def out_group(self, value: SimPEGGroup): + if not isinstance(value, SimPEGGroup): + raise TypeError("Output group must be a SimPEGGroup.") + + if self.params.out_group != value: + self.params.out_group = value + self.params.update_out_group_options() + + self._out_group = value + + def validate_out_group(self, out_group: SimPEGGroup | None) -> SimPEGGroup: + """ + Validate or create a SimPEGGroup to store results. + + :param out_group: Output group from selection. + """ + if isinstance(out_group, SimPEGGroup): + return out_group + + with fetch_active_workspace(self.params.geoh5, mode="r+"): + out_group = SimPEGGroup.create( + self.params.geoh5, + name=self.params.title, + ) + out_group.entity_type.name = self.params.title + + return out_group + @property def client(self) -> Client | bool | None: """ @@ -224,7 +261,6 @@ def __init__( super().__init__(params, client=client, workers=workers) self.inversion_type = self.params.inversion_type - self.out_group = self.validate_out_group(self.params.out_group) self._data_misfit: objective_function.ComboObjectiveFunction | None = None self._directives: list[directives.InversionDirective] | None = None self._inverse_problem: inverse_problem.BaseInvProblem | None = None @@ -432,36 +468,6 @@ def ordering(self): """List of ordering of the data.""" return self.inversion_data.survey.ordering - @property - def out_group(self) -> SimPEGGroup: - """ - Returns the output group for the simulation. - """ - return self._out_group - - @out_group.setter - def out_group(self, value: SimPEGGroup): - if not isinstance(value, SimPEGGroup): - raise TypeError("Output group must be a SimPEGGroup.") - - self.params.out_group = value - self.params.update_out_group_options() - self._out_group = value - - def validate_out_group(self, out_group: SimPEGGroup | None) -> SimPEGGroup: - """ - Validate or create a SimPEGGroup to store results. - - :param out_group: Output group from selection. - """ - if isinstance(out_group, SimPEGGroup): - return out_group - - with fetch_active_workspace(self.workspace, mode="r+"): - out_group = SimPEGGroup.create(self.workspace, name=self.params.title) - - return out_group - @property def params(self) -> BaseForwardOptions | BaseInversionOptions: """Application parameters.""" diff --git a/simpeg_drivers/plate_simulation/driver.py b/simpeg_drivers/plate_simulation/driver.py index 0b8fd40c..d30ef5c1 100644 --- a/simpeg_drivers/plate_simulation/driver.py +++ b/simpeg_drivers/plate_simulation/driver.py @@ -62,7 +62,6 @@ def __init__( self._model: FloatData | None = None self._simulation_parameters: BaseForwardOptions | None = None self._simulation_driver: InversionDriver | None = None - self._out_group = self.validate_out_group(self.params.out_group) def run(self) -> InversionDriver: """Create octree mesh, fill model, and simulate.""" @@ -77,13 +76,6 @@ def run(self) -> InversionDriver: return self.simulation_driver - @property - def out_group(self) -> SimPEGGroup: - """ - Returns the output group for the simulation. - """ - return self._out_group - def validate_out_group(self, out_group: SimPEGGroup | None) -> SimPEGGroup: """ Validate or create a SimPEGGroup to store results. diff --git a/simpeg_drivers/plate_simulation/sweep/driver.py b/simpeg_drivers/plate_simulation/sweep/driver.py index f1cee0c9..2ca7921c 100644 --- a/simpeg_drivers/plate_simulation/sweep/driver.py +++ b/simpeg_drivers/plate_simulation/sweep/driver.py @@ -48,42 +48,6 @@ def __init__(self, params: SweepOptions, workers: list[tuple[str]] | None = None self.out_group = self.validate_out_group(self.params.out_group) - @property - def out_group(self) -> SimPEGGroup: - """ - Returns the output group for the simulation. - """ - return self._out_group - - @out_group.setter - def out_group(self, value: SimPEGGroup): - if not isinstance(value, SimPEGGroup): - raise TypeError("Output group must be a SimPEGGroup.") - - if self.params.out_group != value: - self.params.out_group = value - self.params.update_out_group_options() - - self._out_group = value - - def validate_out_group(self, out_group: SimPEGGroup | None) -> SimPEGGroup: - """ - Validate or create a UIJsonGroup to store results. - - :param value: Output group from selection. - """ - if isinstance(out_group, SimPEGGroup): - return out_group - - with fetch_active_workspace(self.params.geoh5, mode="r+"): - out_group = SimPEGGroup.create( - self.params.geoh5, - name=self.params.title, - ) - out_group.entity_type.name = self.params.title - - return out_group - @classmethod def start(cls, filepath: str | Path, mode="r", **_) -> Self: """Start the parameter sweep from a ui.json file.""" From 3c25bf14761a5a77934ee9d35afb5719cd40cad9 Mon Sep 17 00:00:00 2001 From: dominiquef Date: Mon, 19 Jan 2026 11:01:42 -0800 Subject: [PATCH 10/12] Move basic functions to utils. Start adding unitests --- simpeg_drivers/plate_simulation/driver.py | 8 +- .../plate_simulation/match/driver.py | 74 +++--------------- .../plate_simulation/match/options.py | 8 +- simpeg_drivers/utils/utils.py | 48 ++++++++++++ .../plate_simulation/runtest/gravity_test.py | 1 - tests/plate_simulation/runtest/match_test.py | 77 +++++++++++++++++++ tests/utils_test.py | 57 ++++++++++++++ 7 files changed, 201 insertions(+), 72 deletions(-) create mode 100644 tests/plate_simulation/runtest/match_test.py create mode 100644 tests/utils_test.py diff --git a/simpeg_drivers/plate_simulation/driver.py b/simpeg_drivers/plate_simulation/driver.py index d30ef5c1..0ba111bf 100644 --- a/simpeg_drivers/plate_simulation/driver.py +++ b/simpeg_drivers/plate_simulation/driver.py @@ -15,7 +15,7 @@ import numpy as np from dask.distributed import Client -from geoapps_utils.base import Driver, get_logger +from geoapps_utils.base import get_logger from geoapps_utils.utils.transformations import azimuth_to_unit_vector from geoh5py.data import FloatData, ReferencedData from geoh5py.groups import SimPEGGroup @@ -40,10 +40,8 @@ class PlateSimulationDriver(BaseDriver): :param params: Parameters for plate simulation (mesh, model and series). - :param plate: Plate object used to add anomaly to the model. - :param mesh: Octree mesh in which model is built for the simulation. - :param model: Model to simulate. - :param survey: Survey object for the simulation + :param client: Dask client for parallel processing. + :param workers: List of worker addresses for Dask client. """ _params_class = PlateSimulationOptions diff --git a/simpeg_drivers/plate_simulation/match/driver.py b/simpeg_drivers/plate_simulation/match/driver.py index 4e98199e..e961031f 100644 --- a/simpeg_drivers/plate_simulation/match/driver.py +++ b/simpeg_drivers/plate_simulation/match/driver.py @@ -29,28 +29,27 @@ ) from geoh5py.ui_json import InputFile from scipy import signal -from scipy.sparse import csr_matrix, diags +from scipy.sparse import csr_matrix from scipy.spatial import cKDTree from typing_extensions import Self from simpeg_drivers.driver import BaseDriver -from simpeg_drivers.plate_simulation.match.options import MatchOptions - -# from simpeg_drivers.plate_simulation.match.uijson import PlateMatchUIJson +from simpeg_drivers.plate_simulation.match.options import PlateMatchOptions from simpeg_drivers.plate_simulation.options import PlateSimulationOptions +from simpeg_drivers.utils.utils import inverse_weighted_operator, xyz_to_polar -# import matplotlib.pyplot as plt - logger = get_logger(name=__name__, level_name=False, propagate=False, add_name=False) class PlateMatchDriver(BaseDriver): """Sets up and manages workers to run all combinations of swepts parameters.""" - _params_class = MatchOptions + _params_class = PlateMatchOptions - def __init__(self, params: MatchOptions, workers: list[tuple[str]] | None = None): + def __init__( + self, params: PlateMatchOptions, workers: list[tuple[str]] | None = None + ): super().__init__(params, workers=workers) self._drape_heights = self.set_drape_height() @@ -96,7 +95,7 @@ def time_mask_and_projection(self) -> tuple[np.ndarray, csr_matrix]: # Create inverse distance weighting matrix based on time difference time_diff = np.abs(query_times[row_ids] - simulated_times[inds]) - time_projection = self.inverse_weighted_operator( + time_projection = inverse_weighted_operator( time_diff, inds, (len(query_times), len(simulated_times)), 1.0, 1e-12 ) return time_mask, time_projection @@ -113,7 +112,7 @@ def start(cls, filepath: str | Path, mode="r+", **_) -> Self: with uijson.geoh5.open(mode=mode): try: - options = MatchOptions.build(uijson) + options = PlateMatchOptions.build(uijson) logger.info("Initializing application . . .") driver = cls(options) logger.info("Running application . . .") @@ -157,19 +156,19 @@ def spatial_interpolation( :return: Spatial interpolation matrix. """ # Compute local coordinates for the current line segment - local_polar = self.xyz_to_polar(self.params.survey.vertices[indices, :]) + local_polar = xyz_to_polar(self.params.survey.vertices[indices, :]) local_polar[:, 1] = ( 0.0 if strike_angle is None else strike_angle ) # Align azimuths to zero # Convert to polar coordinates (distance, azimuth, height) - query_polar = self.xyz_to_polar(self._template.vertices) + query_polar = xyz_to_polar(self._template.vertices) # Get the 8 nearest neighbors in the simulation to each observation point sim_tree = cKDTree(query_polar) rad, inds = sim_tree.query(local_polar, k=8) - return self.inverse_weighted_operator( + return inverse_weighted_operator( rad.flatten(), inds.flatten(), (local_polar.shape[0], self._template.vertices.shape[0]), @@ -177,55 +176,6 @@ def spatial_interpolation( 1e-1, ) - @staticmethod - def inverse_weighted_operator( - values: np.ndarray, - col_indices: np.ndarray, - shape: tuple, - power: float, - threshold: float, - ) -> csr_matrix: - """ - Create an inverse distance weighted sparse matrix. - - :param values: Distance values. - :param col_indices: Column indices for the sparse matrix. - :param shape: Shape of the sparse matrix. - :param power: Power for the inverse distance weighting. - :param threshold: Threshold to avoid singularities. - - :return: Inverse distance weighted sparse matrix. - """ - weights = (values**power + threshold) ** -1 - n_vals_row = weights.shape[0] // shape[0] - row_ids = np.repeat(np.arange(shape[0]), n_vals_row) - inv_dist_op = csr_matrix( - (weights, (row_ids, col_indices)), - shape=shape, - ) - # Normalize the rows - row_sum = np.asarray(inv_dist_op.sum(axis=1)).flatten() ** -1.0 - return diags(row_sum) @ inv_dist_op - - @staticmethod - def xyz_to_polar(xyz: np.ndarray) -> np.ndarray: - """ - Convert Cartesian coordinates to polar coordinates defined as - (distance, azimuth, height), where distance is signed based on the - x-coordinate relative to the mean location. - - :param xyz: Cartesian coordinates. - - :return: Polar coordinates (distance, azimuth, height). - """ - mean_loc = np.mean(xyz, axis=0) - distances = np.sign(xyz[:, 0] - mean_loc[0]) * np.linalg.norm( - xyz[:, :2] - mean_loc[:2], axis=1 - ) - - azimuths = 90 - (np.rad2deg(np.arctan2(xyz[:, 0], xyz[:, 1])) % 180) - return np.c_[distances, azimuths, xyz[:, 2]] - def get_segment_indices(self, nearest: int) -> np.ndarray: """ Get indices of line segment for a given nearest vertex. diff --git a/simpeg_drivers/plate_simulation/match/options.py b/simpeg_drivers/plate_simulation/match/options.py index 2d5a32c8..90a655ff 100644 --- a/simpeg_drivers/plate_simulation/match/options.py +++ b/simpeg_drivers/plate_simulation/match/options.py @@ -23,7 +23,7 @@ from simpeg_drivers import assets_path -class MatchOptions(Options): +class PlateMatchOptions(Options): """ Options for matching signal from a survey against a library of simulations. @@ -39,8 +39,8 @@ class MatchOptions(Options): name: ClassVar[str] = "plate_match" default_ui_json: ClassVar[Path] = assets_path() / "uijson/plate_match.ui.json" - title: ClassVar[str] = "Plate Match" - run_command: ClassVar[str] = "simpeg_drivers.plate_simulation.match.driver" + title: str = "Plate Match" + run_command: str = "simpeg_drivers.plate_simulation.match.driver" out_group: SimPEGGroup | None = None survey: AirborneTEMReceivers @@ -50,7 +50,7 @@ class MatchOptions(Options): max_distance: float = 1000.0 topography_object: Points | Grid2D topography: FloatData | None = None - simulations: str + simulations: str | Path @property def simulation_files(self) -> list[Path]: diff --git a/simpeg_drivers/utils/utils.py b/simpeg_drivers/utils/utils.py index e79e52e4..7e5ab46f 100644 --- a/simpeg_drivers/utils/utils.py +++ b/simpeg_drivers/utils/utils.py @@ -29,6 +29,7 @@ from geoh5py.ui_json import InputFile from grid_apps.utils import octree_2_treemesh from scipy.interpolate import LinearNDInterpolator, NearestNDInterpolator, interp1d +from scipy.sparse import csr_matrix, diags from scipy.spatial import ConvexHull, Delaunay, cKDTree from simpeg_drivers import DRIVER_MAP @@ -572,3 +573,50 @@ def simpeg_group_to_driver(group: SimPEGGroup, workspace: Workspace) -> Inversio params = inversion_driver._params_class.build(ifile) # pylint: disable=protected-access return inversion_driver(params) + + +def xyz_to_polar(locations: np.ndarray) -> np.ndarray: + """ + Convert Cartesian coordinates to polar coordinates defined as + (distance, azimuth, height), where distance is signed based on the + x-coordinate relative to the mean location. + + :param locations: Cartesian coordinates. + + :return: Polar coordinates (distance, azimuth, height). + """ + xyz = locations - np.mean(locations, axis=0) + distances = np.sign(xyz[:, 0]) * np.linalg.norm(xyz[:, :2], axis=1) + + azimuths = 90 - (np.rad2deg(np.arctan2(xyz[:, 0], xyz[:, 1])) % 180) + return np.c_[distances, azimuths, locations[:, 2]] + + +def inverse_weighted_operator( + values: np.ndarray, + col_indices: np.ndarray, + shape: tuple, + power: float, + threshold: float, +) -> csr_matrix: + """ + Create an inverse distance weighted sparse matrix. + + :param values: Distance values. + :param col_indices: Column indices for the sparse matrix. + :param shape: Shape of the sparse matrix. + :param power: Power for the inverse distance weighting. + :param threshold: Threshold to avoid singularities. + + :return: Inverse distance weighted sparse matrix. + """ + weights = (values**power + threshold) ** -1 + n_vals_row = weights.shape[0] // shape[0] + row_ids = np.repeat(np.arange(shape[0]), n_vals_row) + inv_dist_op = csr_matrix( + (weights, (row_ids, col_indices)), + shape=shape, + ) + # Normalize the rows + row_sum = np.asarray(inv_dist_op.sum(axis=1)).flatten() ** -1.0 + return diags(row_sum) @ inv_dist_op diff --git a/tests/plate_simulation/runtest/gravity_test.py b/tests/plate_simulation/runtest/gravity_test.py index 3c702d4b..130517bb 100644 --- a/tests/plate_simulation/runtest/gravity_test.py +++ b/tests/plate_simulation/runtest/gravity_test.py @@ -9,7 +9,6 @@ # ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' import numpy as np -from geoh5py import Workspace from geoh5py.groups import SimPEGGroup from simpeg_drivers.plate_simulation.driver import PlateSimulationDriver diff --git a/tests/plate_simulation/runtest/match_test.py b/tests/plate_simulation/runtest/match_test.py new file mode 100644 index 00000000..d8c3082f --- /dev/null +++ b/tests/plate_simulation/runtest/match_test.py @@ -0,0 +1,77 @@ +# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' +# Copyright (c) 2023-2026 Mira Geoscience Ltd. ' +# ' +# This file is part of simpeg-drivers package. ' +# ' +# simpeg-drivers is distributed under the terms and conditions of the MIT License ' +# (see LICENSE file at the root of this source code package). ' +# ' +# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' + +from pathlib import Path + +import numpy as np +from geoh5py import Workspace +from geoh5py.groups import PropertyGroup +from geoh5py.objects import Points + +from simpeg_drivers.plate_simulation.match.options import PlateMatchOptions +from simpeg_drivers.utils.synthetics.driver import ( + SyntheticsComponents, +) +from simpeg_drivers.utils.synthetics.options import ( + MeshOptions, + ModelOptions, + SurveyOptions, + SyntheticsComponentsOptions, +) +from tests.utils.targets import get_workspace + + +def generate_example(geoh5: Workspace, n_grid_points: int, refinement: tuple[int]): + opts = SyntheticsComponentsOptions( + method="airborne tdem", + survey=SurveyOptions( + n_stations=n_grid_points, n_lines=n_grid_points, drape=10.0 + ), + mesh=MeshOptions(refinement=refinement, padding_distance=400.0), + model=ModelOptions(background=0.001), + ) + components = SyntheticsComponents(geoh5, options=opts) + vals = components.survey.add_data( + {"observed_data": {"values": np.random.randn(components.survey.n_vertices)}}, + ) + components.property_group = PropertyGroup(components.survey, properties=vals) + components.queries = Points.create(geoh5, vertices=np.random.randn(1, 3)) + + return components + + +def test_file_parsing(tmp_path: Path): + """ + Generate a few files and test the + plate_simulation.match.Options.simulation_files() method. + """ + filenames = [ + "sim_001.txt", + "sim_002.txt", + "sim_010.txt", + "sim_011.txt", + ] + for fname in filenames: + (tmp_path / fname).touch() + + with get_workspace(tmp_path / f"{__name__}.geoh5") as geoh5: + components = generate_example(geoh5, n_grid_points=3, refinement=(2,)) + options = PlateMatchOptions( + geoh5=geoh5, + survey=components.survey, + data=components.property_group, + queries=components.queries, + topography_object=components.topography, + simulations=tmp_path, + ) + + sim_files = options.simulation_files + assert len(sim_files) == 1 + assert sim_files[0].name == f"{__name__}.geoh5" diff --git a/tests/utils_test.py b/tests/utils_test.py new file mode 100644 index 00000000..13888682 --- /dev/null +++ b/tests/utils_test.py @@ -0,0 +1,57 @@ +# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' +# Copyright (c) 2023-2026 Mira Geoscience Ltd. ' +# ' +# This file is part of simpeg-drivers package. ' +# ' +# simpeg-drivers is distributed under the terms and conditions of the MIT License ' +# (see LICENSE file at the root of this source code package). ' +# ' +# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' + +import numpy as np + +from simpeg_drivers.utils.utils import inverse_weighted_operator, xyz_to_polar + + +def test_xyz_to_polar(): + """ + Test the xyz_to_polar utility function. + """ + for _ in range(100): + rad = np.abs(np.random.randn()) + azm = np.random.randn() + + # Create x, y, z coordinates + x = rad * np.cos(azm) + y = rad * np.sin(azm) + z = np.random.randn() + + polar = xyz_to_polar(np.vstack([[[0, 0, 0]], [[x, y, z]]])) + np.testing.assert_almost_equal( + polar[0, 0], -polar[1, 0] + ) # Opposite side of center + np.testing.assert_almost_equal(polar[0, 1], polar[1, 1]) # Same azimuth + np.testing.assert_allclose([0, z], polar[:, 2]) # Preserves z + + +def test_inverse_weighted_operator(): + """ + Test the inverse_weighted_operator utility function. + + For a constant input, the output should be the same constant. + """ + power = 2.0 + threshold = 1e-12 + shape = (100, 1000) + values = np.random.randn(shape[0] * 2) + indices = np.c_[ + np.random.randint(0, shape[1] - 1, shape[0]), + np.random.randint(0, shape[1] - 1, shape[0]), + ].flatten() + + opt = inverse_weighted_operator(values, indices, shape, power, threshold) + test_val = np.random.randn() + interp = opt * np.full(shape[1], test_val) + + assert opt.shape == shape + np.testing.assert_allclose(interp, test_val, rtol=1e-3) From dc43052ead263a90051e1d7663a1da5736d1b973 Mon Sep 17 00:00:00 2001 From: dominiquef Date: Mon, 19 Jan 2026 11:18:09 -0800 Subject: [PATCH 11/12] Remove dprecated configure_dask in joint. Rename opt options --- simpeg_drivers/driver.py | 8 ++++---- simpeg_drivers/joint/driver.py | 1 - 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/simpeg_drivers/driver.py b/simpeg_drivers/driver.py index 84245c87..c156d919 100644 --- a/simpeg_drivers/driver.py +++ b/simpeg_drivers/driver.py @@ -448,16 +448,16 @@ def n_values(self): def optimization(self): if getattr(self, "_optimization", None) is None: if self.params.forward_only: - return optimization.ProjectedGNCG() + return optimization.ProjectedGNCG(cg_rtol=1.0) self._optimization = optimization.ProjectedGNCG( maxIter=self.params.optimization.max_global_iterations, lower=self.models.lower_bound, upper=self.models.upper_bound, maxIterLS=self.params.optimization.max_line_search_iterations, - maxIterCG=self.params.optimization.max_cg_iterations, - tolCG=self.params.optimization.tol_cg, - stepOffBoundsFact=1e-8, + cg_maxiter=self.params.optimization.max_cg_iterations, + cg_rtol=self.params.optimization.tol_cg, + active_set_grad_scale=1e-8, LSshorten=0.25, require_decrease=False, ) diff --git a/simpeg_drivers/joint/driver.py b/simpeg_drivers/joint/driver.py index 07d5e30f..cd945bfa 100644 --- a/simpeg_drivers/joint/driver.py +++ b/simpeg_drivers/joint/driver.py @@ -234,7 +234,6 @@ def run(self): if self.logger: sys.stdout = self.logger self.logger.start() - self.configure_dask() if Path(self.params.input_file.path_name).is_file(): with fetch_active_workspace(self.workspace, mode="r+"): From 7062d7b7e9f07f46d0b44f93029f3ac647cb4340 Mon Sep 17 00:00:00 2001 From: dominiquef Date: Mon, 19 Jan 2026 16:56:50 -0800 Subject: [PATCH 12/12] Move utils to geoh5py and geoapps_driver. Add unitest --- .../plate_simulation/match/driver.py | 138 +++++++++--------- simpeg_drivers/utils/utils.py | 47 ------ tests/plate_simulation/runtest/match_test.py | 85 ++++++++++- tests/utils_test.py | 57 -------- 4 files changed, 152 insertions(+), 175 deletions(-) delete mode 100644 tests/utils_test.py diff --git a/simpeg_drivers/plate_simulation/match/driver.py b/simpeg_drivers/plate_simulation/match/driver.py index e961031f..16c0cf37 100644 --- a/simpeg_drivers/plate_simulation/match/driver.py +++ b/simpeg_drivers/plate_simulation/match/driver.py @@ -15,12 +15,14 @@ from pathlib import Path import numpy as np -from dask.distributed import progress +from dask.distributed import Future, progress from geoapps_utils.run import load_ui_json_as_dict from geoapps_utils.utils.importing import GeoAppsError from geoapps_utils.utils.locations import topo_drape_elevation from geoapps_utils.utils.logger import get_logger +from geoapps_utils.utils.numerical import inverse_weighted_operator from geoapps_utils.utils.plotting import symlog +from geoapps_utils.utils.transformations import xyz_to_polar from geoh5py import Workspace from geoh5py.groups import PropertyGroup, SimPEGGroup from geoh5py.objects import AirborneTEMReceivers, Surface @@ -36,7 +38,6 @@ from simpeg_drivers.driver import BaseDriver from simpeg_drivers.plate_simulation.match.options import PlateMatchOptions from simpeg_drivers.plate_simulation.options import PlateSimulationOptions -from simpeg_drivers.utils.utils import inverse_weighted_operator, xyz_to_polar logger = get_logger(name=__name__, level_name=False, propagate=False, add_name=False) @@ -52,7 +53,7 @@ def __init__( ): super().__init__(params, workers=workers) - self._drape_heights = self.set_drape_height() + self._drape_heights = self._get_drape_heights() self._template = self.get_template() self._time_mask, self._time_projection = self.time_mask_and_projection() @@ -85,12 +86,12 @@ def time_mask_and_projection(self) -> tuple[np.ndarray, csr_matrix]: simulated_times = np.asarray(self._template.channels) query_times = np.asarray(self.params.survey.channels) # Only interpolate for times within the simulated range - time_mask = (query_times > simulated_times.min()) & ( - query_times < simulated_times.max() + time_mask = (query_times >= simulated_times.min()) & ( + query_times <= simulated_times.max() ) query_times = query_times[time_mask] right = np.searchsorted(simulated_times, query_times) - inds = np.c_[right - 1, right].flatten() + inds = np.c_[np.maximum(0, right - 1), right].flatten() row_ids = np.repeat(np.arange(len(query_times)), 2) # Create inverse distance weighting matrix based on time difference @@ -125,7 +126,7 @@ def start(cls, filepath: str | Path, mode="r+", **_) -> Self: return driver - def set_drape_height(self) -> np.ndarray: + def _get_drape_heights(self) -> np.ndarray: """Set drape heights based on topography object and optional topography data.""" topo = self.params.topography_object.locations @@ -156,18 +157,24 @@ def spatial_interpolation( :return: Spatial interpolation matrix. """ # Compute local coordinates for the current line segment - local_polar = xyz_to_polar(self.params.survey.vertices[indices, :]) + local_polar = xyz_to_polar( + self.params.survey.vertices[indices] + - np.r_[self.params.survey.vertices[indices, :2].mean(axis=0), 0] + ) + local_polar[local_polar[:, 1] >= 180, 0] *= -1 # Wrap azimuths local_polar[:, 1] = ( 0.0 if strike_angle is None else strike_angle ) # Align azimuths to zero # Convert to polar coordinates (distance, azimuth, height) query_polar = xyz_to_polar(self._template.vertices) + query_polar[query_polar[:, 1] >= 180, 0] *= -1 + query_polar[:, 1] = query_polar[:, 1] % 180 # Wrap azimuths # Get the 8 nearest neighbors in the simulation to each observation point sim_tree = cKDTree(query_polar) rad, inds = sim_tree.query(local_polar, k=8) - + inds = np.minimum(query_polar.shape[0] - 1, inds) return inverse_weighted_operator( rad.flatten(), inds.flatten(), @@ -176,24 +183,6 @@ def spatial_interpolation( 1e-1, ) - def get_segment_indices(self, nearest: int) -> np.ndarray: - """ - Get indices of line segment for a given nearest vertex. - - :param nearest: Nearest vertex index. - """ - line_mask = np.where( - self.params.survey.parts == self.params.survey.parts[nearest] - )[0] - distances = np.linalg.norm( - self.params.survey.vertices[nearest, :2] - - self.params.survey.vertices[line_mask, :2], - axis=1, - ) - dist_mask = distances < self.params.max_distance - indices = line_mask[dist_mask] - return indices - def run(self): """Loop over all trials and run a worker for each unique parameter set.""" logger.info( @@ -202,59 +191,74 @@ def run(self): ) observed = normalized_data(self.params.data)[self._time_mask, :] tree = cKDTree(self.params.survey.vertices[:, :2]) - + results = [] for ii, query in enumerate(self.params.queries.vertices): - tasks = [] + # Find the nearest survey location to the query point nearest = tree.query(query[:2], k=1)[1] - indices = self.get_segment_indices(nearest) + indices = self.params.survey.get_segment_indices( + nearest, self.params.max_distance + ) spatial_projection = self.spatial_interpolation( indices, - self.params.strike_angles.values[ii], + 0 + if self.params.strike_angles is None + else self.params.strike_angles.values[ii], ) - file_split = np.array_split( - self.params.simulation_files, len(self.workers) * 10 + self.params.simulation_files, np.maximum(1, len(self.workers) * 10) ) + tasks = [] for file_batch in file_split: + args = ( + file_batch, + spatial_projection, + self._time_projection, + observed[:, indices], + ) + tasks.append( - self.client.submit( - process_files_batch, - file_batch, - spatial_projection, - self._time_projection, - observed[:, indices], - ) + self.client.submit(process_files_batch, *args) + if self.client + else process_files_batch(*args) ) + # Display progress bar - progress(tasks) - scores = np.hstack(self.client.gather(tasks)) - ranked = np.argsort(scores) - - for rank in ranked[-1:][::-1]: - logger.info( - "File: %s \nScore: %.4f", - self.params.simulation_files[rank].name, - scores[rank], + if isinstance(tasks[0], Future): + progress(tasks) + self.client.gather(tasks) + + scores = np.hstack(tasks) + ranked = np.argsort(scores)[::-1] + + # TODO: Return top N matches + # for rank in ranked[-1:][::-1]: + logger.info( + "File: %s \nScore: %.4f", + self.params.simulation_files[ranked[0]].name, + scores[ranked[0]], + ) + with Workspace(self.params.simulation_files[ranked[0]], mode="r") as ws: + survey = fetch_survey(ws) + ui_json = survey.parent.parent.options + ui_json["geoh5"] = ws + ifile = InputFile(ui_json=ui_json) + options = PlateSimulationOptions.build(ifile) + + plate = survey.parent.parent.get_entity("plate")[0].copy( + parent=self.params.out_group ) - with Workspace(self.params.simulation_files[rank], mode="r") as ws: - survey = fetch_survey(ws) - ui_json = survey.parent.parent.options - ui_json["geoh5"] = ws - ifile = InputFile(ui_json=ui_json) - options = PlateSimulationOptions.build(ifile) - - plate = survey.parent.parent.get_entity("plate")[0].copy( - parent=self.params.out_group - ) - - # Set position of plate to query location - center = self.params.survey.vertices[nearest] - center[2] = self._drape_heights[nearest] - plate.vertices = plate.vertices + center - plate.metadata = options.model.model_dump() - - print(f"Best parameters:{options.model.model_dump_json(indent=2)}") + + # Set position of plate to query location + center = self.params.survey.vertices[nearest] + center[2] = self._drape_heights[nearest] + plate.vertices = plate.vertices + center + plate.metadata = options.model.model_dump() + + print(f"Best parameters:{options.model.model_dump_json(indent=2)}") + results.append(self.params.simulation_files[ranked[0]].name) + + return results @classmethod def start_dask_run( diff --git a/simpeg_drivers/utils/utils.py b/simpeg_drivers/utils/utils.py index 7e5ab46f..64bb1cd9 100644 --- a/simpeg_drivers/utils/utils.py +++ b/simpeg_drivers/utils/utils.py @@ -573,50 +573,3 @@ def simpeg_group_to_driver(group: SimPEGGroup, workspace: Workspace) -> Inversio params = inversion_driver._params_class.build(ifile) # pylint: disable=protected-access return inversion_driver(params) - - -def xyz_to_polar(locations: np.ndarray) -> np.ndarray: - """ - Convert Cartesian coordinates to polar coordinates defined as - (distance, azimuth, height), where distance is signed based on the - x-coordinate relative to the mean location. - - :param locations: Cartesian coordinates. - - :return: Polar coordinates (distance, azimuth, height). - """ - xyz = locations - np.mean(locations, axis=0) - distances = np.sign(xyz[:, 0]) * np.linalg.norm(xyz[:, :2], axis=1) - - azimuths = 90 - (np.rad2deg(np.arctan2(xyz[:, 0], xyz[:, 1])) % 180) - return np.c_[distances, azimuths, locations[:, 2]] - - -def inverse_weighted_operator( - values: np.ndarray, - col_indices: np.ndarray, - shape: tuple, - power: float, - threshold: float, -) -> csr_matrix: - """ - Create an inverse distance weighted sparse matrix. - - :param values: Distance values. - :param col_indices: Column indices for the sparse matrix. - :param shape: Shape of the sparse matrix. - :param power: Power for the inverse distance weighting. - :param threshold: Threshold to avoid singularities. - - :return: Inverse distance weighted sparse matrix. - """ - weights = (values**power + threshold) ** -1 - n_vals_row = weights.shape[0] // shape[0] - row_ids = np.repeat(np.arange(shape[0]), n_vals_row) - inv_dist_op = csr_matrix( - (weights, (row_ids, col_indices)), - shape=shape, - ) - # Normalize the rows - row_sum = np.asarray(inv_dist_op.sum(axis=1)).flatten() ** -1.0 - return diags(row_sum) @ inv_dist_op diff --git a/tests/plate_simulation/runtest/match_test.py b/tests/plate_simulation/runtest/match_test.py index d8c3082f..c2d116d2 100644 --- a/tests/plate_simulation/runtest/match_test.py +++ b/tests/plate_simulation/runtest/match_test.py @@ -7,15 +7,22 @@ # (see LICENSE file at the root of this source code package). ' # ' # ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' - +import shutil from pathlib import Path import numpy as np from geoh5py import Workspace from geoh5py.groups import PropertyGroup from geoh5py.objects import Points +from geoh5py.ui_json import InputFile +from simpeg_drivers import assets_path +from simpeg_drivers.electromagnetics.time_domain.driver import TDEMForwardDriver +from simpeg_drivers.electromagnetics.time_domain.options import TDEMForwardOptions +from simpeg_drivers.plate_simulation.driver import PlateSimulationDriver +from simpeg_drivers.plate_simulation.match.driver import PlateMatchDriver, fetch_survey from simpeg_drivers.plate_simulation.match.options import PlateMatchOptions +from simpeg_drivers.plate_simulation.options import PlateSimulationOptions from simpeg_drivers.utils.synthetics.driver import ( SyntheticsComponents, ) @@ -31,9 +38,7 @@ def generate_example(geoh5: Workspace, n_grid_points: int, refinement: tuple[int]): opts = SyntheticsComponentsOptions( method="airborne tdem", - survey=SurveyOptions( - n_stations=n_grid_points, n_lines=n_grid_points, drape=10.0 - ), + survey=SurveyOptions(n_stations=n_grid_points, n_lines=1, drape=10.0), mesh=MeshOptions(refinement=refinement, padding_distance=400.0), model=ModelOptions(background=0.001), ) @@ -41,6 +46,9 @@ def generate_example(geoh5: Workspace, n_grid_points: int, refinement: tuple[int vals = components.survey.add_data( {"observed_data": {"values": np.random.randn(components.survey.n_vertices)}}, ) + + # Shift survey up (for single line) + components.survey.vertices = components.survey.vertices + np.r_[0, 100, 0] components.property_group = PropertyGroup(components.survey, properties=vals) components.queries = Points.create(geoh5, vertices=np.random.randn(1, 3)) @@ -75,3 +83,72 @@ def test_file_parsing(tmp_path: Path): sim_files = options.simulation_files assert len(sim_files) == 1 assert sim_files[0].name == f"{__name__}.geoh5" + + +def test_matching_driver(tmp_path: Path): + """ + Generate a few files and test the + plate_simulation.match.Options.simulation_files() method. + """ + + # Generate simulation files + with get_workspace(tmp_path / f"{__name__}.geoh5") as geoh5: + components = generate_example(geoh5, n_grid_points=5, refinement=(2,)) + + params = TDEMForwardOptions.build( + geoh5=geoh5, + mesh=components.mesh, + topography_object=components.topography, + data_object=components.survey, + starting_model=components.model, + x_channel_bool=True, + y_channel_bool=True, + z_channel_bool=True, + ) + + fwr_driver = TDEMForwardDriver(params) + + ifile = InputFile.read_ui_json( + assets_path() / "uijson" / "plate_simulation.ui.json", validate=False + ) + ifile.data["geoh5"] = geoh5 + ifile.data["simulation"] = fwr_driver.out_group + + plate_options = PlateSimulationOptions.build(ifile.data) + driver = PlateSimulationDriver(plate_options) + driver.run() + + # Make copies of the generated simulation file to emulate a sweep + file = tmp_path / f"{__name__}.geoh5" + new_dir = tmp_path / "simulations" + new_dir.mkdir(parents=True, exist_ok=True) + + for ii in range(1, 5): + new_file = new_dir / (file.stem + f"_[{ii}].geoh5") + shutil.copy(file, new_file) + + # Modify the data slightly + with Workspace(new_file) as sim_geoh5: + survey = fetch_survey(sim_geoh5) + prop_group = survey.get_entity("Iteration_0_z")[0] + scale = np.cos(np.linspace(-np.pi / ii, np.pi / ii, survey.n_vertices)) + + for uid in prop_group.properties: + child = survey.get_entity(uid)[0] + child.values = child.values * scale + + # Random choice of file + with geoh5.open(): + survey = fetch_survey(geoh5) + options = PlateMatchOptions( + geoh5=geoh5, + survey=survey, + data=survey.get_entity("Iteration_0_z")[0], + queries=components.queries, + topography_object=components.topography, + simulations=new_dir, + ) + match_driver = PlateMatchDriver(options) + results = match_driver.run() + + assert results[0] == file.stem + f"_[{4}].geoh5" diff --git a/tests/utils_test.py b/tests/utils_test.py deleted file mode 100644 index 13888682..00000000 --- a/tests/utils_test.py +++ /dev/null @@ -1,57 +0,0 @@ -# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' -# Copyright (c) 2023-2026 Mira Geoscience Ltd. ' -# ' -# This file is part of simpeg-drivers package. ' -# ' -# simpeg-drivers is distributed under the terms and conditions of the MIT License ' -# (see LICENSE file at the root of this source code package). ' -# ' -# ''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''' - -import numpy as np - -from simpeg_drivers.utils.utils import inverse_weighted_operator, xyz_to_polar - - -def test_xyz_to_polar(): - """ - Test the xyz_to_polar utility function. - """ - for _ in range(100): - rad = np.abs(np.random.randn()) - azm = np.random.randn() - - # Create x, y, z coordinates - x = rad * np.cos(azm) - y = rad * np.sin(azm) - z = np.random.randn() - - polar = xyz_to_polar(np.vstack([[[0, 0, 0]], [[x, y, z]]])) - np.testing.assert_almost_equal( - polar[0, 0], -polar[1, 0] - ) # Opposite side of center - np.testing.assert_almost_equal(polar[0, 1], polar[1, 1]) # Same azimuth - np.testing.assert_allclose([0, z], polar[:, 2]) # Preserves z - - -def test_inverse_weighted_operator(): - """ - Test the inverse_weighted_operator utility function. - - For a constant input, the output should be the same constant. - """ - power = 2.0 - threshold = 1e-12 - shape = (100, 1000) - values = np.random.randn(shape[0] * 2) - indices = np.c_[ - np.random.randint(0, shape[1] - 1, shape[0]), - np.random.randint(0, shape[1] - 1, shape[0]), - ].flatten() - - opt = inverse_weighted_operator(values, indices, shape, power, threshold) - test_val = np.random.randn() - interp = opt * np.full(shape[1], test_val) - - assert opt.shape == shape - np.testing.assert_allclose(interp, test_val, rtol=1e-3)