Skip to content

Commit 89e5fdd

Browse files
authored
Improves error message for cross check validation (#147)
* Updates custom error messages * Splits error message into two for cross check validation * Tidies up tests and improves coverage * Reformats error messages * Improves error message for contrast model
1 parent 113729f commit 89e5fdd

File tree

3 files changed

+280
-129
lines changed

3 files changed

+280
-129
lines changed

RATapi/controls.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def warn_setting_incorrect_properties(self, handler: ValidatorFunctionWrapHandle
155155
f" controls procedure are:\n "
156156
f"{', '.join(fields.get('procedure', []))}\n",
157157
}
158-
custom_error_list = custom_pydantic_validation_error(exc.errors(), custom_error_msgs)
158+
custom_error_list = custom_pydantic_validation_error(exc.errors(include_url=False), custom_error_msgs)
159159
raise ValidationError.from_exception_data(exc.title, custom_error_list, hide_input=True) from None
160160

161161
if isinstance(model_input, validated_self.__class__):

RATapi/project.py

Lines changed: 89 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,6 @@ def update_renamed_models(self) -> "Project":
557557
for index, param in all_matches:
558558
if param in params:
559559
setattr(project_field[index], param, new_name)
560-
self._all_names = self.get_all_names()
561560
return self
562561

563562
@model_validator(mode="after")
@@ -566,28 +565,45 @@ def cross_check_model_values(self) -> "Project":
566565
values = ["value_1", "value_2", "value_3", "value_4", "value_5"]
567566
for field in ["backgrounds", "resolutions"]:
568567
self.check_allowed_source(field)
569-
self.check_allowed_values(field, values, getattr(self, f"{field[:-1]}_parameters").get_names())
568+
self.check_allowed_values(
569+
field,
570+
values,
571+
getattr(self, f"{field[:-1]}_parameters").get_names(),
572+
self._all_names[f"{field[:-1]}_parameters"],
573+
)
570574

571575
self.check_allowed_values(
572576
"layers",
573577
["thickness", "SLD", "SLD_real", "SLD_imaginary", "roughness"],
574578
self.parameters.get_names(),
579+
self._all_names["parameters"],
575580
)
576581

577-
self.check_allowed_values("contrasts", ["data"], self.data.get_names())
578-
self.check_allowed_values("contrasts", ["background"], self.backgrounds.get_names())
579-
self.check_allowed_values("contrasts", ["bulk_in"], self.bulk_in.get_names())
580-
self.check_allowed_values("contrasts", ["bulk_out"], self.bulk_out.get_names())
581-
self.check_allowed_values("contrasts", ["scalefactor"], self.scalefactors.get_names())
582-
self.check_allowed_values("contrasts", ["resolution"], self.resolutions.get_names())
583-
self.check_allowed_values("contrasts", ["domain_ratio"], self.domain_ratios.get_names())
582+
self.check_allowed_values("contrasts", ["data"], self.data.get_names(), self._all_names["data"])
583+
self.check_allowed_values(
584+
"contrasts", ["background"], self.backgrounds.get_names(), self._all_names["backgrounds"]
585+
)
586+
self.check_allowed_values("contrasts", ["bulk_in"], self.bulk_in.get_names(), self._all_names["bulk_in"])
587+
self.check_allowed_values("contrasts", ["bulk_out"], self.bulk_out.get_names(), self._all_names["bulk_out"])
588+
self.check_allowed_values(
589+
"contrasts", ["scalefactor"], self.scalefactors.get_names(), self._all_names["scalefactors"]
590+
)
591+
self.check_allowed_values(
592+
"contrasts", ["resolution"], self.resolutions.get_names(), self._all_names["resolutions"]
593+
)
594+
self.check_allowed_values(
595+
"contrasts", ["domain_ratio"], self.domain_ratios.get_names(), self._all_names["domain_ratios"]
596+
)
584597

585598
self.check_contrast_model_allowed_values(
586599
"contrasts",
587600
getattr(self, self._contrast_model_field).get_names(),
601+
self._all_names[self._contrast_model_field],
588602
self._contrast_model_field,
589603
)
590-
self.check_contrast_model_allowed_values("domain_contrasts", self.layers.get_names(), "layers")
604+
self.check_contrast_model_allowed_values(
605+
"domain_contrasts", self.layers.get_names(), self._all_names["layers"], "layers"
606+
)
591607
return self
592608

593609
@model_validator(mode="after")
@@ -606,6 +622,12 @@ def check_protected_parameters(self) -> "Project":
606622
self._protected_parameters = self.get_all_protected_parameters()
607623
return self
608624

625+
@model_validator(mode="after")
626+
def update_names(self) -> "Project":
627+
"""Following validation, update the list of all parameter names."""
628+
self._all_names = self.get_all_names()
629+
return self
630+
609631
def __str__(self):
610632
output = ""
611633
for key, value in self.__dict__.items():
@@ -630,7 +652,9 @@ def get_all_protected_parameters(self):
630652
for class_list in parameter_class_lists
631653
}
632654

