Skip to content

Commit ef6f0c3

Browse files
committed
Adds checks for undefined fields in layers and contrasts
1 parent 700f6b7 commit ef6f0c3

File tree

2 files changed

+107
-18
lines changed

2 files changed

+107
-18
lines changed

RATapi/inputs.py

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,23 @@ def make_problem(project: RATapi.Project) -> ProblemDefinition:
151151
The problem input used in the compiled RAT code.
152152
153153
"""
154-
hydrate_id = {"bulk in": 1, "bulk out": 2}
155154
prior_id = {"uniform": 1, "gaussian": 2, "jeffreys": 3}
156155

157-
# Ensure backgrounds and resolutions have a source defined
156+
# Ensure all contrast fields are properly defined
158157
for contrast in project.contrasts:
158+
contrast_fields = ["data", "background", "bulk_in", "bulk_out", "scalefactor", "resolution"]
159+
160+
if project.calculation == Calculations.Domains:
161+
contrast_fields.append("domain_ratio")
162+
163+
for field in contrast_fields:
164+
if getattr(contrast, field) == "":
165+
raise ValueError(
166+
f'In the input project, the {field} of contrast "{contrast.name}" does not have a '
167+
f"value defined. A value must be supplied before running the project."
168+
)
169+
170+
# Ensure backgrounds and resolutions have a source defined
159171
background = project.backgrounds[contrast.background]
160172
resolution = project.resolutions[contrast.resolution]
161173
if background.source == "":
@@ -191,22 +203,7 @@ def make_problem(project: RATapi.Project) -> ProblemDefinition:
191203
contrast_custom_files = [project.custom_files.index(contrast.model[0], True) for contrast in project.contrasts]
192204

193205
# Get details of defined layers
194-
layer_details = []
195-
for layer in project.layers:
196-
if project.absorption:
197-
layer_params = [
198-
project.parameters.index(getattr(layer, attribute), True)
199-
for attribute in list(RATapi.models.AbsorptionLayer.model_fields.keys())[1:-2]
200-
]
201-
else:
202-
layer_params = [
203-
project.parameters.index(getattr(layer, attribute), True)
204-
for attribute in list(RATapi.models.Layer.model_fields.keys())[1:-2]
205-
]
206-
layer_params.append(project.parameters.index(layer.hydration, True) if layer.hydration else float("NaN"))
207-
layer_params.append(hydrate_id[layer.hydrate_with])
208-
209-
layer_details.append(layer_params)
206+
layer_details = get_layer_details(project)
210207

211208
contrast_background_params = []
212209
contrast_background_types = []
@@ -387,6 +384,35 @@ def make_problem(project: RATapi.Project) -> ProblemDefinition:
387384
return problem
388385

389386

387+
def get_layer_details(project: RATapi.Project) -> list[int]:
388+
"""Get parameter indices for all layers defined in the project."""
389+
hydrate_id = {"bulk in": 1, "bulk out": 2}
390+
layer_details = []
391+
392+
# Get the thickness, SLD, roughness fields from the appropriate model
393+
if project.absorption:
394+
layer_fields = list(RATapi.models.AbsorptionLayer.model_fields.keys())[1:-2]
395+
else:
396+
layer_fields = list(RATapi.models.Layer.model_fields.keys())[1:-2]
397+
398+
for layer in project.layers:
399+
for field in layer_fields:
400+
if getattr(layer, field) == "":
401+
raise ValueError(
402+
f"In the input project, the {field} field of layer {layer.name} does not have a value "
403+
f"defined. A value must be supplied before running the project."
404+
)
405+
406+
layer_params = [project.parameters.index(getattr(layer, attribute), True) for attribute in list(layer_fields)]
407+
408+
layer_params.append(project.parameters.index(layer.hydration, True) if layer.hydration else float("NaN"))
409+
layer_params.append(hydrate_id[layer.hydrate_with])
410+
411+
layer_details.append(layer_params)
412+
413+
return layer_details
414+
415+
390416
def make_resample(project: RATapi.Project) -> list[int]:
391417
"""Construct the "resample" field of the problem input required for the compiled RAT code.
392418

tests/test_inputs.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,69 @@ def test_background_params_value_indices(self, test_problem, bad_value, request)
572572
check_indices(test_problem)
573573

574574

575+
@pytest.mark.parametrize("test_project", ["standard_layers_project", "custom_xy_project", "domains_project"])
576+
@pytest.mark.parametrize("field", ["data", "background", "bulk_in", "bulk_out", "scalefactor", "resolution"])
577+
def test_undefined_contrast_fields(test_project, field, request):
578+
"""If a field in a contrast is empty, we should raise an error."""
579+
test_project = request.getfixturevalue(test_project)
580+
setattr(test_project.contrasts[0], field, "")
581+
582+
with pytest.raises(
583+
ValueError,
584+
match=f"In the input project, the {field} of contrast "
585+
f'"{test_project.contrasts[0].name}" does not have a value defined. '
586+
f"A value must be supplied before running the project.",
587+
):
588+
make_problem(test_project)
589+
590+
591+
@pytest.mark.parametrize("test_project", ["standard_layers_project", "custom_xy_project", "domains_project"])
592+
def test_undefined_background(test_project, request):
593+
"""If the source field of a background defined in a contrast is empty, we should raise an error."""
594+
test_project = request.getfixturevalue(test_project)
595+
background = test_project.backgrounds[test_project.contrasts[0].background]
596+
background.source = ""
597+
598+
with pytest.raises(
599+
ValueError,
600+
match=f"All backgrounds must have a source defined. For a {background.type} type "
601+
f"background, the source must be defined in "
602+
f'"{RATapi.project.values_defined_in[f"backgrounds.{background.type}.source"]}"',
603+
):
604+
make_problem(test_project)
605+
606+
607+
@pytest.mark.parametrize("test_project", ["standard_layers_project", "custom_xy_project", "domains_project"])
608+
def test_undefined_resolution(test_project, request):
609+
"""If the source field of a resolution defined in a contrast is empty, we should raise an error."""
610+
test_project = request.getfixturevalue(test_project)
611+
resolution = test_project.resolutions[test_project.contrasts[0].resolution]
612+
resolution.source = ""
613+
614+
with pytest.raises(
615+
ValueError,
616+
match=f"Constant resolutions must have a source defined. The source must be defined in "
617+
f'"{RATapi.project.values_defined_in[f"resolutions.{resolution.type}.source"]}"',
618+
):
619+
make_problem(test_project)
620+
621+
622+
@pytest.mark.parametrize("test_project", ["standard_layers_project", "domains_project"])
623+
@pytest.mark.parametrize("field", ["thickness", "SLD", "roughness"])
624+
def test_undefined_layers(test_project, field, request):
625+
"""If the thickness, SLD, or roughness fields of a layer defined in the project are empty, we should raise an
626+
error."""
627+
test_project = request.getfixturevalue(test_project)
628+
setattr(test_project.layers[0], field, "")
629+
630+
with pytest.raises(
631+
ValueError,
632+
match=f"In the input project, the {field} field of layer {test_project.layers[0].name} "
633+
f"does not have a value defined. A value must be supplied before running the project.",
634+
):
635+
make_problem(test_project)
636+
637+
575638
def test_append_data_background():
576639
"""Test that background data is correctly added to contrast data."""
577640
data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

0 commit comments

Comments
 (0)