Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions simpeg_drivers/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,10 +579,22 @@ def start_inversion_message(self):
)

@property
def mapping(self) -> list[maps.IdentityMap] | None:
def mapping(self) -> list[maps.Projection] | None:
"""Model mapping for the inversion."""
if self._mapping is None:
self.mapping = maps.IdentityMap(nP=self.n_values)
mapping = []
start = 0
n_blocks = 3 if self.models.is_vector else 1

for _ in range(n_blocks):
mapping.append(
maps.Projection(
self.n_values * n_blocks, slice(start, start + self.n_values)
)
)
start += self.n_values

self._mapping = mapping

return self._mapping

Expand Down Expand Up @@ -838,7 +850,7 @@ def get_path(self, filepath: str | Path) -> str:


if __name__ == "__main__":
file = Path(sys.argv[1]).resolve()
file = Path(r"C:\Users\dominiquef\Desktop\Tests\GEOPY-2620D.ui.json").resolve()
input_file = load_ui_json_as_dict(file)
n_workers = input_file.get("n_workers", None)
n_threads = input_file.get("n_threads", None)
Expand Down
36 changes: 19 additions & 17 deletions simpeg_drivers/joint/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,23 +164,7 @@ def inversion_data(self):
def directives(self):
if getattr(self, "_directives", None) is None and not self.params.forward_only:
with fetch_active_workspace(self.workspace, mode="r+"):
directives_list = self._get_drivers_directives()

directives_list += self._get_global_model_save_directives()
directives_list.append(
directives.SaveLPModelGroup(
self.inversion_mesh.entity,
self._directives.update_irls_directive,
)
)
directives_list.append(self._directives.save_iteration_log_files)
self._directives.directive_list = (
self._directives.inversion_directives + directives_list
)

DirectivesFactory.configure_save_directives(
self._directives.directive_list
)
self._directives.directive_list = self._get_joint_directives()

return self._directives

Expand Down Expand Up @@ -440,6 +424,24 @@ def _get_global_model_save_directives(self):
directives_list += self._get_local_model_save_directives(driver, wire)
return directives_list

def _get_joint_directives(self) -> list[directives.Directive]:
"""
Create a list of directives for the joint inversion.
"""
directives_list = self._get_drivers_directives()
directives_list += self._get_global_model_save_directives()
directives_list.append(
directives.SaveLPModelGroup(
self.inversion_mesh.entity,
self._directives.update_irls_directive,
)
)
directives_list.append(self._directives.save_iteration_log_files)
directives_list += self._directives.inversion_directives
DirectivesFactory.configure_save_directives(directives_list)

return directives_list

def _get_local_model_save_directives(
self, driver, wire
) -> list[directives.Directive]:
Expand Down
8 changes: 1 addition & 7 deletions simpeg_drivers/joint/joint_cross_gradient/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,11 @@

from itertools import combinations

import numpy as np
from geoh5py.groups.property_group_type import GroupTypeEnum
from geoh5py.shared.utils import fetch_active_workspace
from simpeg import directives, maps
from simpeg import maps
from simpeg.objective_function import ComboObjectiveFunction
from simpeg.regularization import CrossGradient

from simpeg_drivers.components.factories import (
DirectivesFactory,
SaveModelGeoh5Factory,
)
from simpeg_drivers.joint.driver import BaseJointDriver

from .options import JointCrossGradientOptions
Expand Down
21 changes: 2 additions & 19 deletions simpeg_drivers/joint/joint_petrophysics/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,14 @@

from __future__ import annotations

from itertools import combinations

import numpy as np
from geoh5py.groups.property_group_type import GroupTypeEnum
from geoh5py.shared.utils import fetch_active_workspace
from simpeg import directives, maps, utils
from simpeg.objective_function import ComboObjectiveFunction
from simpeg.regularization.pgi import PGIsmallness

from simpeg_drivers.components.factories import (
DirectivesFactory,
SaveModelGeoh5Factory,
)
from simpeg_drivers.joint.driver import BaseJointDriver

