Skip to content

Commit e84e1ed

Browse files
committed
Add check_callback parameter to FloatArrayItem and update tests
1 parent a709057 commit e84e1ed

4 files changed

Lines changed: 87 additions & 10 deletions

File tree

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717
set_validation_mode(ValidationMode.STRICT)
1818
```
1919

20+
* New `check_callback` parameter for `FloatArrayItem`:
21+
* The `check_callback` parameter allows you to specify a custom validation function for the item.
22+
* This function will be called to validate the item's value whenever it is set.
23+
* If the function returns `False`, the value will be considered invalid.
24+
2025
* New `allow_none` parameter for `DataItem` objects:
2126
* The `allow_none` parameter allows you to specify whether `None` is a valid value for the item, which can be especially useful when validation modes are used.
2227
* If `allow_none` is set to `True`, `None` is considered a valid value regardless of other constraints.

guidata/dataset/dataitems.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,6 +1095,10 @@ class FloatArrayItem(DataItem):
10951095
variable_size: if True, allows to add/remove row/columns on all axis
10961096
allow_none: if True, None is a valid value regardless of other constraints
10971097
(optional, default=True)
1098+
check_callback: additional callback to check the value
1099+
(function of two arguments (value, raise_exception) returning a boolean,
1100+
where value is the value to check and raise_exception is a boolean
1101+
indicating whether to raise an exception on invalid value)
10981102
"""
10991103

11001104
type = np.ndarray
@@ -1110,12 +1114,14 @@ def __init__(
11101114
check: bool = True,
11111115
variable_size=False,
11121116
allow_none: bool = True,
1117+
check_callback: Callable[[np.ndarray, bool], bool] | None = None,
11131118
) -> None:
11141119
super().__init__(
11151120
label, default=default, help=help, check=check, allow_none=allow_none
11161121
)
11171122
self.set_prop("display", format=format, transpose=transpose, minmax=minmax)
11181123
self.set_prop("edit", variable_size=variable_size)
1124+
self.check_callback = check_callback
11191125

11201126
def check_value(self, value: np.ndarray, raise_exception: bool = False) -> bool:
11211127
"""Override DataItem method"""
@@ -1125,6 +1131,8 @@ def check_value(self, value: np.ndarray, raise_exception: bool = False) -> bool:
11251131
if raise_exception:
11261132
raise TypeError(f"Expected {self.type}, got {type(value)}")
11271133
return False
1134+
if self.check_callback is not None:
1135+
return self.check_callback(value, raise_exception)
11281136
return True
11291137

11301138
def format_string(

guidata/tests/dataset/test_all_features.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import numpy as np
1717

1818
import guidata.dataset as gds
19+
from guidata.config import ValidationMode, get_validation_mode, set_validation_mode
1920
from guidata.dataset.qtitemwidgets import DataSetWidget
2021
from guidata.dataset.qtwidgets import DataSetEditLayout, DataSetShowLayout
2122
from guidata.env import execenv
@@ -65,10 +66,10 @@ class Parameters(gds.DataSet):
6566
"float_col": 1.0,
6667
},
6768
)
68-
string = gds.StringItem("String")
69+
string = gds.StringItem("String", default="")
6970
string_regexp = gds.StringItem("String", regexp=r"^[a-z]+[0-9]$", default="abcd9")
70-
password = gds.StringItem("Password", password=True)
71-
text = gds.TextItem("Text")
71+
password = gds.StringItem("Password", default="", password=True)
72+
text = gds.TextItem("Text", default="")
7273
_bg = gds.BeginGroup("A sub group")
7374
float_slider = gds.FloatItem(
7475
"Float (with slider)", default=0.5, min=0, max=1, step=0.01, slider=True
@@ -121,11 +122,13 @@ class Parameters(gds.DataSet):
121122
integer_slider = gds.IntItem(
122123
"Integer (with slider)", default=5, min=-50, max=50, slider=True
123124
)
124-
integer = gds.IntItem("Integer", default=5, min=3, max=6).set_pos(col=1)
125+
integer = gds.IntItem("Integer", default=5, min=3, max=60).set_pos(col=1)
125126

126127

127128
def test_all_features():
128129
"""Test all guidata item/group features"""
130+
old_mode = get_validation_mode()
131+
set_validation_mode(ValidationMode.STRICT)
129132
with execenv.context(accept_dialogs=True):
130133
with qt_app_context():
131134
prm1 = Parameters()
@@ -137,8 +140,8 @@ def test_all_features():
137140
execenv.print(prm1)
138141
prm1.view()
139142

140-
prm2 = Parameters.create(integer=10101010, string="Using `create`")
141-
assert prm2.integer == 10101010
143+
prm2 = Parameters.create(integer=59, string="Using `create`")
144+
assert prm2.integer == 59
142145
print(prm2)
143146

144147
try:
@@ -150,6 +153,7 @@ def test_all_features():
150153
raise AssertionError("AttributeError not raised")
151154

152155
execenv.print("OK")
156+
set_validation_mode(old_mode)
153157

154158

155159
if __name__ == "__main__":

guidata/tests/unit/test_validationmodes.py

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,42 @@
77

88
import os.path as osp
99

10+
import numpy as np
1011
import pytest
1112

1213
import guidata.dataset as gds
1314
from guidata.config import ValidationMode, get_validation_mode, set_validation_mode
1415
from guidata.env import execenv
1516

1617

18+
def check_array(value: np.ndarray, raise_exception: bool = False) -> bool:
19+
"""Check if value is a valid 2D array of floats.
20+
21+
Args:
22+
value: value to check
23+
raise_exception: if True, raise an exception on invalid value
24+
25+
Returns:
26+
True if value is valid, False otherwise
27+
"""
28+
if (
29+
not isinstance(value, np.ndarray)
30+
or value.ndim != 2
31+
or not np.issubdtype(value.dtype, np.floating)
32+
):
33+
if raise_exception:
34+
raise TypeError("Float array must be a 2D numpy array of floats")
35+
return False
36+
return True
37+
38+
1739
class Parameters(gds.DataSet):
1840
"""Example dataset"""
1941

2042
fitem = gds.FloatItem("Float", min=1, max=250)
2143
iitem = gds.IntItem("Integer", max=20, nonzero=True)
2244
sitem = gds.StringItem("String", notempty=True)
45+
aitem = gds.FloatArrayItem("Array", check_callback=check_array)
2346
fileopenitem = gds.FileOpenItem("File", ("py",))
2447
filesopenitem = gds.FilesOpenItem("Files", ("py",))
2548
filesaveitem = gds.FileSaveItem("Save file", ("py",))
@@ -30,6 +53,7 @@ class Parameters(gds.DataSet):
3053
"fitem": 100.0,
3154
"iitem": 10,
3255
"sitem": "test",
56+
"aitem": np.array([[1.0, 2.0], [3.0, 4.0]]),
3357
"fileopenitem": __file__,
3458
"filesopenitem": [
3559
__file__,
@@ -50,6 +74,10 @@ class Parameters(gds.DataSet):
5074
"test", # Not an integer
5175
23.2323, # Not an integer
5276
),
77+
"aitem": (
78+
np.array([1.0, 2.0]), # Not a 2D array
79+
np.array([[1, 2], [3, 4]]), # Not a float array
80+
),
5381
"sitem": (
5482
"", # Empty string not allowed
5583
123, # Not a string
@@ -74,18 +102,45 @@ def test_default_validation_mode():
74102
execenv.print("OK")
75103

76104

105+
def __check_assigned_value_is_equal(assigned_value, expected_value):
106+
"""Check if the assigned value is correctly set"""
107+
if isinstance(expected_value, np.ndarray):
108+
# For arrays, we check if the value is set correctly
109+
assert isinstance(assigned_value, np.ndarray)
110+
assert assigned_value.shape == expected_value.shape
111+
assert np.all(assigned_value == expected_value)
112+
else:
113+
# For other types, we check if the value is set correctly
114+
assert assigned_value == expected_value
115+
116+
117+
def __check_assigned_value_is_not_equal(assigned_value, expected_value):
118+
"""Check if the assigned value is not equal to the real value"""
119+
if isinstance(expected_value, np.ndarray):
120+
# For arrays, we check if the value is set correctly
121+
if isinstance(assigned_value, np.ndarray):
122+
assert assigned_value.shape == expected_value.shape
123+
assert not np.all(assigned_value == expected_value)
124+
else:
125+
assert assigned_value is None
126+
else:
127+
# For other types, we check if the value is set correctly
128+
assert assigned_value != expected_value
129+
130+
77131
def test_valid_data():
78132
"""Test valid data"""
79133
params = Parameters()
80134
execenv.print("Testing valid data: ", end="")
81135
for name, value in VALID_DATA.items():
82136
setattr(params, name, value)
83-
assert getattr(params, name) == value
137+
__check_assigned_value_is_equal(getattr(params, name), value)
84138
execenv.print("OK")
85139

86140

87141
def test_invalid_data_with_no_validation():
88142
"""Test invalid data with validation disabled"""
143+
old_mode = get_validation_mode()
89144
params = Parameters()
90145
set_validation_mode(ValidationMode.DISABLED)
91146
assert get_validation_mode() == ValidationMode.DISABLED
@@ -95,11 +150,13 @@ def test_invalid_data_with_no_validation():
95150
execenv.print(f" Testing {name} with value: {value}")
96151
setattr(params, name, value)
97152
# No exception should be raised
98-
assert getattr(params, name) == value
153+
__check_assigned_value_is_equal(getattr(params, name), value)
154+
set_validation_mode(old_mode)
99155

100156

101157
def test_invalid_data_with_enabled_validation():
102158
"""Test invalid data with validation enabled"""
159+
old_mode = get_validation_mode()
103160
params = Parameters()
104161
set_validation_mode(ValidationMode.ENABLED)
105162
assert get_validation_mode() == ValidationMode.ENABLED
@@ -111,11 +168,13 @@ def test_invalid_data_with_enabled_validation():
111168
with pytest.warns(gds.DataItemValidationWarning):
112169
setattr(params, name, value)
113170
# The value should be set anyway
114-
assert getattr(params, name) == value
171+
__check_assigned_value_is_equal(getattr(params, name), value)
172+
set_validation_mode(old_mode)
115173

116174

117175
def test_invalid_data_with_strict_validation():
118176
"""Test invalid data with strict validation"""
177+
old_mode = get_validation_mode()
119178
params = Parameters()
120179
set_validation_mode(ValidationMode.STRICT)
121180
assert get_validation_mode() == ValidationMode.STRICT
@@ -127,7 +186,8 @@ def test_invalid_data_with_strict_validation():
127186
with pytest.raises(gds.DataItemValidationError):
128187
setattr(params, name, value)
129188
# The value should not be set
130-
assert getattr(params, name) != value
189+
__check_assigned_value_is_not_equal(getattr(params, name), value)
190+
set_validation_mode(old_mode)
131191

132192

133193
if __name__ == "__main__":

0 commit comments

Comments
 (0)