77
88import os .path as osp
99
10+ import numpy as np
1011import pytest
1112
1213import guidata .dataset as gds
1314from guidata .config import ValidationMode , get_validation_mode , set_validation_mode
1415from guidata .env import execenv
1516
1617
18+ def check_array (value : np .ndarray , raise_exception : bool = False ) -> bool :
19+ """Check if value is a valid 2D array of floats.
20+
21+ Args:
22+ value: value to check
23+ raise_exception: if True, raise an exception on invalid value
24+
25+ Returns:
26+ True if value is valid, False otherwise
27+ """
28+ if (
29+ not isinstance (value , np .ndarray )
30+ or value .ndim != 2
31+ or not np .issubdtype (value .dtype , np .floating )
32+ ):
33+ if raise_exception :
34+ raise TypeError ("Float array must be a 2D numpy array of floats" )
35+ return False
36+ return True
37+
38+
1739class Parameters (gds .DataSet ):
1840 """Example dataset"""
1941
2042 fitem = gds .FloatItem ("Float" , min = 1 , max = 250 )
2143 iitem = gds .IntItem ("Integer" , max = 20 , nonzero = True )
2244 sitem = gds .StringItem ("String" , notempty = True )
45+ aitem = gds .FloatArrayItem ("Array" , check_callback = check_array )
2346 fileopenitem = gds .FileOpenItem ("File" , ("py" ,))
2447 filesopenitem = gds .FilesOpenItem ("Files" , ("py" ,))
2548 filesaveitem = gds .FileSaveItem ("Save file" , ("py" ,))
@@ -30,6 +53,7 @@ class Parameters(gds.DataSet):
3053 "fitem" : 100.0 ,
3154 "iitem" : 10 ,
3255 "sitem" : "test" ,
56+ "aitem" : np .array ([[1.0 , 2.0 ], [3.0 , 4.0 ]]),
3357 "fileopenitem" : __file__ ,
3458 "filesopenitem" : [
3559 __file__ ,
@@ -50,6 +74,10 @@ class Parameters(gds.DataSet):
5074 "test" , # Not an integer
5175 23.2323 , # Not an integer
5276 ),
77+ "aitem" : (
78+ np .array ([1.0 , 2.0 ]), # Not a 2D array
79+ np .array ([[1 , 2 ], [3 , 4 ]]), # Not a float array
80+ ),
5381 "sitem" : (
5482 "" , # Empty string not allowed
5583 123 , # Not a string
@@ -74,18 +102,45 @@ def test_default_validation_mode():
74102 execenv .print ("OK" )
75103
76104
105+ def __check_assigned_value_is_equal (assigned_value , expected_value ):
106+ """Check if the assigned value is correctly set"""
107+ if isinstance (expected_value , np .ndarray ):
108+ # For arrays, we check if the value is set correctly
109+ assert isinstance (assigned_value , np .ndarray )
110+ assert assigned_value .shape == expected_value .shape
111+ assert np .all (assigned_value == expected_value )
112+ else :
113+ # For other types, we check if the value is set correctly
114+ assert assigned_value == expected_value
115+
116+
117+ def __check_assigned_value_is_not_equal (assigned_value , expected_value ):
118+ """Check if the assigned value is not equal to the real value"""
119+ if isinstance (expected_value , np .ndarray ):
120+ # For arrays, we check if the value is set correctly
121+ if isinstance (assigned_value , np .ndarray ):
122+ assert assigned_value .shape == expected_value .shape
123+ assert not np .all (assigned_value == expected_value )
124+ else :
125+ assert assigned_value is None
126+ else :
127+ # For other types, we check if the value is set correctly
128+ assert assigned_value != expected_value
129+
130+
77131def test_valid_data ():
78132 """Test valid data"""
79133 params = Parameters ()
80134 execenv .print ("Testing valid data: " , end = "" )
81135 for name , value in VALID_DATA .items ():
82136 setattr (params , name , value )
83- assert getattr (params , name ) == value
137+ __check_assigned_value_is_equal ( getattr (params , name ), value )
84138 execenv .print ("OK" )
85139
86140
87141def test_invalid_data_with_no_validation ():
88142 """Test invalid data with validation disabled"""
143+ old_mode = get_validation_mode ()
89144 params = Parameters ()
90145 set_validation_mode (ValidationMode .DISABLED )
91146 assert get_validation_mode () == ValidationMode .DISABLED
@@ -95,11 +150,13 @@ def test_invalid_data_with_no_validation():
95150 execenv .print (f" Testing { name } with value: { value } " )
96151 setattr (params , name , value )
97152 # No exception should be raised
98- assert getattr (params , name ) == value
153+ __check_assigned_value_is_equal (getattr (params , name ), value )
154+ set_validation_mode (old_mode )
99155
100156
101157def test_invalid_data_with_enabled_validation ():
102158 """Test invalid data with validation enabled"""
159+ old_mode = get_validation_mode ()
103160 params = Parameters ()
104161 set_validation_mode (ValidationMode .ENABLED )
105162 assert get_validation_mode () == ValidationMode .ENABLED
@@ -111,11 +168,13 @@ def test_invalid_data_with_enabled_validation():
111168 with pytest .warns (gds .DataItemValidationWarning ):
112169 setattr (params , name , value )
113170 # The value should be set anyway
114- assert getattr (params , name ) == value
171+ __check_assigned_value_is_equal (getattr (params , name ), value )
172+ set_validation_mode (old_mode )
115173
116174
117175def test_invalid_data_with_strict_validation ():
118176 """Test invalid data with strict validation"""
177+ old_mode = get_validation_mode ()
119178 params = Parameters ()
120179 set_validation_mode (ValidationMode .STRICT )
121180 assert get_validation_mode () == ValidationMode .STRICT
@@ -127,7 +186,8 @@ def test_invalid_data_with_strict_validation():
127186 with pytest .raises (gds .DataItemValidationError ):
128187 setattr (params , name , value )
129188 # The value should not be set
130- assert getattr (params , name ) != value
189+ __check_assigned_value_is_not_equal (getattr (params , name ), value )
190+ set_validation_mode (old_mode )
131191
132192
133193if __name__ == "__main__" :
0 commit comments