633-
def check_allowed_values(self, attribute: str, field_list: list[str], allowed_values: list[str]) -> None:
655+
def check_allowed_values(
656+
self, attribute: str, field_list: list[str], allowed_values: list[str], previous_values: list[str]
657+
) -> None:
634658
"""Check the values of the given fields in the given model are in the supplied list of allowed values.
635659
636660
Parameters
@@ -641,6 +665,8 @@ def check_allowed_values(self, attribute: str, field_list: list[str], allowed_va
641665
The fields of the attribute to be checked for valid values.
642666
allowed_values : list [str]
643667
The list of allowed values for the fields given in field_list.
668+
previous_values : list [str]
669+
The list of allowed values for the fields given in field_list after the previous validation.
644670
645671
Raises
646672
------
@@ -649,14 +675,22 @@ def check_allowed_values(self, attribute: str, field_list: list[str], allowed_va
649675
650676
"""
651677
class_list = getattr(self, attribute)
652-
for model in class_list:
678+
for index, model in enumerate(class_list):
653679
for field in field_list:
654680
value = getattr(model, field, "")
655681
if value and value not in allowed_values:
656-
raise ValueError(
657-
f'The value "{value}" in the "{field}" field of "{attribute}" must be defined in '
658-
f'"{values_defined_in[f"{attribute}.{field}"]}".',
659-
)
682+
if value in previous_values:
683+
raise ValueError(
684+
f'The value "{value}" used in the "{field}" field of {attribute}[{index}] must be defined '
685+
f'in "{values_defined_in[f"{attribute}.{field}"]}". Please remove "{value}" from '
686+
f'"{attribute}[{index}].{field}" before attempting to delete it.',
687+
)
688+
else:
689+
raise ValueError(
690+
f'The value "{value}" used in the "{field}" field of {attribute}[{index}] must be defined '
691+
f'in "{values_defined_in[f"{attribute}.{field}"]}". Please add "{value}" to '
692+
f'"{values_defined_in[f"{attribute}.{field}"]}" before including it in "{attribute}".',
693+
)
660694

661695
def check_allowed_source(self, attribute: str) -> None:
662696
"""Check that the source of a background or resolution is defined in the relevant field for its type.
@@ -679,24 +713,37 @@ def check_allowed_source(self, attribute: str) -> None:
679713
680714
"""
681715
class_list = getattr(self, attribute)
682-
for model in class_list:
716+
for index, model in enumerate(class_list):
683717
if model.type == TypeOptions.Constant:
684718
allowed_values = getattr(self, f"{attribute[:-1]}_parameters").get_names()
719+
previous_values = self._all_names[f"{attribute[:-1]}_parameters"]
685720
elif model.type == TypeOptions.Data:
686721
allowed_values = self.data.get_names()
722+
previous_values = self._all_names["data"]
687723
else:
688724
allowed_values = self.custom_files.get_names()
725+
previous_values = self._all_names["custom_files"]
689726

