Skip to content

Commit 92fde47

Browse files
committed
Splits error message into two for cross check validation
1 parent 58421b3 commit 92fde47

File tree

3 files changed

+186
-80
lines changed

3 files changed

+186
-80
lines changed

RATapi/project.py

Lines changed: 85 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,6 @@ def update_renamed_models(self) -> "Project":
557557
for index, param in all_matches:
558558
if param in params:
559559
setattr(project_field[index], param, new_name)
560-
self._all_names = self.get_all_names()
561560
return self
562561

563562
@model_validator(mode="after")
@@ -566,28 +565,45 @@ def cross_check_model_values(self) -> "Project":
566565
values = ["value_1", "value_2", "value_3", "value_4", "value_5"]
567566
for field in ["backgrounds", "resolutions"]:
568567
self.check_allowed_source(field)
569-
self.check_allowed_values(field, values, getattr(self, f"{field[:-1]}_parameters").get_names())
568+
self.check_allowed_values(
569+
field,
570+
values,
571+
getattr(self, f"{field[:-1]}_parameters").get_names(),
572+
self._all_names[f"{field[:-1]}_parameters"],
573+
)
570574

571575
self.check_allowed_values(
572576
"layers",
573577
["thickness", "SLD", "SLD_real", "SLD_imaginary", "roughness"],
574578
self.parameters.get_names(),
579+
self._all_names["parameters"],
575580
)
576581

577-
self.check_allowed_values("contrasts", ["data"], self.data.get_names())
578-
self.check_allowed_values("contrasts", ["background"], self.backgrounds.get_names())
579-
self.check_allowed_values("contrasts", ["bulk_in"], self.bulk_in.get_names())
580-
self.check_allowed_values("contrasts", ["bulk_out"], self.bulk_out.get_names())
581-
self.check_allowed_values("contrasts", ["scalefactor"], self.scalefactors.get_names())
582-
self.check_allowed_values("contrasts", ["resolution"], self.resolutions.get_names())
583-
self.check_allowed_values("contrasts", ["domain_ratio"], self.domain_ratios.get_names())
582+
self.check_allowed_values("contrasts", ["data"], self.data.get_names(), self._all_names["data"])
583+
self.check_allowed_values(
584+
"contrasts", ["background"], self.backgrounds.get_names(), self._all_names["backgrounds"]
585+
)
586+
self.check_allowed_values("contrasts", ["bulk_in"], self.bulk_in.get_names(), self._all_names["bulk_in"])
587+
self.check_allowed_values("contrasts", ["bulk_out"], self.bulk_out.get_names(), self._all_names["bulk_out"])
588+
self.check_allowed_values(
589+
"contrasts", ["scalefactor"], self.scalefactors.get_names(), self._all_names["scalefactors"]
590+
)
591+
self.check_allowed_values(
592+
"contrasts", ["resolution"], self.resolutions.get_names(), self._all_names["resolutions"]
593+
)
594+
self.check_allowed_values(
595+
"contrasts", ["domain_ratio"], self.domain_ratios.get_names(), self._all_names["domain_ratios"]
596+
)
584597

585598
self.check_contrast_model_allowed_values(
586599
"contrasts",
587600
getattr(self, self._contrast_model_field).get_names(),
601+
self._all_names[self._contrast_model_field],
588602
self._contrast_model_field,
589603
)
590-
self.check_contrast_model_allowed_values("domain_contrasts", self.layers.get_names(), "layers")
604+
self.check_contrast_model_allowed_values(
605+
"domain_contrasts", self.layers.get_names(), self._all_names["layers"], "layers"
606+
)
591607
return self
592608

593609
@model_validator(mode="after")
@@ -606,6 +622,12 @@ def check_protected_parameters(self) -> "Project":
606622
self._protected_parameters = self.get_all_protected_parameters()
607623
return self
608624

625+
@model_validator(mode="after")
626+
def update_names(self) -> "Project":
627+
"""Following validation, update the list of all parameter names."""
628+
self._all_names = self.get_all_names()
629+
return self
630+
609631
def __str__(self):
610632
output = ""
611633
for key, value in self.__dict__.items():
@@ -630,7 +652,9 @@ def get_all_protected_parameters(self):
630652
for class_list in parameter_class_lists
631653
}
632654

