Skip to content

Commit 9a8a1fd

Browse files
committed
Adds code to save and load results objects
1 parent 37c0696 commit 9a8a1fd

File tree

5 files changed

+289
-64
lines changed

5 files changed

+289
-64
lines changed

RATapi/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,24 @@
66
from RATapi import events, models
77
from RATapi.classlist import ClassList
88
from RATapi.controls import Controls
9+
from RATapi.outputs import BayesResults, Results
910
from RATapi.project import Project
1011
from RATapi.run import run
1112
from RATapi.utils import convert, plotting
1213

1314
with suppress(ImportError): # orsopy is an optional dependency
1415
from RATapi.utils import orso as orso
1516

16-
__all__ = ["examples", "models", "events", "ClassList", "Controls", "Project", "run", "plotting", "convert"]
17+
__all__ = [
18+
"examples",
19+
"models",
20+
"events",
21+
"ClassList",
22+
"Controls",
23+
"BayesResults",
24+
"Results",
25+
"Project",
26+
"run",
27+
"plotting",
28+
"convert",
29+
]

RATapi/outputs.py

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,78 @@
11
"""Converts results from the compiled RAT code to python dataclasses."""
22

3+
import json
34
from dataclasses import dataclass
5+
from pathlib import Path
46
from typing import Any, Optional, Union
57

68
import numpy as np
79

810
import RATapi.rat_core
911
from RATapi.utils.enums import Procedures
1012

13+
bayes_results_subclasses = [
14+
"predictionIntervals",
15+
"confidenceIntervals",
16+
"dreamParams",
17+
"dreamOutput",
18+
"nestedSamplerOutput",
19+
]
20+
21+
bayes_results_fields = {
22+
"param_fields": {
23+
"predictionIntervals": [],
24+
"confidenceIntervals": [],
25+
"dreamParams": [
26+
"nParams",
27+
"nChains",
28+
"nGenerations",
29+
"parallel",
30+
"CPU",
31+
"jumpProbability",
32+
"pUnitGamma",
33+
"nCR",
34+
"delta",
35+
"steps",
36+
"zeta",
37+
"outlier",
38+
"adaptPCR",
39+
"thinning",
40+
"epsilon",
41+
"ABC",
42+
"IO",
43+
"storeOutput",
44+
],
45+
"dreamOutput": ["runtime", "iteration"],
46+
"nestedSamplerOutput": ["logZ", "logZErr"],
47+
},
48+
"list_fields": {
49+
"predictionIntervals": ["reflectivity"],
50+
"confidenceIntervals": [],
51+
"dreamParams": [],
52+
"dreamOutput": [],
53+
"nestedSamplerOutput": [],
54+
},
55+
"double_list_fields": {
56+
"predictionIntervals": ["sld"],
57+
"confidenceIntervals": [],
58+
"dreamParams": [],
59+
"dreamOutput": [],
60+
"nestedSamplerOutput": [],
61+
},
62+
"array_fields": {
63+
"predictionIntervals": ["sampleChi"],
64+
"confidenceIntervals": ["percentile65", "percentile95", "mean"],
65+
"dreamParams": ["R"],
66+
"dreamOutput": ["allChains", "outlierChains", "AR", "R_stat", "CR"],
67+
"nestedSamplerOutput": ["nestSamples", "postSamples"],
68+
},
69+
}
70+
71+
results_fields = {
72+
"list_fields": ["reflectivity", "simulation", "shiftedData", "backgrounds", "resolutions"],
73+
"double_list_fields": ["sldProfiles", "layers", "resampledLayers"],
74+
}
75+
1176

