|
2 | 2 |
|
3 | 3 | import contextlib |
4 | 4 | import os |
| 5 | +from pathlib import Path |
5 | 6 | import tempfile |
6 | 7 | from typing import Any, Union |
7 | 8 |
|
@@ -44,6 +45,28 @@ def test_extra_property_error() -> None: |
44 | 45 | with pytest.raises(pydantic.ValidationError, match="Object has no attribute 'test'"): |
45 | 46 | controls.test = 1 |
46 | 47 |
|
| 48 | +@pytest.mark.parametrize( |
| 49 | + "inputs", |
| 50 | + [ |
| 51 | + {"parallel": Parallel.Contrasts, "resampleMinAngle": 0.66}, |
| 52 | + {"procedure": 'simplex'}, |
| 53 | + {"procedure": 'dream', "nSamples": 504, "nChains": 1200}, |
| 54 | + {"procedure": 'de', "crossoverProbability": 0.45, "strategy": Strategies.RandomEitherOrAlgorithm}, |
| 55 | + {"procedure": 'ns', "nMCMC": 4, "propScale": 0.6}, |
| 56 | + ], |
| 57 | +) |
| 58 | +def test_save_load(inputs): |
| 59 | + """Test that saving and loading controls returns the same controls.""" |
| 60 | + |
| 61 | + original_controls = Controls(**inputs) |
| 62 | + with tempfile.TemporaryDirectory() as tmp: |
| 63 | + # ignore relative path warnings |
| 64 | + path = Path(tmp, "controls.json") |
| 65 | + original_controls.save(path) |
| 66 | + converted_controls = Controls.load(path) |
| 67 | + |
| 68 | + for field in Controls.model_fields: |
| 69 | + assert getattr(converted_controls, field) == getattr(original_controls, field) |
47 | 70 |
|
48 | 71 | class TestCalculate: |
49 | 72 | """Tests the Calculate class.""" |
|
0 commit comments