Skip to content

Commit 3d940e7

Browse files
committed
Resolve test helper merge conflicts
2 parents 50664ef + d61d9b8 commit 3d940e7

3 files changed

Lines changed: 299 additions & 7 deletions

File tree

experanto/experiment.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,33 @@ def _load_devices(self) -> None:
118118
**{str(k): v for k, v in dict(interp_conf).items()},
119119
)
120120

121-
self.devices[d.name] = dev
122-
self.start_time = dev.start_time
123-
self.end_time = dev.end_time
121+
if (
122+
dev.start_time is None
123+
or dev.end_time is None
124+
or not np.isfinite(dev.start_time)
125+
or not np.isfinite(dev.end_time)
126+
):
127+
logger.warning(
128+
"Device %s has undefined start_time or end_time and will be "
129+
"excluded from the experiment-wide time range.",
130+
d.name,
131+
)
132+
else:
133+
self.start_time = min(self.start_time, dev.start_time)
134+
self.end_time = max(self.end_time, dev.end_time)
135+
self.devices[d.name] = dev
124136
logger.info("Parsing finished")
125137

138+
if not self.devices:
139+
raise ValueError(
140+
"Experiment time range could not be determined: no devices with valid start_time and end_time were found."
141+
)
142+
elif self.start_time > self.end_time:
143+
raise ValueError(
144+
"Experiment time range could not be determined: at least one device "
145+
"must define finite start_time and end_time."
146+
)
147+
126148
@property
127149
def device_names(self):
128150
return tuple(self.devices.keys())

tests/create_experiment.py

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
import shutil
33
from contextlib import contextmanager
44

5+
import numpy as np
6+
import yaml
7+
58
from .create_sequence_data import _generate_sequence_data
69