1277
def get_field_string(field: str, value: Any, array_limit: int):
1378
"""Return a string representation of class fields where large arrays are represented by their shape.
@@ -179,6 +244,49 @@ def __str__(self):
179244
output += get_field_string(key, value, 100)
180245
return output
181246

247+
def save(self, filepath: Union[str, Path] = "./results.json"):
248+
"""Save the Results object to a JSON file.
249+
250+
Parameters
251+
----------
252+
filepath : str or Path
253+
The path to where the results file will be written.
254+
"""
255+
filepath = Path(filepath).with_suffix(".json")
256+
json_dict = write_core_results_fields(self)
257+
258+
filepath.write_text(json.dumps(json_dict))
259+
260+
@classmethod
261+
def load(cls, path: Union[str, Path]) -> "Results":
262+
"""Load a Results object from file.
263+
264+
Parameters
265+
----------
266+
path : str or Path
267+
The path to the results json file.
268+
"""
269+
path = Path(path)
270+
input_data = path.read_text()
271+
results_dict = json.loads(input_data)
272+
273+
results_dict = read_core_results_fields(results_dict)
274+
275+
return Results(
276+
reflectivity=results_dict["reflectivity"],
277+
simulation=results_dict["simulation"],
278+
shiftedData=results_dict["shiftedData"],
279+
backgrounds=results_dict["backgrounds"],
280+
resolutions=results_dict["resolutions"],
281+
sldProfiles=results_dict["sldProfiles"],
282+
layers=results_dict["layers"],
283+
resampledLayers=results_dict["resampledLayers"],
284+
calculationResults=CalculationResults(**results_dict["calculationResults"]),
285+
contrastParams=ContrastParams(**results_dict["contrastParams"]),
286+
fitParams=np.array(results_dict["fitParams"]),
287+
fitNames=results_dict["fitNames"],
288+
)
289+
182290

183291
@dataclass
184292
class PredictionIntervals(RATResult):
@@ -405,6 +513,143 @@ class BayesResults(Results):
405513
nestedSamplerOutput: NestedSamplerOutput
406514
chain: np.ndarray
407515

516+
def save(self, filepath: Union[str, Path] = "./results.json"):
517+
"""Save the BayesResults object to a JSON file.
518+
519+
Parameters
520+
----------
521+
filepath : str or Path
522+
The path to where the results file will be written.
523+
"""
524+
filepath = Path(filepath).with_suffix(".json")
525+
json_dict = write_core_results_fields(self)
526+
527+
# Take each of the subclasses in a BayesResults instance and switch the numpy arrays to lists
528+
for subclass_name in bayes_results_subclasses:
529+
subclass = getattr(self, subclass_name)
530+
subclass_dict = {}
531+
532+
for field in bayes_results_fields["param_fields"][subclass_name]:
533+
subclass_dict[field] = getattr(subclass, field)
534+
535+
for field in bayes_results_fields["list_fields"][subclass_name]:
536+
subclass_dict[field] = [result_array.tolist() for result_array in getattr(subclass, field)]
537+
538+
for field in bayes_results_fields["double_list_fields"][subclass_name]:
539+
subclass_dict[field] = [
540+
[result_array.tolist() for result_array in inner_list] for inner_list in getattr(subclass, field)
541+
]
542+
543+
for field in bayes_results_fields["array_fields"][subclass_name]:
544+
subclass_dict[field] = getattr(subclass, field).tolist()
545+
546+
json_dict[subclass_name] = subclass_dict
547+
548+
json_dict["chain"] = self.chain.tolist()
549+
filepath.write_text(json.dumps(json_dict))
550+
551+
@classmethod
552+
def load(cls, path: Union[str, Path]) -> "BayesResults":
553+
"""Load a BayesResults object from file.
554+
555+
Parameters
556+
----------
557+
path : str or Path
558+
The path to the results json file.
559+
"""
560+
path = Path(path)
561+
input_data = path.read_text()
562+
results_dict = json.loads(input_data)
563+
564+
results_dict = read_core_results_fields(results_dict)
565+
566+
# Take each of the subclasses in a BayesResults instance and convert to numpy arrays where necessary
567+
for subclass_name in bayes_results_subclasses:
568+
subclass_dict = {}
569+
570+
for field in bayes_results_fields["param_fields"][subclass_name]:
571+
subclass_dict[field] = results_dict[subclass_name][field]
572+
573+
for field in bayes_results_fields["list_fields"][subclass_name]:
574+
subclass_dict[field] = [np.array(result_array) for result_array in results_dict[subclass_name][field]]
575+
576+
for field in bayes_results_fields["double_list_fields"][subclass_name]:
577+
subclass_dict[field] = [
578+
[np.array(result_array) for result_array in inner_list]
579+
for inner_list in results_dict[subclass_name][field]
580+
]
581+
582+
for field in bayes_results_fields["array_fields"][subclass_name]:
583+
subclass_dict[field] = np.array(results_dict[subclass_name][field])
584+
585+
results_dict[subclass_name] = subclass_dict
586+
587+
return BayesResults(
588+
reflectivity=results_dict["reflectivity"],
589+
simulation=results_dict["simulation"],
590+
shiftedData=results_dict["shiftedData"],
591+
backgrounds=results_dict["backgrounds"],
592+
resolutions=results_dict["resolutions"],
593+
sldProfiles=results_dict["sldProfiles"],
594+
layers=results_dict["layers"],
595+
resampledLayers=results_dict["resampledLayers"],
596+
calculationResults=CalculationResults(**results_dict["calculationResults"]),
597+
contrastParams=ContrastParams(**results_dict["contrastParams"]),
598+
fitParams=np.array(results_dict["fitParams"]),
599+
fitNames=results_dict["fitNames"],
600+
predictionIntervals=PredictionIntervals(**results_dict["predictionIntervals"]),
601+
confidenceIntervals=ConfidenceIntervals(**results_dict["confidenceIntervals"]),
602+
dreamParams=DreamParams(**results_dict["dreamParams"]),
603+
dreamOutput=DreamOutput(**results_dict["dreamOutput"]),
604+
nestedSamplerOutput=NestedSamplerOutput(**results_dict["nestedSamplerOutput"]),
605+
chain=np.array(results_dict["chain"]),
606+
)
607+
608+
609+
def write_core_results_fields(results: Union[Results, BayesResults], json_dict: Optional[dict] = None) -> dict:
610+
"""Modify the values of the fields that appear in both Results and BayesResults when saving to a json file."""
611+
if json_dict is None:
612+
json_dict = {}
613+
614+
for field in results_fields["list_fields"]:
615+
json_dict[field] = [result_array.tolist() for result_array in getattr(results, field)]
616+
617+
for field in results_fields["double_list_fields"]:
618+
json_dict[field] = [
619+
[result_array.tolist() for result_array in inner_list] for inner_list in getattr(results, field)
620+
]
621+
622+
json_dict["calculationResults"] = {}
623+
json_dict["calculationResults"]["chiValues"] = results.calculationResults.chiValues.tolist()
624+
json_dict["calculationResults"]["sumChi"] = results.calculationResults.sumChi
625+
626+
json_dict["contrastParams"] = {}
627+
for field in results.contrastParams.__dict__:
628+
json_dict["contrastParams"][field] = getattr(results.contrastParams, field).tolist()
629+
630+
json_dict["fitParams"] = results.fitParams.tolist()
631+
json_dict["fitNames"] = results.fitNames
632+
633+
return json_dict
634+
635+
636+
def read_core_results_fields(results_dict: dict) -> dict:
637+
"""Modify the values of the fields that appear in both Results and BayesResults when loading a json file."""
638+
for field in results_fields["list_fields"]:
639+
results_dict[field] = [np.array(result_array) for result_array in results_dict[field]]
640+
641+
for field in results_fields["double_list_fields"]:
642+
results_dict[field] = [
643+
[np.array(result_array) for result_array in inner_list] for inner_list in results_dict[field]
644+
]
645+
646+
results_dict["calculationResults"]["chiValues"] = np.array(results_dict["calculationResults"]["chiValues"])
647+
648+
for field in results_dict["contrastParams"]:
649+
results_dict["contrastParams"][field] = np.array(results_dict["contrastParams"][field])
650+
651+
return results_dict
652+
408653

409654
def make_results(
410655
procedure: Procedures,

RATapi/project.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -906,7 +906,6 @@ def save(self, filepath: Union[str, Path] = "./project.json"):
906906
----------
907907
filepath : str or Path
908908
The path to where the project file will be written.
909-
910909
"""
911910
filepath = Path(filepath).with_suffix(".json")
912911

tests/test_outputs.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
We use the example for both a reflectivity calculation, and Bayesian analysis using the Dream algorithm.
44
"""
55

6+
import tempfile
7+
from pathlib import Path
8+
69
import numpy as np
710
import pytest
811

@@ -203,3 +206,23 @@ def test_results_str(test_output_results, test_str, request) -> None:
203206
test_str = request.getfixturevalue(test_str)
204207

205208
assert test_output_results.__str__() == test_str
209+
210+
211+
@pytest.mark.parametrize(
212+
["result_class", "test_results"],
213+
[
214+
(RATapi.Results, "reflectivity_calculation_results"),
215+
(RATapi.BayesResults, "dream_results"),
216+
],
217+
)
218+
def test_save_load(result_class, test_results, request):
219+
"""Test that saving and loading an output object returns the same object."""
220+
test_results = request.getfixturevalue(test_results)
221+
222+
with tempfile.TemporaryDirectory() as tmp:
223+
# ignore relative path warnings
224+
path = Path(tmp, "results.json")
225+
test_results.save(path)
226+
loaded_results = result_class.load(path)
227+
228+
check_results_equal(test_results, loaded_results)

0 commit comments

Comments
 (0)