Skip to content

Commit 0a8fa3c

Browse files
authored
Introduces case insensitive enums and Classlists (#53)
* Adds case insensitive enums * Adds "__str__" method to format contrast models as vertical list * Adds code to ensure searching parameter names in project and classlist is case insensitive
1 parent 48b27ba commit 0a8fa3c

File tree

12 files changed

+305
-101
lines changed

12 files changed

+305
-101
lines changed

RATapi/classlist.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,10 @@ def __str__(self):
6868
+ list(
6969
f"{'Data array: ['+' x '.join(str(i) for i in v.shape) if v.size > 0 else '['}]"
7070
if isinstance(v, np.ndarray)
71+
else "\n".join(element for element in v)
72+
if k == "model"
7173
else str(v)
72-
for v in model.__dict__.values()
74+
for k, v in model.__dict__.items()
7375
)
7476
for index, model in enumerate(self.data)
7577
]
@@ -308,9 +310,9 @@ def _validate_name_field(self, input_args: dict[str, Any]) -> None:
308310
Raised if the input arguments contain a name_field value already defined in the ClassList.
309311
310312
"""
311-
names = self.get_names()
313+
names = [name.lower() for name in self.get_names()]
312314
with contextlib.suppress(KeyError):
313-
if input_args[self.name_field] in names:
315+
if input_args[self.name_field].lower() in names:
314316
raise ValueError(
315317
f"Input arguments contain the {self.name_field} '{input_args[self.name_field]}', "
316318
f"which is already specified in the ClassList",
@@ -331,7 +333,7 @@ def _check_unique_name_fields(self, input_list: Iterable[object]) -> None:
331333
Raised if the input list defines more than one object with the same value of name_field.
332334
333335
"""
334-
names = [getattr(model, self.name_field) for model in input_list if hasattr(model, self.name_field)]
336+
names = [getattr(model, self.name_field).lower() for model in input_list if hasattr(model, self.name_field)]
335337
if len(set(names)) != len(names):
336338
raise ValueError(f"Input list contains objects with the same value of the {self.name_field} attribute")
337339