710
DEFAULT_CONFIG = {
@@ -45,11 +48,73 @@ def create_experiment(
4548

4649
for device_id, device_kwargs in enumerate(devices_kwargs):
4750
device_path = tmp_path / f"device_{device_id}"
48-
_generate_sequence_data(
49-
str(device_path), **device_kwargs
50-
) # pyright: ignore
51+
_generate_sequence_data(str(device_path), **device_kwargs) # pyright: ignore
5152

5253
yield tmp_path
5354
finally:
5455
if tmp_path.exists():
5556
shutil.rmtree(tmp_path)
57+
58+
59+
@contextmanager
60+
def make_sequence_device(
61+
root, name, start, end, sampling_rate=10.0, n_signals=5, override_meta=None
62+
):
63+
"""Create a single sequence device folder under root."""
64+
device_root = root / name
65+
try:
66+
(device_root / "meta").mkdir(parents=True, exist_ok=True)
67+
68+
n_samples = (
69+
int((end - start) * sampling_rate) + 1
70+
) # +1 to include both start and end as sample points
71+
timestamps = np.linspace(start, end, n_samples)
72+
data = np.random.rand(n_samples, n_signals)
73+
74+
np.save(device_root / "timestamps.npy", timestamps)
75+
np.save(device_root / "data.npy", data)
76+
77+
meta = {
78+
"start_time": start,
79+
"end_time": end,
80+
"modality": "sequence",
81+
"sampling_rate": sampling_rate,
82+
"phase_shift_per_signal": False,
83+
"is_mem_mapped": False,
84+
"n_signals": n_signals,
85+
"n_timestamps": n_samples,
86+
"dtype": "float64",
87+
}
88+
if override_meta:
89+
meta.update(override_meta)
90+
with open(device_root / "meta.yml", "w") as f:
91+
yaml.safe_dump(meta, f)
92+
93+
yield device_root
94+
95+
finally:
96+
shutil.rmtree(device_root)
97+
98+
99+
def make_modality_config(*device_names, sampling_rates=None, offsets=None):
100+
if sampling_rates is None:
101+
sampling_rates = [10.0] * len(device_names)
102+
elif isinstance(sampling_rates, (int, float)):
103+
sampling_rates = [sampling_rates] * len(device_names)
104+
105+
if offsets is None:
106+
offsets = [0.0] * len(device_names)
107+
elif isinstance(offsets, (int, float)):
108+
offsets = [offsets] * len(device_names)
109+
110+
assert len(device_names) == len(sampling_rates), (
111+
f"sampling_rates length {len(sampling_rates)} does not match device_names length {len(device_names)}"
112+
)
113+
assert len(device_names) == len(offsets), (
114+
f"offsets length {len(offsets)} does not match device_names length {len(device_names)}"
115+
)
116+
117+
return {
118+
name: {"interpolation": {"sampling_rate": sr, "offset": off}}
119+
for name, sr, off in zip(device_names, sampling_rates, offsets, strict=True)
120+
}

tests/test_experiment.py

Lines changed: 206 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import logging
2+
from contextlib import ExitStack
13
from unittest.mock import MagicMock
24

35
import numpy as np
@@ -6,7 +8,12 @@
68
from experanto.experiment import Experiment
79
from experanto.interpolators import Interpolator
810

9-
from .create_experiment import create_experiment, get_default_config
11+
from .create_experiment import (
12+
create_experiment,
13+
get_default_config,
14+
make_modality_config,
15+
make_sequence_device,
16+
)
1017

1118

1219
class DummyInterpolator(Interpolator):
@@ -180,3 +187,201 @@ def test_experiment_multi_device_interpolation(tmp_path, return_valid, device):
180187
else:
181188
assert isinstance(data, np.ndarray)
182189
assert data.shape == (2, 10)
190+
191+
192+
DEVICE_TIME_RANGE_CASES = [
193+
# Single device: start and end should match that device's range
194+
([(2.0, 9.0)], 2.0, 9.0),
195+
# Two devices with different ranges: start should be min, end should be max
196+
([(1.0, 8.0), (0.0, 10.0)], 0.0, 10.0),
197+
# Three devices with different ranges: start should be min, end should be max
198+
([(0.0, 10.0), (1.0, 8.0), (2.0, 9.0)], 0.0, 10.0),
199+
# Devices with non-overlapping ranges: start should be min, end should be max
200+
([(0.0, 3.0), (7.0, 8.0)], 0.0, 8.0),
201+
# Devices with identical ranges: start and end should match that range
202+
([(1.0, 5.0), (1.0, 5.0)], 1.0, 5.0),
203+
# Large time stamps: start should be min, end should be max
204+
([(1e9, 1e9 + 100), (1e9 - 50, 1e9 + 50)], 1e9 - 50, 1e9 + 100),
205+
]
206+
207+
DEVICE_TIME_RANGE_IDS = [
208+
"single_device",
209+
"two_devices_different_ranges",
210+
"three_devices_different_ranges",
211+
"non_overlapping_ranges",
212+
"identical_ranges",
213+
"large_time_stamps",
214+
]
215+
216+
# Inverted range is intentionally separate from INVALID_META_CASES —
217+
# None/NaN/inf are caught per-device before being added to self.devices,
218+
# whereas start > end is only caught after all devices are loaded.
219+
INVALID_META_CASES = [
220+
{"start_time": None, "end_time": None}, # Both missing
221+
{"start_time": None, "end_time": 10.0}, # Missing start_time
222+
{"start_time": 0.0, "end_time": None}, # Missing end_time
223+
{"start_time": float("inf"), "end_time": 10.0}, # Infinite start_time
224+
{"start_time": 0.0, "end_time": float("inf")}, # Infinite end_time
225+
{"start_time": float("-inf"), "end_time": 10.0}, # Negative Infinite start_time
226+
{"start_time": 0.0, "end_time": float("-inf")}, # Negative Infinite end_time
227+
{"start_time": float("nan"), "end_time": 10.0}, # NaN start_time
228+
{"start_time": 0.0, "end_time": float("nan")}, # NaN end_time
229+
]
230+
231+
INVALID_META_IDS = [
232+
"both_missing",
233+
"missing_start_time",
234+
"missing_end_time",
235+
"infinite_start_time",
236+
"infinite_end_time",
237+
"negative_infinite_start_time",
238+
"negative_infinite_end_time",
239+
"nan_start_time",
240+
"nan_end_time",
241+
]
242+
243+
244+
# Test for union of device time ranges
245+
@pytest.mark.parametrize("n_signals", [5, 20])
246+
@pytest.mark.parametrize(
247+
"device_ranges, expected_start, expected_end",
248+
DEVICE_TIME_RANGE_CASES,
249+
ids=DEVICE_TIME_RANGE_IDS,
250+
)
251+
def test_experiment_start_end_time_reflects_union(
252+
tmp_path, device_ranges, expected_start, expected_end, n_signals
253+
):
254+
"""
255+
Experiment.start_time and end_time should reflect the union of all
256+
device time ranges — earliest start and latest end across all devices.
257+
"""
258+
device_names = [f"device_{i}" for i in range(len(device_ranges))]
259+
260+
with ExitStack() as stack:
261+
for name, (start, end) in zip(device_names, device_ranges, strict=True):
262+
stack.enter_context(
263+
make_sequence_device(
264+
tmp_path,
265+
name,
266+
start=start,
267+
end=end,
268+
n_signals=n_signals,
269+
sampling_rate=float(np.random.randint(5, 30)),
270+
)
271+
)
272+
273+
experiment = Experiment(
274+
root_folder=tmp_path,
275+
modality_config=make_modality_config(
276+
*device_names, offsets=[float(np.random.rand()) for _ in device_names]
277+
),
278+
)
279+
280+
assert experiment.start_time == (expected_start), (
281+
f"Expected start_time={expected_start}, got {experiment.start_time}"
282+
)
283+
assert experiment.end_time == (expected_end), (
284+
f"Expected end_time={expected_end}, got {experiment.end_time}"
285+
)
286+
287+
288+
# Safety check
289+
@pytest.mark.parametrize("override_meta", INVALID_META_CASES, ids=INVALID_META_IDS)
290+
def test_experiment_invalid_metadata(tmp_path, override_meta):
291+
"""
292+
Experiment should raise an error when initialized with invalid metadata.
293+
Covers cases where start_time or end_time is None, NaN, or infinite.
294+
"""
295+
with make_sequence_device(
296+
tmp_path,
297+
"device_0",
298+
start=0.0,
299+
end=10.0,
300+
override_meta=override_meta,
301+
):
302+
with pytest.raises(
303+
ValueError, match="Experiment time range could not be determined"
304+
):
305+
Experiment(
306+
root_folder=tmp_path,
307+
modality_config=make_modality_config("device_0"),
308+
)
309+
310+
311+
def test_experiment_inverted_time_range_raises(tmp_path):
312+
"""
313+
Experiment should raise ValueError when start_time > end_time.
314+
This is a separate guard from invalid metadata (None/NaN/inf) because it
315+
only becomes apparent after all devices are loaded and the overall time range is computed.
316+
"""
317+
with make_sequence_device(
318+
tmp_path,
319+
"device_0",
320+
start=0.0,
321+
end=10.0,
322+
override_meta={"start_time": 5.0, "end_time": 2.0},
323+
):
324+
with pytest.raises(
325+
ValueError, match="Experiment time range could not be determined"
326+
):
327+
Experiment(
328+
root_folder=tmp_path,
329+
modality_config=make_modality_config("device_0"),
330+
)
331+
332+
333+
@pytest.mark.parametrize("override_meta", INVALID_META_CASES, ids=INVALID_META_IDS)
334+
def test_experiment_skips_invalid_devices(tmp_path, override_meta, caplog):
335+
"""
336+
Experiment should skip devices with invalid start_time or end_time and
337+
log a warning, but still initialize successfully if at least one valid
338+
device is present. The experiment time range should reflect only the
339+
valid device.
340+
"""
341+
start_val = np.random.lognormal(mean=0.0, sigma=1.0) # Strictly positive float
342+
duration_val = np.random.lognormal(mean=0.0, sigma=1.0)
343+
end_val = start_val + duration_val
344+
345+
start_nonval = np.random.lognormal(mean=0.0, sigma=1.0)
346+
duration_nonval = np.random.lognormal(mean=0.0, sigma=1.0)
347+
end_nonval = start_nonval + duration_nonval
348+
349+
with ExitStack() as stack:
350+
# Valid device with proper metadata
351+
stack.enter_context(
352+
make_sequence_device(
353+
tmp_path,
354+
"valid_device",
355+
start=start_val,
356+
end=end_val,
357+
)
358+
)
359+
# Invalid device with missing start_time and end_time
360+
stack.enter_context(
361+
make_sequence_device(
362+
tmp_path,
363+
"invalid_device",
364+
start=start_nonval,
365+
end=end_nonval,
366+
override_meta=override_meta,
367+
)
368+
)
369+
370+
with caplog.at_level(logging.WARNING, logger="experanto.experiment"):
371+
experiment = Experiment(
372+
root_folder=tmp_path,
373+
modality_config=make_modality_config("valid_device", "invalid_device"),
374+
)
375+
376+
assert "valid_device" in experiment.devices
377+
assert "invalid_device" not in experiment.devices
378+
379+
assert experiment.start_time == (start_val), (
380+
f"Expected start_time={start_val}, got {experiment.start_time}"
381+
)
382+
assert experiment.end_time == (end_val), (
383+
f"Expected end_time={end_val}, got {experiment.end_time}"
384+
)
385+
assert any("invalid_device" in message for message in caplog.messages), (
386+
"Expected warning about invalid_device was skipped"
387+
)

0 commit comments

Comments
 (0)