Skip to content

Commit b9a7b10

Browse files
authored
Adds code to save and load results objects (#161)
* Adds code to save and load results objects * Merges load methods * Addresses review comments
1 parent 37c0696 commit b9a7b10

File tree

5 files changed

+314
-64
lines changed

5 files changed

+314
-64
lines changed

RATapi/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,24 @@
66
from RATapi import events, models
77
from RATapi.classlist import ClassList
88
from RATapi.controls import Controls
9+
from RATapi.outputs import BayesResults, Results
910
from RATapi.project import Project
1011
from RATapi.run import run
1112
from RATapi.utils import convert, plotting
1213

1314
with suppress(ImportError): # orsopy is an optional dependency
1415
from RATapi.utils import orso as orso
1516

16-
__all__ = ["examples", "models", "events", "ClassList", "Controls", "Project", "run", "plotting", "convert"]
17+
__all__ = [
18+
"examples",
19+
"models",
20+
"events",
21+
"ClassList",
22+
"Controls",
23+
"BayesResults",
24+
"Results",
25+
"Project",
26+
"run",
27+
"plotting",
28+
"convert",
29+
]

RATapi/outputs.py

Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,78 @@
11
"""Converts results from the compiled RAT code to python dataclasses."""
22

3+
import json
34
from dataclasses import dataclass
5+
from pathlib import Path
46
from typing import Any, Optional, Union
57

68
import numpy as np
79

810
import RATapi.rat_core
911
from RATapi.utils.enums import Procedures
1012

13+
bayes_results_subclasses = [
14+
"predictionIntervals",
15+
"confidenceIntervals",
16+
"dreamParams",
17+
"dreamOutput",
18+
"nestedSamplerOutput",
19+
]
20+
21+
bayes_results_fields = {
22+
"param_fields": {
23+
"predictionIntervals": [],
24+
"confidenceIntervals": [],
25+
"dreamParams": [
26+
"nParams",
27+
"nChains",
28+
"nGenerations",
29+
"parallel",
30+
"CPU",
31+
"jumpProbability",
32+
"pUnitGamma",
33+
"nCR",
34+
"delta",
35+
"steps",
36+
"zeta",
37+
"outlier",
38+
"adaptPCR",
39+
"thinning",
40+
"epsilon",
41+
"ABC",
42+
"IO",
43+
"storeOutput",
44+
],
45+
"dreamOutput": ["runtime", "iteration"],
46+
"nestedSamplerOutput": ["logZ", "logZErr"],
47+
},
48+
"list_fields": {
49+
"predictionIntervals": ["reflectivity"],
50+
"confidenceIntervals": [],
51+
"dreamParams": [],
52+
"dreamOutput": [],
53+
"nestedSamplerOutput": [],
54+
},
55+
"double_list_fields": {
56+
"predictionIntervals": ["sld"],
57+
"confidenceIntervals": [],
58+
"dreamParams": [],
59+
"dreamOutput": [],
60+
"nestedSamplerOutput": [],
61+
},
62+
"array_fields": {
63+
"predictionIntervals": ["sampleChi"],
64+
"confidenceIntervals": ["percentile65", "percentile95", "mean"],
65+
"dreamParams": ["R"],
66+
"dreamOutput": ["allChains", "outlierChains", "AR", "R_stat", "CR"],
67+
"nestedSamplerOutput": ["nestSamples", "postSamples"],
68+
},
69+
}
70+
71+
results_fields = {
72+
"list_fields": ["reflectivity", "simulation", "shiftedData", "backgrounds", "resolutions"],
73+
"double_list_fields": ["sldProfiles", "layers", "resampledLayers"],
74+
}
75+
1176

1277
def get_field_string(field: str, value: Any, array_limit: int):
1378
"""Return a string representation of class fields where large arrays are represented by their shape.
@@ -179,6 +244,74 @@ def __str__(self):
179244
output += get_field_string(key, value, 100)
180245
return output
181246

247+
def save(self, filepath: Union[str, Path] = "./results.json"):
248+
"""Save the Results object to a JSON file.
249+
250+
Parameters
251+
----------
252+
filepath : str or Path
253+
The path to where the results file will be written.
254+
"""
255+
filepath = Path(filepath).with_suffix(".json")
256+
json_dict = write_core_results_fields(self)
257+
258+
filepath.write_text(json.dumps(json_dict))
259+
260+
@classmethod
261+
def load(cls, path: Union[str, Path]) -> Union["Results", "BayesResults"]:
262+
"""Load a Results object from file.
263+
264+
Parameters
265+
----------
266+
path : str or Path
267+
The path to the results json file.
268+
"""
269+
path = Path(path)
270+
input_data = path.read_text()
271+
results_dict = json.loads(input_data)
272+
273+
results_dict = read_core_results_fields(results_dict)
274+
275+
if all(key in results_dict for key in bayes_results_subclasses):
276+
results_dict = read_bayes_results_fields(results_dict)
277+
278+
return BayesResults(
279+
reflectivity=results_dict["reflectivity"],
280+
simulation=results_dict["simulation"],
281+
shiftedData=results_dict["shiftedData"],
282+
backgrounds=results_dict["backgrounds"],
283+
resolutions=results_dict["resolutions"],
284+
sldProfiles=results_dict["sldProfiles"],
285+
layers=results_dict["layers"],
286+
resampledLayers=results_dict["resampledLayers"],
287+
calculationResults=CalculationResults(**results_dict["calculationResults"]),
288+
contrastParams=ContrastParams(**results_dict["contrastParams"]),
289+
fitParams=np.array(results_dict["fitParams"]),
290+
fitNames=results_dict["fitNames"],
291+
predictionIntervals=PredictionIntervals(**results_dict["predictionIntervals"]),
292+
confidenceIntervals=ConfidenceIntervals(**results_dict["confidenceIntervals"]),
293+
dreamParams=DreamParams(**results_dict["dreamParams"]),
294+
dreamOutput=DreamOutput(**results_dict["dreamOutput"]),
295+
nestedSamplerOutput=NestedSamplerOutput(**results_dict["nestedSamplerOutput"]),
296+
chain=np.array(results_dict["chain"]),
297+
)
298+
299+
else:
300+
return Results(
301+
reflectivity=results_dict["reflectivity"],
302+
simulation=results_dict["simulation"],
303+
shiftedData=results_dict["shiftedData"],
304+
backgrounds=results_dict["backgrounds"],
305+
resolutions=results_dict["resolutions"],
306+
sldProfiles=results_dict["sldProfiles"],
307+
layers=results_dict["layers"],
308+
resampledLayers=results_dict["resampledLayers"],
309+
calculationResults=CalculationResults(**results_dict["calculationResults"]),
310+
contrastParams=ContrastParams(**results_dict["contrastParams"]),
311+
fitParams=np.array(results_dict["fitParams"]),
312+
fitNames=results_dict["fitNames"],
313+
)
314+
182315

183316
@dataclass
184317
class PredictionIntervals(RATResult):
@@ -405,6 +538,148 @@ class BayesResults(Results):
405538
nestedSamplerOutput: NestedSamplerOutput
406539
chain: np.ndarray
407540

541+
def save(self, filepath: Union[str, Path] = "./results.json"):
542+
"""Save the BayesResults object to a JSON file.
543+
544+
Parameters
545+
----------
546+
filepath : str or Path
547+
The path to where the results file will be written.
548+
"""
549+
filepath = Path(filepath).with_suffix(".json")
550+
json_dict = write_core_results_fields(self)
551+
552+
# Take each of the subclasses in a BayesResults instance and switch the numpy arrays to lists
553+
for subclass_name in bayes_results_subclasses:
554+
subclass = getattr(self, subclass_name)
555+
subclass_dict = {}
556+
557+
for field in bayes_results_fields["param_fields"][subclass_name]:
558+
subclass_dict[field] = getattr(subclass, field)
559+
560+
for field in bayes_results_fields["list_fields"][subclass_name]:
561+
subclass_dict[field] = [result_array.tolist() for result_array in getattr(subclass, field)]
562+
563+
for field in bayes_results_fields["double_list_fields"][subclass_name]:
564+
subclass_dict[field] = [
565+
[result_array.tolist() for result_array in inner_list] for inner_list in getattr(subclass, field)
566+
]
567+
568+
for field in bayes_results_fields["array_fields"][subclass_name]:
569+
subclass_dict[field] = getattr(subclass, field).tolist()
570+
571+
json_dict[subclass_name] = subclass_dict
572+
573+
json_dict["chain"] = self.chain.tolist()
574+
filepath.write_text(json.dumps(json_dict))
575+
576+
577+
def write_core_results_fields(results: Union[Results, BayesResults], json_dict: Optional[dict] = None) -> dict:
578+
"""Modify the values of the fields that appear in both Results and BayesResults when saving to a json file.
579+
580+
Parameters
581+
----------
582+
results: Union[Results, BayesResults]
583+
The results or BayesResults object we are writing to json.
584+
json_dict: Optional[dict]
585+
The dictionary containing the json output.
586+
587+
Returns
588+
-------
589+
json_dict: dict
590+
The output json dict updated with the fields that appear in both Results and BayesResults.
591+
"""
592+
if json_dict is None:
593+
json_dict = {}
594+
595+
for field in results_fields["list_fields"]:
596+
json_dict[field] = [result_array.tolist() for result_array in getattr(results, field)]
597+
598+
for field in results_fields["double_list_fields"]:
599+
json_dict[field] = [
600+
[result_array.tolist() for result_array in inner_list] for inner_list in getattr(results, field)
601+
]
602+
603+
json_dict["calculationResults"] = {}
604+
json_dict["calculationResults"]["chiValues"] = results.calculationResults.chiValues.tolist()
605+
json_dict["calculationResults"]["sumChi"] = results.calculationResults.sumChi
606+
607+
json_dict["contrastParams"] = {}
608+
for field in results.contrastParams.__dict__:
609+
json_dict["contrastParams"][field] = getattr(results.contrastParams, field).tolist()
610+
611+
json_dict["fitParams"] = results.fitParams.tolist()
612+
json_dict["fitNames"] = results.fitNames
613+
614+
return json_dict
615+
616+
617+
def read_core_results_fields(results_dict: dict) -> dict:
618+
"""Modify the values of the fields that appear in both Results and BayesResults when loading a json file.
619+
620+
Parameters
621+
----------
622+
results_dict: Optional[dict]
623+
The dictionary containing the json input.
624+
625+
Returns
626+
-------
627+
results_dict: dict
628+
The input json dict with the fields that appear in both Results and BayesResults converted to numpy arrays
629+
where necessary.
630+
"""
631+
for field in results_fields["list_fields"]:
632+
results_dict[field] = [np.array(result_array) for result_array in results_dict[field]]
633+
634+
for field in results_fields["double_list_fields"]:
635+
results_dict[field] = [
636+
[np.array(result_array) for result_array in inner_list] for inner_list in results_dict[field]
637+
]
638+
639+
results_dict["calculationResults"]["chiValues"] = np.array(results_dict["calculationResults"]["chiValues"])
640+
641+
for field in results_dict["contrastParams"]:
642+
results_dict["contrastParams"][field] = np.array(results_dict["contrastParams"][field])
643+
644+
return results_dict
645+
646+
647+
def read_bayes_results_fields(results_dict: dict) -> dict:
648+
"""Modify the values of the fields that appear only in BayesResults when loading a json file.
649+
650+
Parameters
651+
----------
652+
results_dict: Optional[dict]
653+
The dictionary containing the json input.
654+
655+
Returns
656+
-------
657+
results_dict: dict
658+
The input json dict with the fields that appear in both Results and BayesResults converted to numpy arrays
659+
where necessary.
660+
"""
661+
for subclass_name in bayes_results_subclasses:
662+
subclass_dict = {}
663+
664+
for field in bayes_results_fields["param_fields"][subclass_name]:
665+
subclass_dict[field] = results_dict[subclass_name][field]
666+
667+
for field in bayes_results_fields["list_fields"][subclass_name]:
668+
subclass_dict[field] = [np.array(result_array) for result_array in results_dict[subclass_name][field]]
669+
670+
for field in bayes_results_fields["double_list_fields"][subclass_name]:
671+
subclass_dict[field] = [
672+
[np.array(result_array) for result_array in inner_list]
673+
for inner_list in results_dict[subclass_name][field]
674+
]
675+
676+
for field in bayes_results_fields["array_fields"][subclass_name]:
677+
subclass_dict[field] = np.array(results_dict[subclass_name][field])
678+
679+
results_dict[subclass_name] = subclass_dict
680+
681+
return results_dict
682+
408683

409684
def make_results(
410685
procedure: Procedures,

RATapi/project.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -906,7 +906,6 @@ def save(self, filepath: Union[str, Path] = "./project.json"):
906906
----------
907907
filepath : str or Path
908908
The path to where the project file will be written.
909-
910909
"""
911910
filepath = Path(filepath).with_suffix(".json")
912911

tests/test_outputs.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
We use the example for both a reflectivity calculation, and Bayesian analysis using the Dream algorithm.
44
"""
55

6+
import tempfile
7+
from pathlib import Path
8+
69
import numpy as np
710
import pytest
811

@@ -203,3 +206,18 @@ def test_results_str(test_output_results, test_str, request) -> None:
203206
test_str = request.getfixturevalue(test_str)
204207

205208
assert test_output_results.__str__() == test_str
209+
210+
211+
@pytest.mark.parametrize("result_class", [RATapi.Results, RATapi.BayesResults])
212+
@pytest.mark.parametrize("test_results", ["reflectivity_calculation_results", "dream_results"])
213+
def test_save_load(result_class, test_results, request):
214+
"""Test that saving and loading an output object returns the same object."""
215+
test_results = request.getfixturevalue(test_results)
216+
217+
with tempfile.TemporaryDirectory() as tmp:
218+
# ignore relative path warnings
219+
path = Path(tmp, "results.json")
220+
test_results.save(path)
221+
loaded_results = result_class.load(path)
222+
223+
check_results_equal(test_results, loaded_results)

0 commit comments

Comments
 (0)