690727
if (value := model.source) != "" and value not in allowed_values:
691-
raise ValueError(
692-
f'The value "{value}" in the "source" field of "{attribute}" must be defined in '
693-
f'"{values_defined_in[f"{attribute}.{model.type}.source"]}".',
694-
)
728+
if value in previous_values:
729+
raise ValueError(
730+
f'The value "{value}" used in the "source" field of {attribute}[{index}] must be defined in '
731+
f'"{values_defined_in[f"{attribute}.{model.type}.source"]}". Please remove "{value}" from '
732+
f'"{attribute}[{index}].source" before attempting to delete it.',
733+
)
734+
else:
735+
raise ValueError(
736+
f'The value "{value}" used in the "source" field of {attribute}[{index}] must be defined in '
737+
f'"{values_defined_in[f"{attribute}.{model.type}.source"]}". Please add "{value}" to '
738+
f'"{values_defined_in[f"{attribute}.{model.type}.source"]}" before including it in '
739+
f'"{attribute}".',
740+
)
695741

696742
def check_contrast_model_allowed_values(
697743
self,
698744
contrast_attribute: str,
699745
allowed_values: list[str],
746+
previous_values: list[str],
700747
allowed_field: str,
701748
) -> None:
702749
"""Ensure the contents of the ``model`` for a contrast or domain contrast exist in the required project fields.
@@ -707,6 +754,8 @@ def check_contrast_model_allowed_values(
707754
The specific contrast attribute of Project being validated (either "contrasts" or "domain_contrasts").
708755
allowed_values : list [str]
709756
The list of allowed values for the model of the contrast_attribute.
757+
previous_values : list [str]
758+
The list of allowed values for the model of the contrast_attribute after the previous validation.
710759
allowed_field : str
711760
The name of the field in the project in which the allowed_values are defined.
712761
@@ -717,13 +766,22 @@ def check_contrast_model_allowed_values(
717766
718767
"""
719768
class_list = getattr(self, contrast_attribute)
720-
for contrast in class_list:
721-
model_values = contrast.model
722-
if model_values and not all(value in allowed_values for value in model_values):
723-
raise ValueError(
724-
f'The values: "{", ".join(str(i) for i in model_values)}" in the "model" field of '
725-
f'"{contrast_attribute}" must be defined in "{allowed_field}".',
726-
)
769+
for index, contrast in enumerate(class_list):
770+
if (model_values := contrast.model) and (missing_values := list(set(model_values) - set(allowed_values))):
771+
if all(value in previous_values for value in model_values):
772+
raise ValueError(
773+
f"The value{'s' if len(missing_values) > 1 else ''}: "
774+
f'"{", ".join(str(i) for i in missing_values)}" used in the "model" field of '
775+
f'{contrast_attribute}[{index}] must be defined in "{allowed_field}". Please remove all '
776+
f'unnecessary values from "model" before attempting to delete them.',
777+
)
778+
else:
779+
raise ValueError(
780+
f"The value{'s' if len(missing_values) > 1 else ''}: "
781+
f'"{", ".join(str(i) for i in missing_values)}" used in the "model" field of '
782+
f'{contrast_attribute}[{index}] must be defined in "{allowed_field}". Please add all '
783+
f'required values to "{allowed_field}" before including them in "{contrast_attribute}".',
784+
)
727785

728786
def get_contrast_model_field(self):
729787
"""Get the field used to define the contents of the "model" field in contrasts.
@@ -945,7 +1003,7 @@ def wrapped_func(*args, **kwargs):
9451003
Project.model_validate(self)
9461004
except ValidationError as exc:
9471005
class_list.data = previous_state
948-
custom_error_list = custom_pydantic_validation_error(exc.errors())
1006+
custom_error_list = custom_pydantic_validation_error(exc.errors(include_url=False))
9491007
raise ValidationError.from_exception_data(exc.title, custom_error_list, hide_input=True) from None
9501008
except (TypeError, ValueError):
9511009
class_list.data = previous_state
@@ -980,9 +1038,8 @@ def try_relative_to(path: Path, relative_to: Path) -> str:
9801038
else:
9811039
warnings.warn(
9821040
"Could not save custom file path as relative to the project directory, "
983-
"which means that it may not work on other devices."
984-
"If you would like to share your project, make sure your custom files "
985-
"are in a subfolder of the project save location.",
1041+
"which means that it may not work on other devices. If you would like to share your project, "
1042+
"make sure your custom files are in a subfolder of the project save location.",
9861043
stacklevel=2,
9871044
)
9881045
return str(path.resolve())

0 commit comments

Comments
 (0)