Skip to content

Commit 6790470

Browse files
committed
refactor: remove redundant helpers and use setup_test_experiment
1 parent 8ef966a commit 6790470

2 files changed

Lines changed: 67 additions & 107 deletions

File tree

tests/create_experiment.py

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
import copy
22
import shutil
33
from contextlib import contextmanager
4-
from pathlib import Path
5-
6-
import numpy as np
7-
import yaml
84

95
from .create_sequence_data import _generate_sequence_data
106

@@ -52,63 +48,3 @@ def setup_test_experiment(
5248
finally:
5349
if tmp_path.exists():
5450
shutil.rmtree(tmp_path)
55-
56-
57-
@contextmanager
58-
def make_sequence_device(
59-
root, name, start, end, sampling_rate=10.0, n_signals=5, override_meta=None
60-
):
61-
"""Create a single sequence device folder under root."""
62-
device_root = root / name
63-
try:
64-
(device_root / "meta").mkdir(parents=True, exist_ok=True)
65-
n_samples = int((end - start) * sampling_rate) + 1
66-
timestamps = np.linspace(start, end, n_samples)
67-
data = np.random.rand(n_samples, n_signals)
68-
69-
np.save(device_root / "timestamps.npy", timestamps)
70-
np.save(device_root / "data.npy", data)
71-
72-
meta = {
73-
"start_time": start,
74-
"end_time": end,
75-
"modality": "sequence",
76-
"sampling_rate": sampling_rate,
77-
"phase_shift_per_signal": False,
78-
"is_mem_mapped": False,
79-
"n_signals": n_signals,
80-
"n_timestamps": n_samples,
81-
"dtype": "float64",
82-
}
83-
if override_meta:
84-
meta.update(override_meta)
85-
with open(device_root / "meta.yml", "w") as f:
86-
yaml.safe_dump(meta, f)
87-
yield device_root
88-
finally:
89-
if device_root.exists():
90-
shutil.rmtree(device_root)
91-
92-
93-
def make_modality_config(*device_names, sampling_rates=None, offsets=None):
94-
if sampling_rates is None:
95-
sampling_rates = [10.0] * len(device_names)
96-
elif isinstance(sampling_rates, (int, float)):
97-
sampling_rates = [sampling_rates] * len(device_names)
98-
99-
if offsets is None:
100-
offsets = [0.0] * len(device_names)
101-
elif isinstance(offsets, (int, float)):
102-
offsets = [offsets] * len(device_names)
103-
104-
assert len(device_names) == len(
105-
sampling_rates
106-
), f"sampling_rates length {len(sampling_rates)} does not match device_names length {len(device_names)}"
107-
assert len(device_names) == len(
108-
offsets
109-
), f"offsets length {len(offsets)} does not match device_names length {len(device_names)}"
110-
111-
return {
112-
name: {"interpolation": {"sampling_rate": sr, "offset": off}}
113-
for name, sr, off in zip(device_names, sampling_rates, offsets)
114-
}

tests/test_experiment.py

Lines changed: 67 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
11
import logging
2-
import shutil
3-
from contextlib import ExitStack, contextmanager
4-
from pathlib import Path
52
from unittest.mock import MagicMock
63

74
import numpy as np
@@ -15,8 +12,6 @@
1512

1613
from .create_experiment import (
1714
get_default_config,
18-
make_modality_config,
19-
make_sequence_device,
2015
setup_test_experiment,
2116
)
2217

@@ -236,40 +231,56 @@ def test_experiment_start_end_time_reflects_union(
236231
tmp_path, device_ranges, expected_start, expected_end, n_signals
237232
):
238233
"""Experiment.start_time and end_time should reflect the union of all device time ranges."""
239-
device_names = [f"device_{i}" for i in range(len(device_ranges))]
240-
with ExitStack() as stack:
241-
for name, (start, end) in zip(device_names, device_ranges):
242-
stack.enter_context(
243-
make_sequence_device(
244-
tmp_path,
245-
name,
246-
start=start,
247-
end=end,
248-
n_signals=n_signals,
249-
sampling_rate=float(np.random.randint(5, 30)),
250-
)
251-
)
234+
devices_kwargs = [
235+
{
236+
"start_time": start,
237+
"t_end": end,
238+
"n_signals": n_signals,
239+
"sampling_rate": float(np.random.randint(5, 30)),
240+
}
241+
for start, end in device_ranges
242+
]
243+
244+
with setup_test_experiment(
245+
tmp_path, n_devices=len(device_ranges), devices_kwargs=devices_kwargs
246+
) as experiment_path:
247+
# Manually build the config dict for however many devices were generated
248+
config = {}
249+
for i in range(len(device_ranges)):
250+
config[f"device_{i}"] = {
251+
"interpolation": {
252+
"sampling_rate": 10.0,
253+
"offset": float(np.random.rand()),
254+
}
255+
}
256+
252257
experiment = Experiment(
253-
root_folder=tmp_path,
254-
modality_config=make_modality_config(
255-
*device_names, offsets=[float(np.random.rand()) for _ in device_names]
256-
),
258+
root_folder=str(experiment_path), modality_config=config
257259
)
260+
258261
assert experiment.start_time == pytest.approx(expected_start)
259262
assert experiment.end_time == pytest.approx(expected_end)
260263

261264

262265
@pytest.mark.parametrize("override_meta", INVALID_META_CASES, ids=INVALID_META_IDS)
263266
def test_experiment_invalid_metadata(tmp_path, override_meta):
264-
with make_sequence_device(
265-
tmp_path, "device_0", start=0.0, end=10.0, override_meta=override_meta
266-
):
267+
with setup_test_experiment(
268+
tmp_path, n_devices=1, devices_kwargs=[{"start_time": 0.0, "t_end": 10.0}]
269+
) as experiment_path:
270+
# Explicitly corrupt the generated metadata file
271+
meta_file = experiment_path / "device_0" / "meta.yml"
272+
with open(meta_file, "r") as f:
273+
meta = yaml.safe_load(f)
274+
meta.update(override_meta)
275+
with open(meta_file, "w") as f:
276+
yaml.safe_dump(meta, f)
277+
278+
config = {"device_0": {"interpolation": {"sampling_rate": 10.0}}}
279+
267280
with pytest.raises(
268281
ValueError, match="Experiment time range could not be determined"
269282
):
270-
Experiment(
271-
root_folder=tmp_path, modality_config=make_modality_config("device_0")
272-
)
283+
Experiment(root_folder=str(experiment_path), modality_config=config)
273284

274285

275286
@pytest.mark.parametrize("override_meta", INVALID_META_CASES, ids=INVALID_META_IDS)
@@ -279,24 +290,37 @@ def test_experiment_skips_invalid_devices(tmp_path, override_meta, caplog):
279290
np.random.lognormal(0.0, 1.0),
280291
)
281292
end_val = start_val + duration_val
282-
with ExitStack() as stack:
283-
stack.enter_context(
284-
make_sequence_device(tmp_path, "valid_device", start=start_val, end=end_val)
285-
)
286-
stack.enter_context(
287-
make_sequence_device(
288-
tmp_path,
289-
"invalid_device",
290-
start=0.0,
291-
end=10.0,
292-
override_meta=override_meta,
293-
)
294-
)
293+
294+
devices_kwargs = [
295+
{"start_time": start_val, "t_end": end_val}, # valid device
296+
{"start_time": 0.0, "t_end": 10.0}, # invalid device
297+
]
298+
299+
with setup_test_experiment(
300+
tmp_path, n_devices=2, devices_kwargs=devices_kwargs
301+
) as experiment_path:
302+
# Rename the folders to match what the old test expected
303+
(experiment_path / "device_0").rename(experiment_path / "valid_device")
304+
(experiment_path / "device_1").rename(experiment_path / "invalid_device")
305+
306+
# Explicitly corrupt the metadata file for the invalid device
307+
meta_file = experiment_path / "invalid_device" / "meta.yml"
308+
with open(meta_file, "r") as f:
309+
meta = yaml.safe_load(f)
310+
meta.update(override_meta)
311+
with open(meta_file, "w") as f:
312+
yaml.safe_dump(meta, f)
313+
314+
config = {
315+
"valid_device": {"interpolation": {"sampling_rate": 10.0}},
316+
"invalid_device": {"interpolation": {"sampling_rate": 10.0}},
317+
}
318+
295319
with caplog.at_level(logging.WARNING, logger="experanto.experiment"):
296320
experiment = Experiment(
297-
root_folder=tmp_path,
298-
modality_config=make_modality_config("valid_device", "invalid_device"),
321+
root_folder=str(experiment_path), modality_config=config
299322
)
323+
300324
assert "valid_device" in experiment.devices
301325
assert "invalid_device" not in experiment.devices
302326
assert experiment.start_time == pytest.approx(start_val)

0 commit comments

Comments
 (0)