Skip to content

Commit 66fe819

Browse files
committed
test: add comprehensive coverage for Experiment class and data generation utilities
1 parent d851484 commit 66fe819

11 files changed

Lines changed: 378 additions & 237 deletions

File tree

.github/workflows/docs.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,13 @@ on:
77
pull_request:
88
branches:
99
- '**'
10-
workflow_dispatch:
1110

1211
jobs:
1312
build-docs:
1413
name: Build Sphinx Docs
1514
runs-on: ubuntu-latest
1615
if: |
1716
github.event_name == 'push' ||
18-
github.event_name == 'workflow_dispatch' ||
1917
(
2018
github.event_name == 'pull_request' &&
2119
github.repository != github.event.pull_request.head.repo.full_name

.github/workflows/ruff.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,13 @@ on:
77
pull_request:
88
branches:
99
- '**'
10-
workflow_dispatch:
1110

1211
jobs:
1312
ruff:
1413
name: Ruff Linting
1514
runs-on: ubuntu-latest
1615
if: |
1716
github.event_name == 'push' ||
18-
github.event_name == 'workflow_dispatch' ||
1917
(
2018
github.event_name == 'pull_request' &&
2119
github.repository != github.event.pull_request.head.repo.full_name

.github/workflows/test.yml

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,13 @@ on:
77
pull_request:
88
branches:
99
- '**'
10-
workflow_dispatch:
1110

1211
jobs:
1312
test:
1413
name: Run Tests
1514
runs-on: ubuntu-latest
1615
if: |
1716
github.event_name == 'push' ||
18-
github.event_name == 'workflow_dispatch' ||
1917
(
2018
github.event_name == 'pull_request' &&
2119
github.repository != github.event.pull_request.head.repo.full_name
@@ -49,7 +47,6 @@ jobs:
4947
runs-on: ubuntu-latest
5048
if: |
5149
github.event_name == 'push' ||
52-
github.event_name == 'workflow_dispatch' ||
5350
(
5451
github.event_name == 'pull_request' &&
5552
github.repository != github.event.pull_request.head.repo.full_name
@@ -67,25 +64,25 @@ jobs:
6764
- name: Install tools
6865
run: |
6966
python -m pip install --upgrade pip
70-
pip install -e ".[dev]"
67+
pip install "black[jupyter]==25.1.0" isort==6.0.1
7168
7269
- name: Run Black and Isort
7370
run: |
7471
if [[ "${{ github.event_name }}" == "push" && "${{ github.repository }}" == "${{ github.event.repository.full_name }}" ]]; then
7572
echo "Push event in same repo: Running black and isort with auto-format and commit"
7673
black .
7774
isort .
78-
git config user.name "github-actions[bot]"
79-
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
75+
git config user.name "github-actions"
76+
git config user.email "github-actions@github.com"
8077
if ! git diff --quiet; then
8178
git commit -am "style: auto-format with black and isort"
8279
git push
8380
else
8481
echo "No formatting changes to commit."
8582
fi
8683
87-
elif [[ "${{ github.event_name }}" == "pull_request" || "${{ github.event_name }}" == "workflow_dispatch" ]]; then
88-
echo "PR or manual run: Running black and isort in check mode"
84+
elif [[ "${{ github.event_name }}" == "pull_request" ]]; then
85+
echo "PR from fork: Running black and isort in check mode"
8986
black --check . || { echo "Black formatting issues found. Run 'black .' to fix."; exit 1; }
9087
isort --check-only . || { echo "isort import order issues found. Run 'isort .' to fix."; exit 1; }
9188
@@ -98,7 +95,6 @@ jobs:
9895
runs-on: ubuntu-latest
9996
if: |
10097
github.event_name == 'push' ||
101-
github.event_name == 'workflow_dispatch' ||
10298
(
10399
github.event_name == 'pull_request' &&
104100
github.repository != github.event.pull_request.head.repo.full_name

.readthedocs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ version: 2
44
build:
55
os: ubuntu-22.04
66
tools:
7-
python: "3.12"
7+
python: "3.9"
88

99
python:
1010
install:

docs/source/conf.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# -- Project information -----------------------------------------------------
1414
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
1515

16-
project = "Experanto"
16+
project = "experanto"
1717
copyright = f"{datetime.today().strftime('%Y')}, sinzlab"
1818
author = "sinzlab"
1919
release = "0.1"
@@ -83,8 +83,6 @@
8383
"optree",
8484
"rootutils",
8585
"numba",
86-
"numpy",
87-
"yaml",
8886
]
8987

9088
# -- Intersphinx settings (cross-references to external docs) ----------------

experanto/experiment.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from __future__ import annotations
22

33
import logging
4+
import re
45
import warnings
6+
from collections.abc import Sequence
57
from pathlib import Path
68

79
import numpy as np
10+
811
from hydra.utils import instantiate
912
from omegaconf import DictConfig
1013

@@ -76,7 +79,7 @@ def __init__(
7679
def _load_devices(self) -> None:
7780
# Populate devices by going through subfolders
7881
# Assumption: blocks are sorted by start time
79-
device_folders = [d for d in self.root_folder.iterdir() if (d.is_dir())]
82+
device_folders = [d for d in self.root_folder.iterdir() if d.is_dir()]
8083

8184
for d in device_folders:
8285
if d.name not in self.modality_config:
@@ -95,14 +98,14 @@ def _load_devices(self) -> None:
9598
dev = instantiate(
9699
interp_conf, root_folder=d, cache_data=self.cache_data
97100
)
101+
98102
# Check if instantiated object is proper Interpolator
99103
if not isinstance(dev, Interpolator):
100104
raise ValueError(
101-
"Please provide an Interpolator which inherits from experantos Interpolator class."
105+
"Instantiated object must inherit from Interpolator class."
102106
)
103107

104108
elif isinstance(interp_conf, Interpolator):
105-
# Already instantiated Interpolator
106109
dev = interp_conf
107110

108111
else:
@@ -207,26 +210,22 @@ def interpolate(
207210
dict_keys(['screen', 'responses', 'eye_tracker'])
208211
"""
209212
if device is None:
210-
values = {}
211-
valid = {}
213+
values, valid = {}, {}
212214
for d, interp in self.devices.items():
213215
res = interp.interpolate(times, return_valid=return_valid)
214216
if return_valid:
215217
vals, vlds = res
216-
values[d] = vals
217-
valid[d] = vlds
218+
values[d], valid[d] = vals, vlds
218219
else:
219220
values[d] = res
220-
if return_valid:
221-
return values, valid
222-
else:
223-
return values
221+
return (values, valid) if return_valid else values
222+
224223
elif isinstance(device, str):
225-
assert device in self.devices, f"Unknown device '{device}'"
226-
res = self.devices[device].interpolate(times, return_valid=return_valid)
227-
return res
228-
else:
229-
raise ValueError(f"Unsupported device type: {type(device)}")
224+
if device not in self.devices:
225+
raise KeyError(f"Unknown device '{device}'")
226+
return self.devices[device].interpolate(times, return_valid=return_valid)
227+
228+
raise ValueError(f"Unsupported device type: {type(device)}")
230229

231230
def get_valid_range(self, device_name: str) -> tuple[float, float]:
232231
"""Get the valid time range for a specific device.
@@ -239,7 +238,7 @@ def get_valid_range(self, device_name: str) -> tuple[float, float]:
239238
Returns
240239
-------
241240
tuple
242-
A tuple ``(start_time, end_time)`` representing the valid
241+
A tuple `(start_time, end_time)` representing the valid
243242
time interval in seconds.
244243
"""
245244
return tuple(self.devices[device_name].valid_interval)

experanto/interpolators.py

Lines changed: 12 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@ def __exit__(self, *exc):
8484
self.close()
8585

8686
@staticmethod
87-
def create(root_folder: str, cache_data: bool = False, **kwargs) -> Interpolator:
87+
def create(
88+
root_folder: str | Path, cache_data: bool = False, **kwargs
89+
) -> "Interpolator":
8890
"""Factory method to create the appropriate interpolator for a modality.
8991
9092
Reads the ``meta.yml`` file in the folder to determine the modality type
@@ -109,35 +111,24 @@ def create(root_folder: str, cache_data: bool = False, **kwargs) -> Interpolator
109111
ValueError
110112
If the modality type is not supported.
111113
"""
114+
root_folder = str(root_folder)
112115
with open(Path(root_folder) / "meta.yml") as file:
113116
meta_data = yaml.safe_load(file)
114117
modality = meta_data.get("modality")
115118

116119
if modality == "sequence":
117120
if meta_data.get("phase_shift_per_signal", False):
118121
return PhaseShiftedSequenceInterpolator(
119-
root_folder, cache_data=cache_data, **kwargs
122+
root_folder, cache_data, **kwargs
120123
)
121124
else:
122-
return SequenceInterpolator(
123-
root_folder, cache_data=cache_data, **kwargs
124-
)
125+
return SequenceInterpolator(root_folder, cache_data, **kwargs)
125126
elif modality == "screen":
126-
use_stimuli_names = kwargs.pop(
127-
"use_stimuli_names", meta_data.get("use_stimuli_names", False)
128-
)
129-
return ScreenInterpolator(
130-
root_folder,
131-
cache_data=cache_data,
132-
use_stimuli_names=use_stimuli_names,
133-
**kwargs,
134-
)
127+
return ScreenInterpolator(root_folder, cache_data, **kwargs)
135128
elif modality == "time_interval":
136-
return TimeIntervalInterpolator(
137-
root_folder, cache_data=cache_data, **kwargs
138-
)
129+
return TimeIntervalInterpolator(root_folder, cache_data, **kwargs)
139130
elif modality == "spikes":
140-
return SpikeInterpolator(root_folder, cache_data=cache_data, **kwargs)
131+
return SpikeInterpolator(root_folder, cache_data, **kwargs)
141132
else:
142133
raise ValueError(
143134
f"There is no interpolator for {modality}. Please use 'sequence', 'screen', 'time_interval' as modality or provide a custom interpolator."
@@ -497,8 +488,6 @@ class ScreenInterpolator(Interpolator):
497488
native image size from metadata.
498489
normalize : bool, default=False
499490
If True, normalizes frames using stored mean/std statistics.
500-
use_stimuli_names : bool, default=False
501-
If True, uses ``stimulus_name`` from metadata to locate data files instead of trial keys.
502491
**kwargs
503492
Additional keyword arguments (ignored).
504493
@@ -519,11 +508,10 @@ class ScreenInterpolator(Interpolator):
519508
def __init__(
520509
self,
521510
root_folder: str,
522-
cache_data: bool = False,
511+
cache_data: bool = False, # New parameter
523512
rescale: bool = False,
524513
rescale_size: tuple[int, int] | None = None,
525514
normalize: bool = False,
526-
use_stimuli_names: bool = False,
527515
**kwargs,
528516
) -> None:
529517
super().__init__(root_folder)
@@ -533,7 +521,6 @@ def __init__(
533521
self.valid_interval = TimeInterval(self.start_time, self.end_time)
534522
self.rescale = rescale
535523
self.cache_trials = cache_data # Store the cache preference
536-
self.use_stimuli_names = use_stimuli_names
537524
self._parse_trials()
538525

539526
# create mapping from image index to file index
@@ -618,14 +605,8 @@ def _parse_trials(self) -> None:
618605
metadatas, keys = self.read_combined_meta()
619606

620607
for key, metadata in zip(keys, metadatas, strict=True):
621-
if self.use_stimuli_names:
622-
stimulus_name = metadata.get("stimulus_name")
623-
assert (
624-
stimulus_name is not None
625-
), f"stimulus_name is required in metadata when use_stimuli_names is True, but not found for key: {key}"
626-
data_file_name = self.root_folder / "data" / f"{stimulus_name}.npy"
627-
else:
628-
data_file_name = self.root_folder / "data" / f"{key}.npy"
608+
data_file_name = self.root_folder / "data" / f"{key}.npy"
609+
# Pass the cache_trials parameter when creating trials
629610
self.trials.append(
630611
ScreenTrial.create(
631612
data_file_name, metadata, cache_data=self.cache_trials

pyproject.toml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ version = "0.1"
44
description = "Python package to interpolate recordings and stimuli of neuroscience experiments"
55
readme = "README.md"
66
requires-python = ">=3.9"
7-
# When adding or removing dependencies, also update autodoc_mock_imports in docs/source/conf.py if necessary.
87
dependencies = [
98
"scipy>=1.13.1",
109
"jaxtyping>=0.2.30",
@@ -22,9 +21,7 @@ dependencies = [
2221
dev = [
2322
"pytest==8.3.5",
2423
"pytest-cov>=7.0.0",
25-
"pyright==1.1.408",
26-
"black[jupyter]==25.1.0",
27-
"isort==6.0.1",
24+
"pyright",
2825
"hypothesis>=6.0",
2926
"ruff==0.15.6",
3027
]

0 commit comments

Comments
 (0)