Skip to content

Commit 9be7b4f

Browse files
committed
Merges load methods
1 parent c171762 commit 9be7b4f

File tree

2 files changed

+67
-79
lines changed

2 files changed

+67
-79
lines changed

RATapi/outputs.py

Lines changed: 65 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def save(self, filepath: Union[str, Path] = "./results.json"):
258258
filepath.write_text(json.dumps(json_dict))
259259

260260
@classmethod
261-
def load(cls, path: Union[str, Path]) -> "Results":
261+
def load(cls, path: Union[str, Path]) -> Union["Results", "BayesResults"]:
262262
"""Load a Results object from file.
263263
264264
Parameters
@@ -272,20 +272,45 @@ def load(cls, path: Union[str, Path]) -> "Results":
272272

273273
results_dict = read_core_results_fields(results_dict)
274274

275-
return Results(
276-
reflectivity=results_dict["reflectivity"],
277-
simulation=results_dict["simulation"],
278-
shiftedData=results_dict["shiftedData"],
279-
backgrounds=results_dict["backgrounds"],
280-
resolutions=results_dict["resolutions"],
281-
sldProfiles=results_dict["sldProfiles"],
282-
layers=results_dict["layers"],
283-
resampledLayers=results_dict["resampledLayers"],
284-
calculationResults=CalculationResults(**results_dict["calculationResults"]),
285-
contrastParams=ContrastParams(**results_dict["contrastParams"]),
286-
fitParams=np.array(results_dict["fitParams"]),
287-
fitNames=results_dict["fitNames"],
288-
)
275+
if all(key in results_dict for key in bayes_results_subclasses):
276+
results_dict = read_bayes_results_fields(results_dict)
277+
278+
return BayesResults(
279+
reflectivity=results_dict["reflectivity"],
280+
simulation=results_dict["simulation"],
281+
shiftedData=results_dict["shiftedData"],
282+
backgrounds=results_dict["backgrounds"],
283+
resolutions=results_dict["resolutions"],
284+
sldProfiles=results_dict["sldProfiles"],
285+
layers=results_dict["layers"],
286+
resampledLayers=results_dict["resampledLayers"],
287+
calculationResults=CalculationResults(**results_dict["calculationResults"]),
288+
contrastParams=ContrastParams(**results_dict["contrastParams"]),
289+
fitParams=np.array(results_dict["fitParams"]),
290+
fitNames=results_dict["fitNames"],
291+
predictionIntervals=PredictionIntervals(**results_dict["predictionIntervals"]),
292+
confidenceIntervals=ConfidenceIntervals(**results_dict["confidenceIntervals"]),
293+
dreamParams=DreamParams(**results_dict["dreamParams"]),
294+
dreamOutput=DreamOutput(**results_dict["dreamOutput"]),
295+
nestedSamplerOutput=NestedSamplerOutput(**results_dict["nestedSamplerOutput"]),
296+
chain=np.array(results_dict["chain"]),
297+
)
298+
299+
else:
300+
return Results(
301+
reflectivity=results_dict["reflectivity"],
302+
simulation=results_dict["simulation"],
303+
shiftedData=results_dict["shiftedData"],
304+
backgrounds=results_dict["backgrounds"],
305+
resolutions=results_dict["resolutions"],
306+
sldProfiles=results_dict["sldProfiles"],
307+
layers=results_dict["layers"],
308+
resampledLayers=results_dict["resampledLayers"],
309+
calculationResults=CalculationResults(**results_dict["calculationResults"]),
310+
contrastParams=ContrastParams(**results_dict["contrastParams"]),
311+
fitParams=np.array(results_dict["fitParams"]),
312+
fitNames=results_dict["fitNames"],
313+
)
289314

290315

291316
@dataclass
@@ -548,63 +573,6 @@ def save(self, filepath: Union[str, Path] = "./results.json"):
548573
json_dict["chain"] = self.chain.tolist()
549574
filepath.write_text(json.dumps(json_dict))
550575

551-
@classmethod
552-
def load(cls, path: Union[str, Path]) -> "BayesResults":
553-
"""Load a BayesResults object from file.
554-
555-
Parameters
556-
----------
557-
path : str or Path
558-
The path to the results json file.
559-
"""
560-
path = Path(path)
561-
input_data = path.read_text()
562-
results_dict = json.loads(input_data)
563-
564-
results_dict = read_core_results_fields(results_dict)
565-
566-
# Take each of the subclasses in a BayesResults instance and convert to numpy arrays where necessary
567-
for subclass_name in bayes_results_subclasses:
568-
subclass_dict = {}
569-
570-
for field in bayes_results_fields["param_fields"][subclass_name]:
571-
subclass_dict[field] = results_dict[subclass_name][field]
572-
573-
for field in bayes_results_fields["list_fields"][subclass_name]:
574-
subclass_dict[field] = [np.array(result_array) for result_array in results_dict[subclass_name][field]]
575-
576-
for field in bayes_results_fields["double_list_fields"][subclass_name]:
577-
subclass_dict[field] = [
578-
[np.array(result_array) for result_array in inner_list]
579-
for inner_list in results_dict[subclass_name][field]
580-
]
581-
582-
for field in bayes_results_fields["array_fields"][subclass_name]:
583-
subclass_dict[field] = np.array(results_dict[subclass_name][field])
584-
585-
results_dict[subclass_name] = subclass_dict
586-
587-
return BayesResults(
588-
reflectivity=results_dict["reflectivity"],
589-
simulation=results_dict["simulation"],
590-
shiftedData=results_dict["shiftedData"],
591-
backgrounds=results_dict["backgrounds"],
592-
resolutions=results_dict["resolutions"],
593-
sldProfiles=results_dict["sldProfiles"],
594-
layers=results_dict["layers"],
595-
resampledLayers=results_dict["resampledLayers"],
596-
calculationResults=CalculationResults(**results_dict["calculationResults"]),
597-
contrastParams=ContrastParams(**results_dict["contrastParams"]),
598-
fitParams=np.array(results_dict["fitParams"]),
599-
fitNames=results_dict["fitNames"],
600-
predictionIntervals=PredictionIntervals(**results_dict["predictionIntervals"]),
601-
confidenceIntervals=ConfidenceIntervals(**results_dict["confidenceIntervals"]),
602-
dreamParams=DreamParams(**results_dict["dreamParams"]),
603-
dreamOutput=DreamOutput(**results_dict["dreamOutput"]),
604-
nestedSamplerOutput=NestedSamplerOutput(**results_dict["nestedSamplerOutput"]),
605-
chain=np.array(results_dict["chain"]),
606-
)
607-
608576

609577
def write_core_results_fields(results: Union[Results, BayesResults], json_dict: Optional[dict] = None) -> dict:
610578
"""Modify the values of the fields that appear in both Results and BayesResults when saving to a json file."""
@@ -651,6 +619,31 @@ def read_core_results_fields(results_dict: dict) -> dict:
651619
return results_dict
652620

653621

622+
def read_bayes_results_fields(results_dict: dict) -> dict:
623+
"""Modify the values of the fields that appear only in BayesResults when loading a json file."""
624+
for subclass_name in bayes_results_subclasses:
625+
subclass_dict = {}
626+
627+
for field in bayes_results_fields["param_fields"][subclass_name]:
628+
subclass_dict[field] = results_dict[subclass_name][field]
629+
630+
for field in bayes_results_fields["list_fields"][subclass_name]:
631+
subclass_dict[field] = [np.array(result_array) for result_array in results_dict[subclass_name][field]]
632+
633+
for field in bayes_results_fields["double_list_fields"][subclass_name]:
634+
subclass_dict[field] = [
635+
[np.array(result_array) for result_array in inner_list]
636+
for inner_list in results_dict[subclass_name][field]
637+
]
638+
639+
for field in bayes_results_fields["array_fields"][subclass_name]:
640+
subclass_dict[field] = np.array(results_dict[subclass_name][field])
641+
642+
results_dict[subclass_name] = subclass_dict
643+
644+
return results_dict
645+
646+
654647
def make_results(
655648
procedure: Procedures,
656649
output_results: RATapi.rat_core.OutputResult,

tests/test_outputs.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -208,13 +208,8 @@ def test_results_str(test_output_results, test_str, request) -> None:
208208
assert test_output_results.__str__() == test_str
209209

210210

211-
@pytest.mark.parametrize(
212-
["result_class", "test_results"],
213-
[
214-
(RATapi.Results, "reflectivity_calculation_results"),
215-
(RATapi.BayesResults, "dream_results"),
216-
],
217-
)
211+
@pytest.mark.parametrize("result_class", [RATapi.Results, RATapi.BayesResults])
212+
@pytest.mark.parametrize("test_results", ["reflectivity_calculation_results", "dream_results"])
218213
def test_save_load(result_class, test_results, request):
219214
"""Test that saving and loading an output object returns the same object."""
220215
test_results = request.getfixturevalue(test_results)

0 commit comments

Comments
 (0)