Skip to content

Commit 9b4fab8

Browse files
committed
Refactors "set_fields" routine with custom context manager
1 parent 082657c commit 9b4fab8

File tree

3 files changed

+82
-4
lines changed

3 files changed

+82
-4
lines changed

RATapi/classlist.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import collections
66
import contextlib
7+
import importlib
78
import warnings
89
from collections.abc import Sequence
910
from typing import Any, Generic, TypeVar, Union
@@ -261,9 +262,64 @@ def extend(self, other: Sequence[T]) -> None:
261262
def set_fields(self, index: int, **kwargs) -> None:
262263
"""Assign the values of an existing object's attributes using keyword arguments."""
263264
self._validate_name_field(kwargs)
264-
class_handle = self.data[index].__class__
265-
new_fields = {**self.data[index].__dict__, **kwargs}
266-
self.data[index] = class_handle(**new_fields)
265+
pydantic_object = False
266+
267+
if importlib.util.find_spec("pydantic"):
268+
# Pydantic is installed, so set up a context manager that will
269+
# suppress validation errors until all fields have been set.
270+
from pydantic import BaseModel, ValidationError
271+
272+
if isinstance(self.data[index], BaseModel):
273+
pydantic_object = True
274+
275+
# Define a custom context manager
276+
class SuppressCustomValidation(contextlib.AbstractContextManager):
277+
"""Context manager to suppress "value_error" based validation errors in pydantic.
278+
279+
This validation context is necessary because errors can occur whilst individual
280+
model values are set, which are resolved when all of the input values are set.
281+
282+
After the exception is suppressed, execution proceeds with the next
283+
statement following the with statement.
284+
285+
with SuppressCustomValidation():
286+
setattr(self.data[index], key, value)
287+
# Execution still resumes here if the attribute cannot be set
288+
"""
289+
290+
def __init__(self):
291+
pass
292+
293+
def __enter__(self):
294+
pass
295+
296+
def __exit__(self, exctype, excinst, exctb):
297+
# If the return of __exit__ is True or truthy, the exception is suppressed.
298+
# Otherwise, the default behaviour of raising the exception applies.
299+
#
300+
# To suppress errors arising from field and model validators in pydantic,
301+
# we will examine the validation errors raised. If all of the errors
302+
# listed in the exception have the type "value_error", this indicates
303+
# they have arisen from field or model validators and will be suppressed.
304+
# Otherwise, they will be raised.
305+
if exctype is None:
306+
return
307+
if issubclass(exctype, ValidationError) and all(
308+
[error["type"] == "value_error" for error in excinst.errors()]
309+
):
310+
return True
311+
return False
312+
313+
validation_context = SuppressCustomValidation()
314+
else:
315+
validation_context = contextlib.nullcontext()
316+
317+
for key, value in kwargs.items():
318+
with validation_context:
319+
setattr(self.data[index], key, value)
320+
321+
if pydantic_object:
322+
self._class_handle.model_validate(self.data[index])
267323

268324
def get_names(self) -> list[str]:
269325
"""Return a list of the values of the name_field attribute of each class object in the list.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ extend-exclude = ["*.ipynb"]
1212

1313
[tool.ruff.lint]
1414
select = ["E", "F", "UP", "B", "SIM", "I"]
15-
ignore = ["SIM108"]
15+
ignore = ["SIM103", "SIM108"]
1616

1717
[tool.ruff.lint.flake8-pytest-style]
1818
fixture-parentheses = false

tests/test_classlist.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,3 +1005,25 @@ class NestedModel(pydantic.BaseModel):
10051005
for submodel, exp_dict in zip(model.submodels, submodels_list):
10061006
for key, value in exp_dict.items():
10071007
assert getattr(submodel, key) == value
1008+
1009+
def test_set_pydantic_fields(self):
1010+
"""Test that intermediate validation errors for pydantic models are suppressed when using "set_fields"."""
1011+
from pydantic import BaseModel, model_validator
1012+
1013+
class MinMaxModel(BaseModel):
1014+
min: float
1015+
value: float
1016+
max: float
1017+
1018+
@model_validator(mode="after")
1019+
def check_value_in_range(self) -> "MinMaxModel":
1020+
if self.value < self.min or self.value > self.max:
1021+
raise ValueError(
1022+
f"value {self.value} is not within the defined range: {self.min} <= value <= {self.max}"
1023+
)
1024+
return self
1025+
1026+
model_list = ClassList([MinMaxModel(min=1, value=2, max=5)])
1027+
model_list.set_fields(0, min=3, value=4)
1028+
1029+
assert model_list == ClassList([MinMaxModel(min=3.0, value=4.0, max=5.0)])

0 commit comments

Comments
 (0)