@@ -258,7 +258,7 @@ def save(self, filepath: Union[str, Path] = "./results.json"):
258258 filepath .write_text (json .dumps (json_dict ))
259259
260260 @classmethod
261- def load (cls , path : Union [str , Path ]) -> "Results" :
261+ def load (cls , path : Union [str , Path ]) -> Union [ "Results" , "BayesResults" ] :
262262 """Load a Results object from file.
263263
264264 Parameters
@@ -272,20 +272,45 @@ def load(cls, path: Union[str, Path]) -> "Results":
272272
273273 results_dict = read_core_results_fields (results_dict )
274274
275- return Results (
276- reflectivity = results_dict ["reflectivity" ],
277- simulation = results_dict ["simulation" ],
278- shiftedData = results_dict ["shiftedData" ],
279- backgrounds = results_dict ["backgrounds" ],
280- resolutions = results_dict ["resolutions" ],
281- sldProfiles = results_dict ["sldProfiles" ],
282- layers = results_dict ["layers" ],
283- resampledLayers = results_dict ["resampledLayers" ],
284- calculationResults = CalculationResults (** results_dict ["calculationResults" ]),
285- contrastParams = ContrastParams (** results_dict ["contrastParams" ]),
286- fitParams = np .array (results_dict ["fitParams" ]),
287- fitNames = results_dict ["fitNames" ],
288- )
275+ if all (key in results_dict for key in bayes_results_subclasses ):
276+ results_dict = read_bayes_results_fields (results_dict )
277+
278+ return BayesResults (
279+ reflectivity = results_dict ["reflectivity" ],
280+ simulation = results_dict ["simulation" ],
281+ shiftedData = results_dict ["shiftedData" ],
282+ backgrounds = results_dict ["backgrounds" ],
283+ resolutions = results_dict ["resolutions" ],
284+ sldProfiles = results_dict ["sldProfiles" ],
285+ layers = results_dict ["layers" ],
286+ resampledLayers = results_dict ["resampledLayers" ],
287+ calculationResults = CalculationResults (** results_dict ["calculationResults" ]),
288+ contrastParams = ContrastParams (** results_dict ["contrastParams" ]),
289+ fitParams = np .array (results_dict ["fitParams" ]),
290+ fitNames = results_dict ["fitNames" ],
291+ predictionIntervals = PredictionIntervals (** results_dict ["predictionIntervals" ]),
292+ confidenceIntervals = ConfidenceIntervals (** results_dict ["confidenceIntervals" ]),
293+ dreamParams = DreamParams (** results_dict ["dreamParams" ]),
294+ dreamOutput = DreamOutput (** results_dict ["dreamOutput" ]),
295+ nestedSamplerOutput = NestedSamplerOutput (** results_dict ["nestedSamplerOutput" ]),
296+ chain = np .array (results_dict ["chain" ]),
297+ )
298+
299+ else :
300+ return Results (
301+ reflectivity = results_dict ["reflectivity" ],
302+ simulation = results_dict ["simulation" ],
303+ shiftedData = results_dict ["shiftedData" ],
304+ backgrounds = results_dict ["backgrounds" ],
305+ resolutions = results_dict ["resolutions" ],
306+ sldProfiles = results_dict ["sldProfiles" ],
307+ layers = results_dict ["layers" ],
308+ resampledLayers = results_dict ["resampledLayers" ],
309+ calculationResults = CalculationResults (** results_dict ["calculationResults" ]),
310+ contrastParams = ContrastParams (** results_dict ["contrastParams" ]),
311+ fitParams = np .array (results_dict ["fitParams" ]),
312+ fitNames = results_dict ["fitNames" ],
313+ )
289314
290315
291316@dataclass
@@ -548,63 +573,6 @@ def save(self, filepath: Union[str, Path] = "./results.json"):
548573 json_dict ["chain" ] = self .chain .tolist ()
549574 filepath .write_text (json .dumps (json_dict ))
550575
551- @classmethod
552- def load (cls , path : Union [str , Path ]) -> "BayesResults" :
553- """Load a BayesResults object from file.
554-
555- Parameters
556- ----------
557- path : str or Path
558- The path to the results json file.
559- """
560- path = Path (path )
561- input_data = path .read_text ()
562- results_dict = json .loads (input_data )
563-
564- results_dict = read_core_results_fields (results_dict )
565-
566- # Take each of the subclasses in a BayesResults instance and convert to numpy arrays where necessary
567- for subclass_name in bayes_results_subclasses :
568- subclass_dict = {}
569-
570- for field in bayes_results_fields ["param_fields" ][subclass_name ]:
571- subclass_dict [field ] = results_dict [subclass_name ][field ]
572-
573- for field in bayes_results_fields ["list_fields" ][subclass_name ]:
574- subclass_dict [field ] = [np .array (result_array ) for result_array in results_dict [subclass_name ][field ]]
575-
576- for field in bayes_results_fields ["double_list_fields" ][subclass_name ]:
577- subclass_dict [field ] = [
578- [np .array (result_array ) for result_array in inner_list ]
579- for inner_list in results_dict [subclass_name ][field ]
580- ]
581-
582- for field in bayes_results_fields ["array_fields" ][subclass_name ]:
583- subclass_dict [field ] = np .array (results_dict [subclass_name ][field ])
584-
585- results_dict [subclass_name ] = subclass_dict
586-
587- return BayesResults (
588- reflectivity = results_dict ["reflectivity" ],
589- simulation = results_dict ["simulation" ],
590- shiftedData = results_dict ["shiftedData" ],
591- backgrounds = results_dict ["backgrounds" ],
592- resolutions = results_dict ["resolutions" ],
593- sldProfiles = results_dict ["sldProfiles" ],
594- layers = results_dict ["layers" ],
595- resampledLayers = results_dict ["resampledLayers" ],
596- calculationResults = CalculationResults (** results_dict ["calculationResults" ]),
597- contrastParams = ContrastParams (** results_dict ["contrastParams" ]),
598- fitParams = np .array (results_dict ["fitParams" ]),
599- fitNames = results_dict ["fitNames" ],
600- predictionIntervals = PredictionIntervals (** results_dict ["predictionIntervals" ]),
601- confidenceIntervals = ConfidenceIntervals (** results_dict ["confidenceIntervals" ]),
602- dreamParams = DreamParams (** results_dict ["dreamParams" ]),
603- dreamOutput = DreamOutput (** results_dict ["dreamOutput" ]),
604- nestedSamplerOutput = NestedSamplerOutput (** results_dict ["nestedSamplerOutput" ]),
605- chain = np .array (results_dict ["chain" ]),
606- )
607-
608576
609577def write_core_results_fields (results : Union [Results , BayesResults ], json_dict : Optional [dict ] = None ) -> dict :
610578 """Modify the values of the fields that appear in both Results and BayesResults when saving to a json file."""
@@ -651,6 +619,31 @@ def read_core_results_fields(results_dict: dict) -> dict:
651619 return results_dict
652620
653621
622+ def read_bayes_results_fields (results_dict : dict ) -> dict :
623+ """Modify the values of the fields that appear only in BayesResults when loading a json file."""
624+ for subclass_name in bayes_results_subclasses :
625+ subclass_dict = {}
626+
627+ for field in bayes_results_fields ["param_fields" ][subclass_name ]:
628+ subclass_dict [field ] = results_dict [subclass_name ][field ]
629+
630+ for field in bayes_results_fields ["list_fields" ][subclass_name ]:
631+ subclass_dict [field ] = [np .array (result_array ) for result_array in results_dict [subclass_name ][field ]]
632+
633+ for field in bayes_results_fields ["double_list_fields" ][subclass_name ]:
634+ subclass_dict [field ] = [
635+ [np .array (result_array ) for result_array in inner_list ]
636+ for inner_list in results_dict [subclass_name ][field ]
637+ ]
638+
639+ for field in bayes_results_fields ["array_fields" ][subclass_name ]:
640+ subclass_dict [field ] = np .array (results_dict [subclass_name ][field ])
641+
642+ results_dict [subclass_name ] = subclass_dict
643+
644+ return results_dict
645+
646+
654647def make_results (
655648 procedure : Procedures ,
656649 output_results : RATapi .rat_core .OutputResult ,
0 commit comments