633-
def check_allowed_values(self, attribute: str, field_list: list[str], allowed_values: list[str]) -> None:
655+
def check_allowed_values(
656+
self, attribute: str, field_list: list[str], allowed_values: list[str], previous_values: list[str]
657+
) -> None:
634658
"""Check the values of the given fields in the given model are in the supplied list of allowed values.
635659
636660
Parameters
@@ -641,6 +665,8 @@ def check_allowed_values(self, attribute: str, field_list: list[str], allowed_va
641665
The fields of the attribute to be checked for valid values.
642666
allowed_values : list [str]
643667
The list of allowed values for the fields given in field_list.
668+
previous_values : list [str]
669+
The list of allowed values for the fields given in field_list after the previous validation.
644670
645671
Raises
646672
------
@@ -653,12 +679,19 @@ def check_allowed_values(self, attribute: str, field_list: list[str], allowed_va
653679
for field in field_list:
654680
value = getattr(model, field, "")
655681
if value and value not in allowed_values:
656-
raise ValueError(
657-
f'The value "{value}" used in the "{field}" field at index {index} of "{attribute}" '
658-
f'must be defined in "{values_defined_in[f"{attribute}.{field}"]}". Please either add '
659-
f'"{value}" to "{values_defined_in[f"{attribute}.{field}"]}" before including it in'
660-
f' "{attribute}", or remove it from "{attribute}.{field}" before attempting to delete it.',
661-
)
682+
if value in previous_values:
683+
raise ValueError(
684+
f'The value "{value}" used in the "{field}" field at index {index} of "{attribute}" '
685+
f'must be defined in "{values_defined_in[f"{attribute}.{field}"]}". Please remove '
686+
f'"{value}" from "{attribute}{index}.{field}" before attempting to delete it.',
687+
)
688+
else:
689+
raise ValueError(
690+
f'The value "{value}" used in the "{field}" field at index {index} of "{attribute}" '
691+
f'must be defined in "{values_defined_in[f"{attribute}.{field}"]}". Please add '
692+
f'"{value}" to "{values_defined_in[f"{attribute}.{field}"]}" before including it in '
693+
f'"{attribute}".',
694+
)
662695

663696
def check_allowed_source(self, attribute: str) -> None:
664697
"""Check that the source of a background or resolution is defined in the relevant field for its type.
@@ -681,24 +714,37 @@ def check_allowed_source(self, attribute: str) -> None:
681714
682715
"""
683716
class_list = getattr(self, attribute)
684-
for model in class_list:
717+
for index, model in enumerate(class_list):
685718
if model.type == TypeOptions.Constant:
686719
allowed_values = getattr(self, f"{attribute[:-1]}_parameters").get_names()
720+
previous_values = self._all_names[f"{attribute[:-1]}_parameters"]
687721
elif model.type == TypeOptions.Data:
688722
allowed_values = self.data.get_names()
723+
previous_values = self._all_names["data"]
689724
else:
690725
allowed_values = self.custom_files.get_names()
726+
previous_values = self._all_names["custom_files"]
691727

692728
if (value := model.source) != "" and value not in allowed_values:
693-
raise ValueError(
694-
f'The value "{value}" in the "source" field of "{attribute}" must be defined in '
695-
f'"{values_defined_in[f"{attribute}.{model.type}.source"]}".',
696-
)
729+
if value in previous_values:
730+
raise ValueError(
731+
f'The value "{value}" used in the "source" field at index {index} of "{attribute}" '
732+
f'must be defined in "{values_defined_in[f"{attribute}.{model.type}.source"]}". Please remove '
733+
f'"{value}" from "{attribute}{index}.source" before attempting to delete it.',
734+
)
735+
else:
736+
raise ValueError(
737+
f'The value "{value}" used in the "source" field at index {index} of "{attribute}" '
738+
f'must be defined in "{values_defined_in[f"{attribute}.{model.type}.source"]}". Please add '
739+
f'"{value}" to "{values_defined_in[f"{attribute}.{model.type}.source"]}" before including it '
740+
f'in "{attribute}".',
741+
)
697742

698743
def check_contrast_model_allowed_values(
699744
self,
700745
contrast_attribute: str,
701746
allowed_values: list[str],
747+
previous_values: list[str],
702748
allowed_field: str,
703749
) -> None:
704750
"""Ensure the contents of the ``model`` for a contrast or domain contrast exist in the required project fields.
@@ -709,6 +755,8 @@ def check_contrast_model_allowed_values(
709755
The specific contrast attribute of Project being validated (either "contrasts" or "domain_contrasts").
710756
allowed_values : list [str]
711757
The list of allowed values for the model of the contrast_attribute.
758+
previous_values : list [str]
759+
The list of allowed values for the model of the contrast_attribute after the previous validation.
712760
allowed_field : str
713761
The name of the field in the project in which the allowed_values are defined.
714762
@@ -719,13 +767,21 @@ def check_contrast_model_allowed_values(
719767
720768
"""
721769
class_list = getattr(self, contrast_attribute)
722-
for contrast in class_list:
723-
model_values = contrast.model
724-
if model_values and not all(value in allowed_values for value in model_values):
725-
raise ValueError(
726-
f'The values: "{", ".join(str(i) for i in model_values)}" in the "model" field of '
727-
f'"{contrast_attribute}" must be defined in "{allowed_field}".',
728-
)
770+
for index, contrast in enumerate(class_list):
771+
if (model_values := contrast.model) and not all(value in allowed_values for value in model_values):
772+
if all(value in previous_values for value in model_values):
773+
raise ValueError(
774+
f'The values: "{", ".join(str(i) for i in model_values)}" used in the "model" field at index '
775+
f'{index} of "{contrast_attribute}" must be defined in "{allowed_field}". Please remove '
776+
f'all unnecessary values from "model" before attempting to delete them.',
777+
)
778+
else:
779+
raise ValueError(
780+
f'The values: "{", ".join(str(i) for i in model_values)}" used in the "model" field at index '
781+
f'{index} of "{contrast_attribute}" must be defined in "{allowed_field}". Please add '
782+
f'all required values to "{allowed_field}" '
783+
f'before including them in "{contrast_attribute}".',
784+
)
729785

730786
def get_contrast_model_field(self):
731787
"""Get the field used to define the contents of the "model" field in contrasts.

RATapi/utils/custom_errors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def custom_pydantic_validation_error(
3535
if error["type"] in custom_error_msgs:
3636
custom_error = pydantic_core.PydanticCustomError(error["type"], custom_error_msgs[error["type"]])
3737
else:
38-
custom_error = pydantic_core.PydanticCustomError(error["type"], error["msg"].replace(",", ":", 1))
38+
custom_error = pydantic_core.PydanticCustomError(error["type"], error["msg"])
3939
error["type"] = custom_error
4040
custom_error_list.append(error)
4141

0 commit comments

Comments
 (0)