|
| 1 | +import logging |
| 2 | +from contextlib import ExitStack |
1 | 3 | from unittest.mock import MagicMock |
2 | 4 |
|
3 | 5 | import numpy as np |
|
6 | 8 | from experanto.experiment import Experiment |
7 | 9 | from experanto.interpolators import Interpolator |
8 | 10 |
|
9 | | -from .create_experiment import create_experiment, get_default_config |
| 11 | +from .create_experiment import ( |
| 12 | + create_experiment, |
| 13 | + get_default_config, |
| 14 | + make_modality_config, |
| 15 | + make_sequence_device, |
| 16 | +) |
10 | 17 |
|
11 | 18 |
|
12 | 19 | class DummyInterpolator(Interpolator): |
@@ -180,3 +187,201 @@ def test_experiment_multi_device_interpolation(tmp_path, return_valid, device): |
180 | 187 | else: |
181 | 188 | assert isinstance(data, np.ndarray) |
182 | 189 | assert data.shape == (2, 10) |
| 190 | + |
| 191 | + |
| 192 | +DEVICE_TIME_RANGE_CASES = [ |
| 193 | + # Single device: start and end should match that device's range |
| 194 | + ([(2.0, 9.0)], 2.0, 9.0), |
| 195 | + # Two devices with different ranges: start should be min, end should be max |
| 196 | + ([(1.0, 8.0), (0.0, 10.0)], 0.0, 10.0), |
| 197 | + # Three devices with different ranges: start should be min, end should be max |
| 198 | + ([(0.0, 10.0), (1.0, 8.0), (2.0, 9.0)], 0.0, 10.0), |
| 199 | + # Devices with non-overlapping ranges: start should be min, end should be max |
| 200 | + ([(0.0, 3.0), (7.0, 8.0)], 0.0, 8.0), |
| 201 | + # Devices with identical ranges: start and end should match that range |
| 202 | + ([(1.0, 5.0), (1.0, 5.0)], 1.0, 5.0), |
| 203 | + # Large time stamps: start should be min, end should be max |
| 204 | + ([(1e9, 1e9 + 100), (1e9 - 50, 1e9 + 50)], 1e9 - 50, 1e9 + 100), |
| 205 | +] |
| 206 | + |
| 207 | +DEVICE_TIME_RANGE_IDS = [ |
| 208 | + "single_device", |
| 209 | + "two_devices_different_ranges", |
| 210 | + "three_devices_different_ranges", |
| 211 | + "non_overlapping_ranges", |
| 212 | + "identical_ranges", |
| 213 | + "large_time_stamps", |
| 214 | +] |
| 215 | + |
| 216 | +# Inverted range is intentionally separate from INVALID_META_CASES — |
| 217 | +# None/NaN/inf are caught per-device before being added to self.devices, |
| 218 | +# whereas start > end is only caught after all devices are loaded. |
| 219 | +INVALID_META_CASES = [ |
| 220 | + {"start_time": None, "end_time": None}, # Both missing |
| 221 | + {"start_time": None, "end_time": 10.0}, # Missing start_time |
| 222 | + {"start_time": 0.0, "end_time": None}, # Missing end_time |
| 223 | + {"start_time": float("inf"), "end_time": 10.0}, # Infinite start_time |
| 224 | + {"start_time": 0.0, "end_time": float("inf")}, # Infinite end_time |
| 225 | + {"start_time": float("-inf"), "end_time": 10.0}, # Negative Infinite start_time |
| 226 | + {"start_time": 0.0, "end_time": float("-inf")}, # Negative Infinite end_time |
| 227 | + {"start_time": float("nan"), "end_time": 10.0}, # NaN start_time |
| 228 | + {"start_time": 0.0, "end_time": float("nan")}, # NaN end_time |
| 229 | +] |
| 230 | + |
| 231 | +INVALID_META_IDS = [ |
| 232 | + "both_missing", |
| 233 | + "missing_start_time", |
| 234 | + "missing_end_time", |
| 235 | + "infinite_start_time", |
| 236 | + "infinite_end_time", |
| 237 | + "negative_infinite_start_time", |
| 238 | + "negative_infinite_end_time", |
| 239 | + "nan_start_time", |
| 240 | + "nan_end_time", |
| 241 | +] |
| 242 | + |
| 243 | + |
| 244 | +# Test for union of device time ranges |
| 245 | +@pytest.mark.parametrize("n_signals", [5, 20]) |
| 246 | +@pytest.mark.parametrize( |
| 247 | + "device_ranges, expected_start, expected_end", |
| 248 | + DEVICE_TIME_RANGE_CASES, |
| 249 | + ids=DEVICE_TIME_RANGE_IDS, |
| 250 | +) |
| 251 | +def test_experiment_start_end_time_reflects_union( |
| 252 | + tmp_path, device_ranges, expected_start, expected_end, n_signals |
| 253 | +): |
| 254 | + """ |
| 255 | + Experiment.start_time and end_time should reflect the union of all |
| 256 | + device time ranges — earliest start and latest end across all devices. |
| 257 | + """ |
| 258 | + device_names = [f"device_{i}" for i in range(len(device_ranges))] |
| 259 | + |
| 260 | + with ExitStack() as stack: |
| 261 | + for name, (start, end) in zip(device_names, device_ranges, strict=True): |
| 262 | + stack.enter_context( |
| 263 | + make_sequence_device( |
| 264 | + tmp_path, |
| 265 | + name, |
| 266 | + start=start, |
| 267 | + end=end, |
| 268 | + n_signals=n_signals, |
| 269 | + sampling_rate=float(np.random.randint(5, 30)), |
| 270 | + ) |
| 271 | + ) |
| 272 | + |
| 273 | + experiment = Experiment( |
| 274 | + root_folder=tmp_path, |
| 275 | + modality_config=make_modality_config( |
| 276 | + *device_names, offsets=[float(np.random.rand()) for _ in device_names] |
| 277 | + ), |
| 278 | + ) |
| 279 | + |
| 280 | + assert experiment.start_time == (expected_start), ( |
| 281 | + f"Expected start_time={expected_start}, got {experiment.start_time}" |
| 282 | + ) |
| 283 | + assert experiment.end_time == (expected_end), ( |
| 284 | + f"Expected end_time={expected_end}, got {experiment.end_time}" |
| 285 | + ) |
| 286 | + |
| 287 | + |
| 288 | +# Safety check |
| 289 | +@pytest.mark.parametrize("override_meta", INVALID_META_CASES, ids=INVALID_META_IDS) |
| 290 | +def test_experiment_invalid_metadata(tmp_path, override_meta): |
| 291 | + """ |
| 292 | + Experiment should raise an error when initialized with invalid metadata. |
| 293 | + Covers cases where start_time or end_time is None, NaN, or infinite. |
| 294 | + """ |
| 295 | + with make_sequence_device( |
| 296 | + tmp_path, |
| 297 | + "device_0", |
| 298 | + start=0.0, |
| 299 | + end=10.0, |
| 300 | + override_meta=override_meta, |
| 301 | + ): |
| 302 | + with pytest.raises( |
| 303 | + ValueError, match="Experiment time range could not be determined" |
| 304 | + ): |
| 305 | + Experiment( |
| 306 | + root_folder=tmp_path, |
| 307 | + modality_config=make_modality_config("device_0"), |
| 308 | + ) |
| 309 | + |
| 310 | + |
| 311 | +def test_experiment_inverted_time_range_raises(tmp_path): |
| 312 | + """ |
| 313 | + Experiment should raise ValueError when start_time > end_time. |
| 314 | + This is a separate guard from invalid metadata (None/NaN/inf) because it |
| 315 | + only becomes apparent after all devices are loaded and the overall time range is computed. |
| 316 | + """ |
| 317 | + with make_sequence_device( |
| 318 | + tmp_path, |
| 319 | + "device_0", |
| 320 | + start=0.0, |
| 321 | + end=10.0, |
| 322 | + override_meta={"start_time": 5.0, "end_time": 2.0}, |
| 323 | + ): |
| 324 | + with pytest.raises( |
| 325 | + ValueError, match="Experiment time range could not be determined" |
| 326 | + ): |
| 327 | + Experiment( |
| 328 | + root_folder=tmp_path, |
| 329 | + modality_config=make_modality_config("device_0"), |
| 330 | + ) |
| 331 | + |
| 332 | + |
| 333 | +@pytest.mark.parametrize("override_meta", INVALID_META_CASES, ids=INVALID_META_IDS) |
| 334 | +def test_experiment_skips_invalid_devices(tmp_path, override_meta, caplog): |
| 335 | + """ |
| 336 | + Experiment should skip devices with invalid start_time or end_time and |
| 337 | + log a warning, but still initialize successfully if at least one valid |
| 338 | + device is present. The experiment time range should reflect only the |
| 339 | + valid device. |
| 340 | + """ |
| 341 | + start_val = np.random.lognormal(mean=0.0, sigma=1.0) # Strictly positive float |
| 342 | + duration_val = np.random.lognormal(mean=0.0, sigma=1.0) |
| 343 | + end_val = start_val + duration_val |
| 344 | + |
| 345 | + start_nonval = np.random.lognormal(mean=0.0, sigma=1.0) |
| 346 | + duration_nonval = np.random.lognormal(mean=0.0, sigma=1.0) |
| 347 | + end_nonval = start_nonval + duration_nonval |
| 348 | + |
| 349 | + with ExitStack() as stack: |
| 350 | + # Valid device with proper metadata |
| 351 | + stack.enter_context( |
| 352 | + make_sequence_device( |
| 353 | + tmp_path, |
| 354 | + "valid_device", |
| 355 | + start=start_val, |
| 356 | + end=end_val, |
| 357 | + ) |
| 358 | + ) |
| 359 | + # Invalid device with missing start_time and end_time |
| 360 | + stack.enter_context( |
| 361 | + make_sequence_device( |
| 362 | + tmp_path, |
| 363 | + "invalid_device", |
| 364 | + start=start_nonval, |
| 365 | + end=end_nonval, |
| 366 | + override_meta=override_meta, |
| 367 | + ) |
| 368 | + ) |
| 369 | + |
| 370 | + with caplog.at_level(logging.WARNING, logger="experanto.experiment"): |
| 371 | + experiment = Experiment( |
| 372 | + root_folder=tmp_path, |
| 373 | + modality_config=make_modality_config("valid_device", "invalid_device"), |
| 374 | + ) |
| 375 | + |
| 376 | + assert "valid_device" in experiment.devices |
| 377 | + assert "invalid_device" not in experiment.devices |
| 378 | + |
| 379 | + assert experiment.start_time == (start_val), ( |
| 380 | + f"Expected start_time={start_val}, got {experiment.start_time}" |
| 381 | + ) |
| 382 | + assert experiment.end_time == (end_val), ( |
| 383 | + f"Expected end_time={end_val}, got {experiment.end_time}" |
| 384 | + ) |
| 385 | + assert any("invalid_device" in message for message in caplog.messages), ( |
| 386 | + "Expected warning about invalid_device was skipped" |
| 387 | + ) |
0 commit comments