Skip to content

nerdslab/poyo_plus

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

POYO+ 🧠

Official codebase for POYO+ from ICLR 2025 (Paper Link).


POYO+ is a multi-task version of POYO. This is an example training script for the Calcium POYO+ model in Azabou and Pan et al. 2025, corresponding to the module torch_brain.models.CalciumPOYOPlus.

Installation

First, install uv by following the steps here. Then, create your Python environment:

uv venv venv -p 3.11
source venv/bin/activate
uv pip install -r requirements.txt

Training POYO+ on Calcium Traces

Data Preparation There are 1304 sessions in the full Calcium POYO+ model and 30 holdout drifting gratings sessions. The raw data for all sessions is ~360GB and processed data uses ~58GB. To prepare the data, run:

brainsets prepare allen_visual_coding_ophys_2016 --raw-dir=data/raw --processed-dir=data/processed

Training

python train.py --config-name=train_calcium_poyo_plus.yaml

Check out configs/train_calcium_poyo_plus.yaml for full-model config and configs/train_calcium_poyo_plus_single_session.yaml for a single-session example.

Pretrained weights

A Calcium POYO+ checkpoint trained on the full 1304-session Allen Visual Coding corpus is available here:

Download it and load it with the same model config used at training time (configs/model/calcium_poyo_plus.yaml):

The unit_emb and session_emb InfiniteVocabEmbedding layers will be materialized automatically from the vocabulary saved alongside the weights (116702 units, 1306 sessions).

Test results

Held-out test metrics for this checkpoint, averaged across recordings (overall average_test_metric = 0.3352):

Task Metric Value
drifting_gratings_orientation Accuracy 0.4719
drifting_gratings_temporal_frequency Accuracy 0.4503
static_gratings_orientation Accuracy 0.3347
static_gratings_phase Accuracy 0.2664
static_gratings_spatial_frequency Accuracy 0.4455
natural_scenes Accuracy 0.0968
natural_movie_one_frame mean(Accuracy, WithinDeltaAcc) 0.4296
natural_movie_two_frame mean(Accuracy, WithinDeltaAcc) 0.3685
natural_movie_three_frame mean(Accuracy, WithinDeltaAcc) 0.1598
pupil_location MeanSquaredError (z-normalized) 0.2791
running_speed MeanSquaredError (z-normalized) 0.4021

Note

Performance numbers differ from those reported in the paper. In particular, our running_speed prediction target was actually the monotonically-increasing timestamps rather than the m/s running-speed value so that row is not directly comparable. The remaining differences are attributable to small hyperparameter/implementation deltas and stochasticity in training.

Finetuning

To finetune a pre-trained model, download the checkpoint above and run:

python finetune.py ckpt_path="/path/to/epoch_epoch=414.ckpt"
  • Set which model and dataset to use in configs/finetune.yaml.
  • Update ckpt_path to your model checkpoint location.

Cite

If you use this code, please cite our paper:

@inproceedings{azabou2025multisession,
  author = {Azabou, Mehdi and Pan, Krystal and Arora, Vinam and Knight, Ian and Dyer, Eva and Richards, Blake A},
  booktitle = {International Conference on Learning Representations},
  editor = {Y. Yue and A. Garg and N. Peng and F. Sha and R. Yu},
  pages = {59654--59677},
  title = {Multi-session, multi-task neural decoding from distinct cell-types and brain regions},
  url = {https://proceedings.iclr.cc/paper_files/paper/2025/file/953390c834451505703c9da45de634d8-Paper-Conference.pdf},
  volume = {2025},
  year = {2025}
}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages