Official codebase for POYO+ from ICLR 2025 (Paper Link).
POYO+ is a multi-task version of POYO.
This is an example training script for the Calcium POYO+ model in Azabou and Pan et al. 2025, corresponding to the module torch_brain.models.CalciumPOYOPlus.
First, install uv by following the steps here. Then, create your Python environment:
uv venv venv -p 3.11
source venv/bin/activate
uv pip install -r requirements.txtData Preparation There are 1304 sessions in the full Calcium POYO+ model and 30 holdout drifting gratings sessions. The raw data for all sessions is ~360GB and processed data uses ~58GB. To prepare the data, run:
brainsets prepare allen_visual_coding_ophys_2016 --raw-dir=data/raw --processed-dir=data/processedTraining
python train.py --config-name=train_calcium_poyo_plus.yamlCheck out configs/train_calcium_poyo_plus.yaml for full-model config and configs/train_calcium_poyo_plus_single_session.yaml for a single-session example.
A Calcium POYO+ checkpoint trained on the full 1304-session Allen Visual Coding corpus is available here:
Download it and load it with the same model config used at training time
(configs/model/calcium_poyo_plus.yaml):
The unit_emb and session_emb InfiniteVocabEmbedding layers will be
materialized automatically from the vocabulary saved alongside the weights
(116702 units, 1306 sessions).
Held-out test metrics for this checkpoint, averaged across recordings (overall
average_test_metric = 0.3352):
| Task | Metric | Value |
|---|---|---|
| drifting_gratings_orientation | Accuracy | 0.4719 |
| drifting_gratings_temporal_frequency | Accuracy | 0.4503 |
| static_gratings_orientation | Accuracy | 0.3347 |
| static_gratings_phase | Accuracy | 0.2664 |
| static_gratings_spatial_frequency | Accuracy | 0.4455 |
| natural_scenes | Accuracy | 0.0968 |
| natural_movie_one_frame | mean(Accuracy, WithinDeltaAcc) | 0.4296 |
| natural_movie_two_frame | mean(Accuracy, WithinDeltaAcc) | 0.3685 |
| natural_movie_three_frame | mean(Accuracy, WithinDeltaAcc) | 0.1598 |
| pupil_location | MeanSquaredError (z-normalized) | 0.2791 |
| running_speed | MeanSquaredError (z-normalized) | 0.4021 |
Note
Performance numbers differ from those reported in the paper. In particular, our
running_speed prediction target was actually the monotonically-increasing
timestamps rather than the m/s running-speed value so that row is not
directly comparable. The remaining differences are attributable to small
hyperparameter/implementation deltas and stochasticity in training.
To finetune a pre-trained model, download the checkpoint above and run:
python finetune.py ckpt_path="/path/to/epoch_epoch=414.ckpt"- Set which model and dataset to use in
configs/finetune.yaml. - Update
ckpt_pathto your model checkpoint location.
If you use this code, please cite our paper:
@inproceedings{azabou2025multisession,
author = {Azabou, Mehdi and Pan, Krystal and Arora, Vinam and Knight, Ian and Dyer, Eva and Richards, Blake A},
booktitle = {International Conference on Learning Representations},
editor = {Y. Yue and A. Garg and N. Peng and F. Sha and R. Yu},
pages = {59654--59677},
title = {Multi-session, multi-task neural decoding from distinct cell-types and brain regions},
url = {https://proceedings.iclr.cc/paper_files/paper/2025/file/953390c834451505703c9da45de634d8-Paper-Conference.pdf},
volume = {2025},
year = {2025}
}