Skip to content

Commit 4314ecd

Browse files
committed
fix: restore project files accidentally overwritten during rebase
1 parent 66fe819 commit 4314ecd

6 files changed

Lines changed: 51 additions & 19 deletions

File tree

.github/workflows/docs.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@ on:
77
pull_request:
88
branches:
99
- '**'
10+
workflow_dispatch:
1011

1112
jobs:
1213
build-docs:
1314
name: Build Sphinx Docs
1415
runs-on: ubuntu-latest
1516
if: |
1617
github.event_name == 'push' ||
18+
github.event_name == 'workflow_dispatch' ||
1719
(
1820
github.event_name == 'pull_request' &&
1921
github.repository != github.event.pull_request.head.repo.full_name

.github/workflows/ruff.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@ on:
77
pull_request:
88
branches:
99
- '**'
10+
workflow_dispatch:
1011

1112
jobs:
1213
ruff:
1314
name: Ruff Linting
1415
runs-on: ubuntu-latest
1516
if: |
1617
github.event_name == 'push' ||
18+
github.event_name == 'workflow_dispatch' ||
1719
(
1820
github.event_name == 'pull_request' &&
1921
github.repository != github.event.pull_request.head.repo.full_name

.github/workflows/test.yml

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@ on:
77
pull_request:
88
branches:
99
- '**'
10+
workflow_dispatch:
1011

1112
jobs:
1213
test:
1314
name: Run Tests
1415
runs-on: ubuntu-latest
1516
if: |
1617
github.event_name == 'push' ||
18+
github.event_name == 'workflow_dispatch' ||
1719
(
1820
github.event_name == 'pull_request' &&
1921
github.repository != github.event.pull_request.head.repo.full_name
@@ -47,6 +49,7 @@ jobs:
4749
runs-on: ubuntu-latest
4850
if: |
4951
github.event_name == 'push' ||
52+
github.event_name == 'workflow_dispatch' ||
5053
(
5154
github.event_name == 'pull_request' &&
5255
github.repository != github.event.pull_request.head.repo.full_name
@@ -64,25 +67,25 @@ jobs:
6467
- name: Install tools
6568
run: |
6669
python -m pip install --upgrade pip
67-
pip install "black[jupyter]==25.1.0" isort==6.0.1
70+
pip install -e ".[dev]"
6871
6972
- name: Run Black and Isort
7073
run: |
7174
if [[ "${{ github.event_name }}" == "push" && "${{ github.repository }}" == "${{ github.event.repository.full_name }}" ]]; then
7275
echo "Push event in same repo: Running black and isort with auto-format and commit"
7376
black .
7477
isort .
75-
git config user.name "github-actions"
76-
git config user.email "github-actions@github.com"
78+
git config user.name "github-actions[bot]"
79+
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
7780
if ! git diff --quiet; then
7881
git commit -am "style: auto-format with black and isort"
7982
git push
8083
else
8184
echo "No formatting changes to commit."
8285
fi
8386
84-
elif [[ "${{ github.event_name }}" == "pull_request" ]]; then
85-
echo "PR from fork: Running black and isort in check mode"
87+
elif [[ "${{ github.event_name }}" == "pull_request" || "${{ github.event_name }}" == "workflow_dispatch" ]]; then
88+
echo "PR or manual run: Running black and isort in check mode"
8689
black --check . || { echo "Black formatting issues found. Run 'black .' to fix."; exit 1; }
8790
isort --check-only . || { echo "isort import order issues found. Run 'isort .' to fix."; exit 1; }
8891
@@ -95,6 +98,7 @@ jobs:
9598
runs-on: ubuntu-latest
9699
if: |
97100
github.event_name == 'push' ||
101+
github.event_name == 'workflow_dispatch' ||
98102
(
99103
github.event_name == 'pull_request' &&
100104
github.repository != github.event.pull_request.head.repo.full_name

docs/source/conf.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# -- Project information -----------------------------------------------------
1414
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
1515

16-
project = "experanto"
16+
project = "Experanto"
1717
copyright = f"{datetime.today().strftime('%Y')}, sinzlab"
1818
author = "sinzlab"
1919
release = "0.1"
@@ -83,6 +83,8 @@
8383
"optree",
8484
"rootutils",
8585
"numba",
86+
"numpy",
87+
"yaml",
8688
]
8789

8890
# -- Intersphinx settings (cross-references to external docs) ----------------

experanto/interpolators.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,7 @@ def __exit__(self, *exc):
8484
self.close()
8585

8686
@staticmethod
87-
def create(
88-
root_folder: str | Path, cache_data: bool = False, **kwargs
89-
) -> "Interpolator":
87+
def create(root_folder: str, cache_data: bool = False, **kwargs) -> Interpolator:
9088
"""Factory method to create the appropriate interpolator for a modality.
9189
9290
Reads the ``meta.yml`` file in the folder to determine the modality type
@@ -111,24 +109,35 @@ def create(
111109
ValueError
112110
If the modality type is not supported.
113111
"""
114-
root_folder = str(root_folder)
115112
with open(Path(root_folder) / "meta.yml") as file:
116113
meta_data = yaml.safe_load(file)
117114
modality = meta_data.get("modality")
118115

119116
if modality == "sequence":
120117
if meta_data.get("phase_shift_per_signal", False):
121118
return PhaseShiftedSequenceInterpolator(
122-
root_folder, cache_data, **kwargs
119+
root_folder, cache_data=cache_data, **kwargs
123120
)
124121
else:
125-
return SequenceInterpolator(root_folder, cache_data, **kwargs)
122+
return SequenceInterpolator(
123+
root_folder, cache_data=cache_data, **kwargs
124+
)
126125
elif modality == "screen":
127-
return ScreenInterpolator(root_folder, cache_data, **kwargs)
126+
use_stimuli_names = kwargs.pop(
127+
"use_stimuli_names", meta_data.get("use_stimuli_names", False)
128+
)
129+
return ScreenInterpolator(
130+
root_folder,
131+
cache_data=cache_data,
132+
use_stimuli_names=use_stimuli_names,
133+
**kwargs,
134+
)
128135
elif modality == "time_interval":
129-
return TimeIntervalInterpolator(root_folder, cache_data, **kwargs)
136+
return TimeIntervalInterpolator(
137+
root_folder, cache_data=cache_data, **kwargs
138+
)
130139
elif modality == "spikes":
131-
return SpikeInterpolator(root_folder, cache_data, **kwargs)
140+
return SpikeInterpolator(root_folder, cache_data=cache_data, **kwargs)
132141
else:
133142
raise ValueError(
134143
f"There is no interpolator for {modality}. Please use 'sequence', 'screen', 'time_interval' as modality or provide a custom interpolator."
@@ -488,6 +497,8 @@ class ScreenInterpolator(Interpolator):
488497
native image size from metadata.
489498
normalize : bool, default=False
490499
If True, normalizes frames using stored mean/std statistics.
500+
use_stimuli_names : bool, default=False
501+
If True, uses ``stimulus_name`` from metadata to locate data files instead of trial keys.
491502
**kwargs
492503
Additional keyword arguments (ignored).
493504
@@ -508,10 +519,11 @@ class ScreenInterpolator(Interpolator):
508519
def __init__(
509520
self,
510521
root_folder: str,
511-
cache_data: bool = False, # New parameter
522+
cache_data: bool = False,
512523
rescale: bool = False,
513524
rescale_size: tuple[int, int] | None = None,
514525
normalize: bool = False,
526+
use_stimuli_names: bool = False,
515527
**kwargs,
516528
) -> None:
517529
super().__init__(root_folder)
@@ -521,6 +533,7 @@ def __init__(
521533
self.valid_interval = TimeInterval(self.start_time, self.end_time)
522534
self.rescale = rescale
523535
self.cache_trials = cache_data # Store the cache preference
536+
self.use_stimuli_names = use_stimuli_names
524537
self._parse_trials()
525538

526539
# create mapping from image index to file index
@@ -605,8 +618,14 @@ def _parse_trials(self) -> None:
605618
metadatas, keys = self.read_combined_meta()
606619

607620
for key, metadata in zip(keys, metadatas, strict=True):
608-
data_file_name = self.root_folder / "data" / f"{key}.npy"
609-
# Pass the cache_trials parameter when creating trials
621+
if self.use_stimuli_names:
622+
stimulus_name = metadata.get("stimulus_name")
623+
assert (
624+
stimulus_name is not None
625+
), f"stimulus_name is required in metadata when use_stimuli_names is True, but not found for key: {key}"
626+
data_file_name = self.root_folder / "data" / f"{stimulus_name}.npy"
627+
else:
628+
data_file_name = self.root_folder / "data" / f"{key}.npy"
610629
self.trials.append(
611630
ScreenTrial.create(
612631
data_file_name, metadata, cache_data=self.cache_trials

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ version = "0.1"
44
description = "Python package to interpolate recordings and stimuli of neuroscience experiments"
55
readme = "README.md"
66
requires-python = ">=3.9"
7+
# When adding or removing dependencies, also update autodoc_mock_imports in docs/source/conf.py if necessary.
78
dependencies = [
89
"scipy>=1.13.1",
910
"jaxtyping>=0.2.30",
@@ -21,7 +22,9 @@ dependencies = [
2122
dev = [
2223
"pytest==8.3.5",
2324
"pytest-cov>=7.0.0",
24-
"pyright",
25+
"pyright==1.1.408",
26+
"black[jupyter]==25.1.0",
27+
"isort==6.0.1",
2528
"hypothesis>=6.0",
2629
"ruff==0.15.6",
2730
]

0 commit comments

Comments
 (0)