Skip to content

Commit e9208d1

Browse files
committed
Unify save method for controls
1 parent f978f4c commit e9208d1

File tree

2 files changed

+28
-8
lines changed

2 files changed

+28
-8
lines changed

ratapi/controls.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -233,19 +233,16 @@ def delete_IPC(self):
233233
os.remove(self._IPCFilePath)
234234
return None
235235

236-
def save(self, path: Union[str, Path], filename: str = "controls"):
236+
def save(self, filepath: Union[str, Path] = "./controls.json"):
237237
"""Save a controls object to a JSON file.
238238
239239
Parameters
240240
----------
241-
path : str or Path
242-
The directory in which the controls object will be written.
243-
filename : str
244-
The name for the JSON file containing the controls object.
245-
241+
filepath : str or Path
242+
The path to where the controls file will be written.
246243
"""
247-
file = Path(path, f"{filename.removesuffix('.json')}.json")
248-
file.write_text(self.model_dump_json())
244+
filepath = Path(filepath).with_suffix(".json")
245+
filepath.write_text(self.model_dump_json())
249246

250247
@classmethod
251248
def load(cls, path: Union[str, Path]) -> "Controls":

tests/test_controls.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import contextlib
44
import os
5+
from pathlib import Path
56
import tempfile
67
from typing import Any, Union
78

@@ -44,6 +45,28 @@ def test_extra_property_error() -> None:
4445
with pytest.raises(pydantic.ValidationError, match="Object has no attribute 'test'"):
4546
controls.test = 1
4647

48+
@pytest.mark.parametrize(
49+
"inputs",
50+
[
51+
{"parallel": Parallel.Contrasts, "resampleMinAngle": 0.66},
52+
{"procedure": 'simplex'},
53+
{"procedure": 'dream', "nSamples": 504, "nChains": 1200},
54+
{"procedure": 'de', "crossoverProbability": 0.45, "strategy": Strategies.RandomEitherOrAlgorithm},
55+
{"procedure": 'ns', "nMCMC": 4, "propScale": 0.6},
56+
],
57+
)
58+
def test_save_load(inputs):
59+
"""Test that saving and loading controls returns the same controls."""
60+
61+
original_controls = Controls(**inputs)
62+
with tempfile.TemporaryDirectory() as tmp:
63+
# ignore relative path warnings
64+
path = Path(tmp, "controls.json")
65+
original_controls.save(path)
66+
converted_controls = Controls.load(path)
67+
68+
for field in Controls.model_fields:
69+
assert getattr(converted_controls, field) == getattr(original_controls, field)
4770

4871
class TestCalculate:
4972
"""Tests the Calculate class."""

0 commit comments

Comments
 (0)