11import logging
2- import shutil
3- from contextlib import ExitStack , contextmanager
4- from pathlib import Path
52from unittest .mock import MagicMock
63
74import numpy as np
1512
1613from .create_experiment import (
1714 get_default_config ,
18- make_modality_config ,
19- make_sequence_device ,
2015 setup_test_experiment ,
2116)
2217
@@ -236,40 +231,56 @@ def test_experiment_start_end_time_reflects_union(
236231 tmp_path , device_ranges , expected_start , expected_end , n_signals
237232):
238233 """Experiment.start_time and end_time should reflect the union of all device time ranges."""
239- device_names = [f"device_{ i } " for i in range (len (device_ranges ))]
240- with ExitStack () as stack :
241- for name , (start , end ) in zip (device_names , device_ranges ):
242- stack .enter_context (
243- make_sequence_device (
244- tmp_path ,
245- name ,
246- start = start ,
247- end = end ,
248- n_signals = n_signals ,
249- sampling_rate = float (np .random .randint (5 , 30 )),
250- )
251- )
234+ devices_kwargs = [
235+ {
236+ "start_time" : start ,
237+ "t_end" : end ,
238+ "n_signals" : n_signals ,
239+ "sampling_rate" : float (np .random .randint (5 , 30 )),
240+ }
241+ for start , end in device_ranges
242+ ]
243+
244+ with setup_test_experiment (
245+ tmp_path , n_devices = len (device_ranges ), devices_kwargs = devices_kwargs
246+ ) as experiment_path :
247+ # Manually build the config dict for however many devices were generated
248+ config = {}
249+ for i in range (len (device_ranges )):
250+ config [f"device_{ i } " ] = {
251+ "interpolation" : {
252+ "sampling_rate" : 10.0 ,
253+ "offset" : float (np .random .rand ()),
254+ }
255+ }
256+
252257 experiment = Experiment (
253- root_folder = tmp_path ,
254- modality_config = make_modality_config (
255- * device_names , offsets = [float (np .random .rand ()) for _ in device_names ]
256- ),
258+ root_folder = str (experiment_path ), modality_config = config
257259 )
260+
258261 assert experiment .start_time == pytest .approx (expected_start )
259262 assert experiment .end_time == pytest .approx (expected_end )
260263
261264
262265@pytest .mark .parametrize ("override_meta" , INVALID_META_CASES , ids = INVALID_META_IDS )
263266def test_experiment_invalid_metadata (tmp_path , override_meta ):
264- with make_sequence_device (
265- tmp_path , "device_0" , start = 0.0 , end = 10.0 , override_meta = override_meta
266- ):
267+ with setup_test_experiment (
268+ tmp_path , n_devices = 1 , devices_kwargs = [{"start_time" : 0.0 , "t_end" : 10.0 }]
269+ ) as experiment_path :
270+ # Explicitly corrupt the generated metadata file
271+ meta_file = experiment_path / "device_0" / "meta.yml"
272+ with open (meta_file , "r" ) as f :
273+ meta = yaml .safe_load (f )
274+ meta .update (override_meta )
275+ with open (meta_file , "w" ) as f :
276+ yaml .safe_dump (meta , f )
277+
278+ config = {"device_0" : {"interpolation" : {"sampling_rate" : 10.0 }}}
279+
267280 with pytest .raises (
268281 ValueError , match = "Experiment time range could not be determined"
269282 ):
270- Experiment (
271- root_folder = tmp_path , modality_config = make_modality_config ("device_0" )
272- )
283+ Experiment (root_folder = str (experiment_path ), modality_config = config )
273284
274285
275286@pytest .mark .parametrize ("override_meta" , INVALID_META_CASES , ids = INVALID_META_IDS )
@@ -279,24 +290,37 @@ def test_experiment_skips_invalid_devices(tmp_path, override_meta, caplog):
279290 np .random .lognormal (0.0 , 1.0 ),
280291 )
281292 end_val = start_val + duration_val
282- with ExitStack () as stack :
283- stack .enter_context (
284- make_sequence_device (tmp_path , "valid_device" , start = start_val , end = end_val )
285- )
286- stack .enter_context (
287- make_sequence_device (
288- tmp_path ,
289- "invalid_device" ,
290- start = 0.0 ,
291- end = 10.0 ,
292- override_meta = override_meta ,
293- )
294- )
293+
294+ devices_kwargs = [
295+ {"start_time" : start_val , "t_end" : end_val }, # valid device
296+ {"start_time" : 0.0 , "t_end" : 10.0 }, # invalid device
297+ ]
298+
299+ with setup_test_experiment (
300+ tmp_path , n_devices = 2 , devices_kwargs = devices_kwargs
301+ ) as experiment_path :
302+ # Rename the folders to match what the old test expected
303+ (experiment_path / "device_0" ).rename (experiment_path / "valid_device" )
304+ (experiment_path / "device_1" ).rename (experiment_path / "invalid_device" )
305+
306+ # Explicitly corrupt the metadata file for the invalid device
307+ meta_file = experiment_path / "invalid_device" / "meta.yml"
308+ with open (meta_file , "r" ) as f :
309+ meta = yaml .safe_load (f )
310+ meta .update (override_meta )
311+ with open (meta_file , "w" ) as f :
312+ yaml .safe_dump (meta , f )
313+
314+ config = {
315+ "valid_device" : {"interpolation" : {"sampling_rate" : 10.0 }},
316+ "invalid_device" : {"interpolation" : {"sampling_rate" : 10.0 }},
317+ }
318+
295319 with caplog .at_level (logging .WARNING , logger = "experanto.experiment" ):
296320 experiment = Experiment (
297- root_folder = tmp_path ,
298- modality_config = make_modality_config ("valid_device" , "invalid_device" ),
321+ root_folder = str (experiment_path ), modality_config = config
299322 )
323+
300324 assert "valid_device" in experiment .devices
301325 assert "invalid_device" not in experiment .devices
302326 assert experiment .start_time == pytest .approx (start_val )
0 commit comments