Skip to content

Commit 815a399

Browse files
committed
moved json save/load to the project method
1 parent e88caf9 commit 815a399

File tree

4 files changed

+81
-107
lines changed

4 files changed

+81
-107
lines changed

RATapi/project.py

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import collections
44
import copy
55
import functools
6+
import json
67
from enum import Enum
78
from pathlib import Path
89
from textwrap import indent
@@ -23,7 +24,6 @@
2324

2425
import RATapi.models
2526
from RATapi.classlist import ClassList
26-
from RATapi.utils.convert import project_from_json, project_to_json
2727
from RATapi.utils.custom_errors import custom_pydantic_validation_error
2828
from RATapi.utils.enums import Calculations, Geometries, LayerModels, Priors, TypeOptions
2929

@@ -810,20 +810,55 @@ def classlist_script(name, classlist):
810810
+ "\n)"
811811
)
812812

813-
def save(self, path: str | Path, filename: str = "project"):
813+
def save(self, path: Union[str, Path], filename: str = "project"):
814814
"""Save a project to a JSON file.
815815
816816
Parameters
817817
----------
818818
path : str or Path
819-
The directory in which the project will be written.
819+
The path in which the project will be written.
820+
filename : str
821+
The name of the generated project file.
820822
821823
"""
824+
json_dict = {}
825+
for field in self.model_fields:
826+
attr = getattr(self, field)
827+
828+
if field == "data":
829+
830+
def make_data_dict(item):
831+
return {
832+
"name": item.name,
833+
"data": item.data.tolist(),
834+
"data_range": item.data_range,
835+
"simulation_range": item.simulation_range,
836+
}
837+
838+
json_dict["data"] = [make_data_dict(data) for data in attr]
839+
840+
elif field == "custom_files":
841+
842+
def make_custom_file_dict(item):
843+
return {
844+
"name": item.name,
845+
"filename": item.filename,
846+
"language": item.language,
847+
"path": str(item.path),
848+
}
849+
850+
json_dict["custom_files"] = [make_custom_file_dict(file) for file in attr]
851+
852+
elif isinstance(attr, ClassList):
853+
json_dict[field] = [dict(item) for item in attr]
854+
else:
855+
json_dict[field] = attr
856+
822857
file = Path(path, f"{filename.removesuffix('.json')}.json")
823-
file.write_text(project_to_json(self))
858+
file.write_text(json.dumps(json_dict))
824859

825860
@classmethod
826-
def load(cls, path: str | Path) -> "Project":
861+
def load(cls, path: Union[str, Path]) -> "Project":
827862
"""Load a project from file.
828863
829864
Parameters
@@ -832,8 +867,17 @@ def load(cls, path: str | Path) -> "Project":
832867
The path to the project file.
833868
834869
"""
835-
file = Path(path)
836-
return project_from_json(file.read_text())
870+
input = Path(path).read_text()
871+
model_dict = json.loads(input)
872+
for i in range(0, len(model_dict["data"])):
873+
if model_dict["data"][i]["name"] == "Simulation":
874+
model_dict["data"][i]["data"] = np.empty([0, 3])
875+
del model_dict["data"][i]["data_range"]
876+
else:
877+
data = model_dict["data"][i]["data"]
878+
model_dict["data"][i]["data"] = np.array(data)
879+
880+
return cls.model_validate(model_dict)
837881

838882
def _classlist_wrapper(self, class_list: ClassList, func: Callable):
839883
"""Defines the function used to wrap around ClassList routines to force revalidation.

RATapi/utils/convert.py

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Utilities for converting input files to Python `Project`s."""
22