Expand All @@ -48,7 +44,7 @@ def __init__(self, params: JointPetrophysicsOptions):
def directives(self):
if getattr(self, "_directives", None) is None and not self.params.forward_only:
with fetch_active_workspace(self.workspace, mode="r+"):
directives_list = self._get_drivers_directives()
directives_list = self._get_joint_directives()
directives_list.append(
directives.PGI_UpdateParameters(
update_gmm=True,
Expand All @@ -58,7 +54,6 @@ def directives(self):
],
)
)
directives_list += self._get_global_model_save_directives()

# TODO: To bring back once we let the classification change
# directives_list.append(
Expand All @@ -78,20 +73,8 @@ def directives(self):
# reference_type=self.params.models.petrophysical_model.entity_type,
# )
# )
directives_list.append(
directives.SaveLPModelGroup(
self.inversion_mesh.entity,
self._directives.update_irls_directive,
)
)
directives_list.append(self._directives.save_iteration_log_files)
self._directives.directive_list = (
self._directives.inversion_directives + directives_list
)

DirectivesFactory.configure_save_directives(
self._directives.directive_list
)
self._directives.directive_list = directives_list

return self._directives

Expand Down
44 changes: 43 additions & 1 deletion simpeg_drivers/joint/joint_surveys/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import numpy as np
from geoh5py.shared.utils import fetch_active_workspace
from simpeg import maps
from simpeg import directives, maps

from simpeg_drivers.driver import InversionDriver
from simpeg_drivers.joint.driver import BaseJointDriver
Expand Down Expand Up @@ -56,6 +56,15 @@ def validate_create_models(self):
continue

model_local_values = getattr(self.drivers[0].models, model_type)

if (
self.drivers[0].models.is_vector
and len(model_local_values) > self.drivers[0].models.n_active
):
model_local_values = np.linalg.norm(
model_local_values.reshape((-1, 3), order="F"), axis=1
)

model = (
projection * model_local_values[: self.drivers[0].models.n_active]
) / (norm + 1e-8)
Expand Down Expand Up @@ -101,6 +110,39 @@ def _get_global_model_save_directives(self):

return directives_list

@property
def directives(self):
if getattr(self, "_directives", None) is None and not self.params.forward_only:
with fetch_active_workspace(self.workspace, mode="r+"):
directives_list = self._get_joint_directives()

if self.models.is_vector:
for directive in directives_list:
if isinstance(directive, directives.VectorInversion):
directives_list.remove(directive)

reference_angles = (
getattr(self.params.models, "reference_model", None)
is not None,
getattr(self.params.models, "reference_inclination", None)
is not None,
getattr(self.params.models, "reference_declination", None)
is not None,
)

vector_directive = directives.VectorInversion(
self.data_misfit.objfcts,
self.regularization,
chifact_target=self.params.cooling_schedule.chi_factor * 2,
reference_angles=reference_angles,
)

directives_list = [vector_directive] + directives_list

self._directives.directive_list = directives_list

return self._directives


JointSurveyDriver.n_values = InversionDriver.n_values
JointSurveyDriver.mapping = InversionDriver.mapping
38 changes: 0 additions & 38 deletions simpeg_drivers/potential_fields/magnetic_vector/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,41 +28,3 @@ class MVIInversionDriver(InversionDriver):
"""Magnetic Vector inversion driver."""

_params_class = MVIInversionOptions

@property
def mapping(self) -> list[maps.Projection] | None:
"""Model mapping for the inversion."""
if self._mapping is None:
mapping = []
start = 0
for _ in range(3):
mapping.append(
maps.Projection(
self.n_values * 3, slice(start, start + self.n_values)
)
)
start += self.n_values

self._mapping = mapping

return self._mapping

@mapping.setter
def mapping(self, value: list[maps.Projection]):
if not isinstance(value, list) or len(value) != 3:
raise TypeError(
"'mapping' must be a list of 3 instances of maps.IdentityMap. "
f"Provided {value}"
)

if not all(
isinstance(val, maps.Projection)
and val.shape == (self.n_values, 3 * self.n_values)
for val in value
):
raise TypeError(
"'mapping' must be an instance of maps.Projection with shape (n_values, 3 * self.n_values). "
f"Provided {value}"
)

self._mapping = value
Loading