Skip to content

Commit 16a2593

Browse files
committed
Incorporate maintainer feedback for experiment tests
- Fix Pyright type errors breaking CI - Switch to tmp_path for test isolation - Plug Windows file locks with try/finally - Remove redundant irregular timestamp test
1 parent f0d2623 commit 16a2593

3 files changed

Lines changed: 38 additions & 42 deletions

File tree

tests/create_experiment.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@
55

66
from .create_sequence_data import _generate_sequence_data
77

8-
# Temporary directory for test execution
9-
EXPERIMENT_ROOT = Path("tests/experiment")
10-
118
DEFAULT_CONFIG = {
129
"device_0": {
1310
"sampling_rate": 1.0,
@@ -33,6 +30,7 @@ def get_default_config():
3330

3431
@contextmanager
3532
def create_experiment(
33+
tmp_path,
3634
n_devices=2,
3735
devices_kwargs=None,
3836
):
@@ -41,14 +39,18 @@ def create_experiment(
4139

4240
devices_kwargs = [default_params | kwargs for kwargs in devices_kwargs]
4341

42+
assert len(devices_kwargs) == n_devices, "wrong experiment creation"
43+
4444
try:
45-
EXPERIMENT_ROOT.mkdir(parents=True, exist_ok=True)
45+
tmp_path.mkdir(parents=True, exist_ok=True)
4646

4747
for device_id, device_kwargs in enumerate(devices_kwargs):
48-
device_path = EXPERIMENT_ROOT / f"device_{device_id}"
49-
_generate_sequence_data(str(device_path), **device_kwargs) # pyright: ignore
48+
device_path = tmp_path / f"device_{device_id}"
49+
_generate_sequence_data(
50+
str(device_path), **device_kwargs
51+
) # pyright: ignore
5052

51-
yield EXPERIMENT_ROOT
53+
yield tmp_path
5254
finally:
53-
if EXPERIMENT_ROOT.exists():
54-
shutil.rmtree(EXPERIMENT_ROOT)
55+
if tmp_path.exists():
56+
shutil.rmtree(tmp_path)

tests/create_sequence_data.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ def _generate_sequence_data(
2020
irregular=False,
2121
):
2222
"""Generates synthetic sequence data folders for testing interpolator logic."""
23+
from pathlib import Path
24+
25+
sequence_root = Path(sequence_root)
2326
sequence_root.mkdir(parents=True, exist_ok=True)
2427
(sequence_root / "meta").mkdir(parents=True, exist_ok=True)
2528

@@ -87,6 +90,7 @@ def create_sequence_data(
8790
t_end=10.0,
8891
sampling_rate=10.0,
8992
contain_nans=False,
93+
start_time=0.0,
9094
):
9195
"""Context manager for temporary sequence data creation and cleanup."""
9296
try:
@@ -98,6 +102,7 @@ def create_sequence_data(
98102
t_end=t_end,
99103
sampling_rate=sampling_rate,
100104
contain_nans=contain_nans,
105+
start_time=start_time,
101106
)
102107
finally:
103108
if SEQUENCE_ROOT.exists():
@@ -113,4 +118,7 @@ def sequence_data_and_interpolator(data_kwargs=None, interp_kwargs=None):
113118
from experanto.interpolators import Interpolator
114119

115120
seq_interp = Interpolator.create(str(SEQUENCE_ROOT), **interp_kwargs)
116-
yield timestamps, data, shifts, seq_interp
121+
try:
122+
yield timestamps, data, shifts, seq_interp
123+
finally:
124+
seq_interp.close()

tests/test_experiment.py

Lines changed: 18 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .create_experiment import create_experiment, get_default_config
1010

1111

12-
class TestInterpolator(Interpolator):
12+
class DummyInterpolator(Interpolator):
1313
"""Small concrete interpolator used for testing Experiment routing logic."""
1414

1515
def __init__(self):
@@ -22,7 +22,7 @@ def __init__(self):
2222
@pytest.fixture
2323
def mock_interpolator():
2424
"""Shared interpolator instance to isolate Experiment logic from interpolation math."""
25-
return TestInterpolator()
25+
return DummyInterpolator()
2626

2727

2828
def test_experiment_initialization_and_device_loading(tmp_path, mock_interpolator):
@@ -60,17 +60,18 @@ def test_experiment_interpolate_routing(tmp_path, mock_interpolator):
6060

6161
# Bulk routing (device=None)
6262
res_dict = exp.interpolate(test_times, device=None)
63-
assert isinstance(res, dict)
64-
assert isinstance(res["screen"], np.ndarray)
63+
assert isinstance(res_dict, dict)
64+
assert isinstance(res_dict["screen"], np.ndarray)
6565
np.testing.assert_array_equal(res_dict["screen"], np.array([1, 2, 3]))
6666

6767

6868
@pytest.mark.parametrize(
6969
"device_name, start_t, end_t", [("device_0", 0.0, 10.0), ("device_1", 0.0, 20.0)]
7070
)
71-
def test_get_valid_range_all_devices(device_name, start_t, end_t):
71+
def test_get_valid_range_all_devices(tmp_path, device_name, start_t, end_t):
7272
"""Integration test for valid_interval propagation from disk to object."""
7373
with create_experiment(
74+
tmp_path,
7475
devices_kwargs=[{"t_end": 10.0}, {"t_end": 20.0}],
7576
) as experiment_path:
7677
experiment = Experiment(
@@ -83,8 +84,8 @@ def test_get_valid_range_all_devices(device_name, start_t, end_t):
8384
assert isinstance(valid_range, tuple)
8485

8586

86-
def test_get_valid_range_raises_for_invalid_device():
87-
with create_experiment() as experiment_path:
87+
def test_get_valid_range_raises_for_invalid_device(tmp_path):
88+
with create_experiment(tmp_path) as experiment_path:
8889
experiment = Experiment(
8990
root_folder=str(experiment_path),
9091
modality_config=get_default_config(),
@@ -93,12 +94,14 @@ def test_get_valid_range_raises_for_invalid_device():
9394
experiment.get_valid_range("device_does_not_exist")
9495

9596

96-
def test_experiment_with_non_zero_start_time():
97+
def test_experiment_with_non_zero_start_time(tmp_path):
9798
"""Test boundary conditions for data not starting at t=0."""
9899
start_offset, duration = 1.5, 10.0
99100

100101
with create_experiment(
101-
devices_kwargs=[{"t_end": start_offset + duration, "start_time": start_offset}]
102+
tmp_path,
103+
n_devices=1,
104+
devices_kwargs=[{"t_end": start_offset + duration, "start_time": start_offset}],
102105
) as experiment_path:
103106
experiment = Experiment(
104107
root_folder=str(experiment_path),
@@ -110,20 +113,22 @@ def test_experiment_with_non_zero_start_time():
110113
assert res is not None
111114

112115

113-
def test_experiment_numeric_precision_offset():
116+
def test_experiment_numeric_precision_offset(tmp_path):
114117
"""Stress test using non-integer rates and offsets to catch float drift."""
115118
start_offset = 0.123456789
116119
sampling_rate = 33.3333333
117120
duration = 1.0
118121

119122
with create_experiment(
123+
tmp_path,
124+
n_devices=1,
120125
devices_kwargs=[
121126
{
122127
"start_time": start_offset,
123128
"t_end": start_offset + duration,
124129
"sampling_rate": sampling_rate,
125130
}
126-
]
131+
],
127132
) as experiment_path:
128133
experiment = Experiment(
129134
root_folder=str(experiment_path),
@@ -138,28 +143,9 @@ def test_experiment_numeric_precision_offset():
138143
assert res is not None
139144

140145

141-
def test_experiment_irregular_timestamps():
142-
"""Verify interpolation stability with jittered (non-linear) time steps."""
143-
with create_experiment(
144-
devices_kwargs=[{"irregular": True, "sampling_rate": 10.0}]
145-
) as experiment_path:
146-
experiment = Experiment(
147-
root_folder=str(experiment_path),
148-
modality_config=get_default_config(),
149-
)
150-
151-
valid_range = experiment.get_valid_range("device_0")
152-
mid_point = (valid_range[0] + valid_range[1]) / 2
153-
154-
res = experiment.interpolate(np.array([mid_point]), device="device_0")
155-
assert res is not None
156-
assert isinstance(res, np.ndarray)
157-
assert res.shape == (1, 10)
158-
159-
160-
def test_experiment_multi_device_interpolation():
146+
def test_experiment_multi_device_interpolation(tmp_path):
161147
"""Check data consistency when interpolating across multiple modalities."""
162-
with create_experiment(n_devices=2) as experiment_path:
148+
with create_experiment(tmp_path, n_devices=2) as experiment_path:
163149
exp = Experiment(
164150
root_folder=str(experiment_path), modality_config=get_default_config()
165151
)

0 commit comments

Comments
 (0)