@@ -367,7 +369,12 @@ def _get_item_from_name_field(self, value: Union[object, str]) -> Union[object,
367369
object with that value of the name_field attribute cannot be found.
368370
369371
"""
370-
return next((model for model in self.data if getattr(model, self.name_field) == value), value)
372+
try:
373+
lower_value = value.lower()
374+
except AttributeError:
375+
lower_value = value
376+
377+
return next((model for model in self.data if getattr(model, self.name_field).lower() == lower_value), value)
371378

372379
@staticmethod
373380
def _determine_class_handle(input_list: Sequence[object]):

RATapi/controls.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
fields = {
2020
"calculate": common_fields,
2121
"simplex": [*common_fields, "xTolerance", "funcTolerance", "maxFuncEvals", "maxIterations", *update_fields],
22-
"de": [
22+
"DE": [
2323
*common_fields,
2424
"populationSize",
2525
"fWeight",
@@ -29,8 +29,8 @@
2929
"numGenerations",
3030
*update_fields,
3131
],
32-
"ns": [*common_fields, "nLive", "nMCMC", "propScale", "nsTolerance"],
33-
"dream": [*common_fields, "nSamples", "nChains", "jumpProbability", "pUnitGamma", "boundHandling", "adaptPCR"],
32+
"NS": [*common_fields, "nLive", "nMCMC", "propScale", "nsTolerance"],
33+
"DREAM": [*common_fields, "nSamples", "nChains", "jumpProbability", "pUnitGamma", "boundHandling", "adaptPCR"],
3434
}
3535

3636

RATapi/models.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,26 @@ class Contrast(RATModel):
7777
resample: bool = False
7878
model: list[str] = []
7979

80+
def __str__(self):
81+
table = prettytable.PrettyTable()
82+
table.field_names = [key.replace("_", " ") for key in self.__dict__]
83+
model_entry = "\n".join(element for element in self.model)
84+
table.add_row(
85+
[
86+
self.name,
87+
self.data,
88+
self.background,
89+
self.background_action,
90+
self.bulk_in,
91+
self.bulk_out,
92+
self.scalefactor,
93+
self.resolution,
94+
self.resample,
95+
model_entry,
96+
]
97+
)
98+
return table.get_string()
99+
80100

81101
class ContrastWithRatio(RATModel):
82102
"""Groups together all of the components of the model including domain terms."""
@@ -93,6 +113,26 @@ class ContrastWithRatio(RATModel):
93113
domain_ratio: str = ""
94114
model: list[str] = []
95115

116+
def __str__(self):
117+
table = prettytable.PrettyTable()
118+
table.field_names = [key.replace("_", " ") for key in self.__dict__]
119+
model_entry = "\n".join(element for element in self.model)
120+
table.add_row(
121+
[
122+
self.name,
123+
self.data,
124+
self.background,
125+
self.background_action,
126+
self.bulk_in,
127+
self.bulk_out,
128+
self.scalefactor,
129+
self.resolution,
130+
self.resample,
131+
model_entry,
132+
]
133+
)
134+
return table.get_string()
135+
96136

97137
class CustomFile(RATModel):
98138
"""Defines the files containing functions to run when using custom models."""
@@ -219,6 +259,13 @@ class DomainContrast(RATModel):
219259
name: str = Field(default_factory=lambda: "New Domain Contrast " + next(domain_contrast_number), min_length=1)
220260
model: list[str] = []
221261

262+
def __str__(self):
263+
table = prettytable.PrettyTable()
264+
table.field_names = [key.replace("_", " ") for key in self.__dict__]
265+
model_entry = "\n".join(element for element in self.model)
266+
table.add_row([self.name, model_entry])
267+
return table.get_string()
268+
222269

223270
class Layer(RATModel, populate_by_name=True):
224271
"""Combines parameters into defined layers."""

RATapi/outputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def make_results(
185185
resample=output_results.contrastParams.resample,
186186
)
187187

188-
if procedure in [Procedures.NS, Procedures.Dream]:
188+
if procedure in [Procedures.NS, Procedures.DREAM]:
189189
prediction_intervals = PredictionIntervals(
190190
reflectivity=bayes_results.predictionIntervals.reflectivity,
191191
sld=bayes_results.predictionIntervals.sld,

RATapi/project.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def model_post_init(self, __context: Any) -> None:
269269
if not hasattr(field, "_class_handle"):
270270
field._class_handle = getattr(RATapi.models, model)
271271

272-
if "Substrate Roughness" not in self.parameters.get_names():
272+
if "Substrate Roughness" not in [name.title() for name in self.parameters.get_names()]:
273273
self.parameters.insert(
274274
0,
275275
RATapi.models.ProtectedParameter(
@@ -283,13 +283,13 @@ def model_post_init(self, __context: Any) -> None:
283283
sigma=np.inf,
284284
),
285285
)
286-
elif "Substrate Roughness" not in self.get_all_protected_parameters().values():
286+
elif "Substrate Roughness" not in [name.title() for name in self.get_all_protected_parameters()["parameters"]]:
287287
# If substrate roughness is included as a standard parameter replace it with a protected parameter
288288
substrate_roughness_values = self.parameters[self.parameters.index("Substrate Roughness")].model_dump()
289289
self.parameters.remove("Substrate Roughness")
290290
self.parameters.insert(0, RATapi.models.ProtectedParameter(**substrate_roughness_values))
291291

292-
if "Simulation" not in self.data.get_names():
292+
if "Simulation" not in [name.title() for name in self.data.get_names()]:
293293
self.data.insert(0, RATapi.models.Data(name="Simulation", simulation_range=[0.005, 0.7]))
294294

295295
self._all_names = self.get_all_names()

RATapi/utils/enums.py

Lines changed: 58 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,40 @@
66
from strenum import StrEnum
77

88

9-
# Controls
10-
class Parallel(StrEnum):
11-
"""Defines the available options for parallelization"""
9+
class RATEnum(StrEnum):
10+
@classmethod
11+
def _missing_(cls, value: str):
12+
value = value.lower()
1213

13-
Single = "single"
14-
Points = "points"
15-
Contrasts = "contrasts"
14+
# Replace common alternative spellings
15+
value = value.replace("-", " ").replace("_", " ").replace("++", "pp").replace("polarized", "polarised")
16+
17+
for member in cls:
18+
if member.value.lower() == value:
19+
return member
20+
return None
1621

1722

18-
class Procedures(StrEnum):
23+
# Controls
24+
class Procedures(RATEnum):
1925
"""Defines the available options for procedures"""
2026

2127
Calculate = "calculate"
2228
Simplex = "simplex"
23-
DE = "de"
24-
NS = "ns"
25-
Dream = "dream"
29+
DE = "DE"
30+
NS = "NS"
31+
DREAM = "DREAM"
2632

2733

28-
class Display(StrEnum):
34+
class Parallel(RATEnum):
35+
"""Defines the available options for parallelization"""
36+
37+
Single = "single"
38+
Points = "points"
39+
Contrasts = "contrasts"
40+
41+
42+
class Display(RATEnum):
2943
"""Defines the available options for display"""
3044

3145
Off = "off"
@@ -34,15 +48,6 @@ class Display(StrEnum):
3448
Final = "final"
3549

3650

37-
class BoundHandling(StrEnum):
38-
"""Defines the available options for bound handling"""
39-
40-
Off = "off"
41-
Reflect = "reflect"
42-
Bound = "bound"
43-
Fold = "fold"
44-
45-
4651
class Strategies(Enum):
4752
"""Defines the available options for strategies"""
4853

@@ -54,48 +59,56 @@ class Strategies(Enum):
5459
RandomEitherOrAlgorithm = 6
5560

5661

57-
# Models
58-
class Hydration(StrEnum):
59-
None_ = "none"
60-
BulkIn = "bulk in"
61-
BulkOut = "bulk out"
62-
Oil = "oil"
63-
64-
65-
class Languages(StrEnum):
66-
Cpp = "cpp"
67-
Python = "python"
68-
Matlab = "matlab"
69-
62+
class BoundHandling(RATEnum):
63+
"""Defines the available options for bound handling"""
7064

71-
class Priors(StrEnum):
72-
Uniform = "uniform"
73-
Gaussian = "gaussian"
65+
Off = "off"
66+
Reflect = "reflect"
67+
Bound = "bound"
68+
Fold = "fold"
7469

7570

76-
class TypeOptions(StrEnum):
71+
# Models
72+
class TypeOptions(RATEnum):
7773
Constant = "constant"
7874
Data = "data"
7975
Function = "function"
8076

8177

82-
class BackgroundActions(StrEnum):
78+
class BackgroundActions(RATEnum):
8379
Add = "add"
8480
Subtract = "subtract"
8581

8682

83+
class Languages(RATEnum):
84+
Cpp = "Cpp"
85+
Python = "python"
86+
Matlab = "matlab"
87+
88+
89+
class Hydration(RATEnum):
90+
None_ = "none"
91+
BulkIn = "bulk in"
92+
BulkOut = "bulk out"
93+
94+
95+
class Priors(RATEnum):
96+
Uniform = "uniform"
97+
Gaussian = "gaussian"
98+
99+
87100
# Project
88-
class Calculations(StrEnum):
101+
class Calculations(RATEnum):
89102
NonPolarised = "non polarised"
90103
Domains = "domains"
91104

92105

93-
class Geometries(StrEnum):
94-
AirSubstrate = "air/substrate"
95-
SubstrateLiquid = "substrate/liquid"
96-
97-
98-
class LayerModels(StrEnum):
106+
class LayerModels(RATEnum):
99107
CustomLayers = "custom layers"
100108
CustomXY = "custom xy"
101109
StandardLayers = "standard layers"
110+
111+
112+
class Geometries(RATEnum):
113+
AirSubstrate = "air/substrate"
114+
SubstrateLiquid = "substrate/liquid"

tests/test_classlist.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -767,10 +767,13 @@ def test__validate_name_field(two_name_class_list: ClassList, input_dict: dict[s
767767
"input_dict",
768768
[
769769
({"name": "Alice"}),
770+
({"name": "ALICE"}),
771+
({"name": "alice"}),
770772
],
771773
)
772774
def test__validate_name_field_not_unique(two_name_class_list: ClassList, input_dict: dict[str, Any]) -> None:
773-
"""We should raise a ValueError if we input values containing a name_field defined in an object in the ClassList."""
775+
"""We should raise a ValueError if we input values containing a name_field defined in an object in the ClassList,
776+
accounting for case sensitivity."""
774777
with pytest.raises(
775778
ValueError,
776779
match=f"Input arguments contain the {two_name_class_list.name_field} "
@@ -801,11 +804,13 @@ def test__check_unique_name_fields(two_name_class_list: ClassList, input_list: I
801804
"input_list",
802805
[
803806
([InputAttributes(name="Alice"), InputAttributes(name="Alice")]),
807+
([InputAttributes(name="Alice"), InputAttributes(name="ALICE")]),
808+
([InputAttributes(name="Alice"), InputAttributes(name="alice")]),
804809
],
805810
)
806811
def test__check_unique_name_fields_not_unique(two_name_class_list: ClassList, input_list: Iterable) -> None:
807-
"""We should raise a ValueError if an input list contains multiple objects with matching name_field values
808-
defined.
812+
"""We should raise a ValueError if an input list contains multiple objects with (case-insensitive) matching
813+
name_field values defined.
809814
"""
810815
with pytest.raises(
811816
ValueError,
@@ -846,7 +851,11 @@ def test__check_classes_different_classes(input_list: Iterable) -> None:
846851
["value", "expected_output"],
847852
[
848853
("Alice", InputAttributes(name="Alice")),
854+
("ALICE", InputAttributes(name="Alice")),
855+
("alice", InputAttributes(name="Alice")),
849856
("Eve", "Eve"),
857+
("EVE", "EVE"),
858+
("eve", "eve"),
850859
],
851860
)
852861
def test__get_item_from_name_field(

0 commit comments

Comments
 (0)