3-
import json
43
import warnings
54
from collections.abc import Iterable
65
from os import PathLike
@@ -553,72 +552,3 @@ def convert_parameters(
553552
eng.save(str(filename), "problem", nargout=0)
554553
eng.exit()
555554
return None
556-
557-
558-
def project_to_json(project: Project) -> str:
559-
"""Write a Project as a JSON file.
560-
561-
Parameters
562-
----------
563-
project : Project
564-
The input Project object to convert.
565-
566-
Returns
567-
-------
568-
str
569-
A string representing the class in JSON format.
570-
"""
571-
json_dict = {}
572-
for field in project.model_fields:
573-
attr = getattr(project, field)
574-
575-
if field == "data":
576-
577-
def make_data_dict(item):
578-
return {
579-
"name": item.name,
580-
"data": item.data.tolist(),
581-
"data_range": item.data_range,
582-
"simulation_range": item.simulation_range,
583-
}
584-
585-
json_dict["data"] = [make_data_dict(data) for data in attr]
586-
587-
elif field == "custom_files":
588-
589-
def make_custom_file_dict(item):
590-
return {"name": item.name, "filename": item.filename, "language": item.language, "path": str(item.path)}
591-
592-
json_dict["custom_files"] = [make_custom_file_dict(file) for file in attr]
593-
594-
elif isinstance(attr, ClassList):
595-
json_dict[field] = [dict(item) for item in attr]
596-
else:
597-
json_dict[field] = attr
598-
599-
return json.dumps(json_dict)
600-
601-
602-
def project_from_json(input: str) -> Project:
603-
"""Read a Project from a JSON string generated by `to_json`.
604-
605-
Parameters
606-
----------
607-
input : str
608-
The JSON input as a string.
609-
610-
Returns
611-
-------
612-
Project
613-
The project corresponding to that JSON input.
614-
"""
615-
model_dict = json.loads(input)
616-
for i in range(0, len(model_dict["data"])):
617-
if model_dict["data"][i]["name"] == "Simulation":
618-
model_dict["data"][i]["data"] = empty([0, 3])
619-
del model_dict["data"][i]["data_range"]
620-
else:
621-
data = model_dict["data"][i]["data"]
622-
model_dict["data"][i]["data"] = array(data)
623-
624-
return Project.model_validate(model_dict)

tests/test_convert.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytest
99

1010
import RATapi
11-
from RATapi.utils.convert import project_class_to_r1, project_from_json, project_to_json, r1_to_project_class
11+
from RATapi.utils.convert import project_class_to_r1, r1_to_project_class
1212

1313
TEST_DIR_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "test_data")
1414

@@ -110,35 +110,6 @@ def test_invalid_constraints():
110110
assert output_project.background_parameters[0].min == output_project.background_parameters[0].value
111111

112112

113-
@pytest.mark.parametrize(
114-
"project",
115-
[
116-
"r1_default_project",
117-
"r1_monolayer",
118-
"r1_monolayer_8_contrasts",
119-
"r1_orso_polymer",
120-
"r1_motofit_bench_mark",
121-
"dspc_bilayer",
122-
# "dspc_standard_layers",
123-
# "dspc_custom_layers",
124-
# "dspc_custom_xy",
125-
"domains_standard_layers",
126-
"domains_custom_layers",
127-
"domains_custom_xy",
128-
"absorption",
129-
],
130-
)
131-
def test_json_involution(project, request):
132-
"""Test that converting a Project to JSON and back returns the same project."""
133-
original_project = request.getfixturevalue(project)
134-
json_data = project_to_json(original_project)
135-
136-
converted_project = project_from_json(json_data)
137-
138-
for field in RATapi.Project.model_fields:
139-
assert getattr(converted_project, field) == getattr(original_project, field)
140-
141-
142113
@pytest.mark.skipif(importlib.util.find_spec("matlab") is None, reason="Matlab not installed")
143114
@pytest.mark.parametrize("path_type", [os.path.join, pathlib.Path])
144115
def test_matlab_save(path_type, request):

tests/test_project.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1536,3 +1536,32 @@ def test_wrap_extend(test_project, class_list: str, model_type: str, field: str,
15361536

15371537
# Ensure invalid model was not appended
15381538
assert test_attribute == orig_class_list
1539+
1540+
1541+
@pytest.mark.parametrize(
1542+
"project",
1543+
[
1544+
"r1_default_project",
1545+
"r1_monolayer",
1546+
"r1_monolayer_8_contrasts",
1547+
"r1_orso_polymer",
1548+
"r1_motofit_bench_mark",
1549+
"dspc_standard_layers",
1550+
"dspc_custom_layers",
1551+
"dspc_custom_xy",
1552+
"domains_standard_layers",
1553+
"domains_custom_layers",
1554+
"domains_custom_xy",
1555+
"absorption",
1556+
],
1557+
)
1558+
def test_save_load(project, request):
1559+
"""Test that saving and loading a project returns the same project."""
1560+
original_project = request.getfixturevalue(project)
1561+
1562+
with tempfile.TemporaryDirectory() as tmp:
1563+
original_project.save(tmp)
1564+
converted_project = RATapi.Project.load(Path(tmp, "project.json"))
1565+
1566+
for field in RATapi.Project.model_fields:
1567+
assert getattr(converted_project, field) == getattr(original_project, field)

0 commit comments

Comments
 (0)