Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion brainscore_vision/models/cornet_s/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, model):
relative_path="cornet_s/cornet_s_epoch43.pth.tar",
version_id="null",
sha1="a4bfd8eda33b45fd945da1b972ab0b7cad38d60f")
checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage) # map onto cpu
checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage, weights_only=False) # map onto cpu
model.load_state_dict(checkpoint['state_dict'])
model = model.module # unwrap
preprocessing = functools.partial(load_preprocess_images, image_size=224)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from brainscore_vision.model_helpers.check_submission import check_models
import functools
from transformers import AutoFeatureExtractor, CvtForImageClassification
from transformers import CvtForImageClassification
from brainscore_vision.model_helpers.activations.pytorch import PytorchWrapper
from PIL import Image
import numpy as np
Expand All @@ -15,9 +15,9 @@ def get_model(name):
assert name == 'cvt_cvt-w24-384-in22k_finetuned-in1k_4'
# https://huggingface.co/models?sort=downloads&search=cvt
image_size = 384
processor = AutoFeatureExtractor.from_pretrained('microsoft/cvt-w24-384-22k')
model = CvtForImageClassification.from_pretrained('microsoft/cvt-w24-384-22k')
preprocessing = functools.partial(load_preprocess_images, processor=processor, image_size=image_size)
# Use torchvision preprocessing (standard ImageNet normalization) instead of HuggingFace processor
preprocessing = functools.partial(load_preprocess_images, processor=None, image_size=image_size)
wrapper = PytorchWrapper(identifier=name, model=model, preprocessing=preprocessing)
wrapper.image_size = image_size

Expand Down
14 changes: 12 additions & 2 deletions environment_lock.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ dependencies:
- xz=5.4.6
- zlib=1.2.13
- pip:
- albucore==0.0.24
- albumentations==2.0.8
- anyio==4.4.0
- appnope==0.1.4
- argon2-cffi==23.1.0
Expand All @@ -44,6 +46,7 @@ dependencies:
- cloudpickle==3.0.0
- comm==0.2.2
- contourpy==1.2.1
- cornet==0.1.0
- cycler==0.12.1
- dask==2024.8.1
- debugpy==1.8.5
Expand All @@ -64,6 +67,7 @@ dependencies:
- h5py==3.11.0
- httpcore==1.0.5
- httpx==0.27.0
- huggingface_hub==0.36.2
- idna==3.7
- importlib-metadata==4.13.0
- iniconfig==2.0.0
Expand Down Expand Up @@ -128,6 +132,8 @@ dependencies:
- pure-eval==0.2.3
- pybtex==0.24.0
- pycparser==2.22
- pydantic==2.12.5
- pydantic_core==2.41.5
- pygments==2.18.0
- pyparsing==3.1.2
- pytest==8.3.2
Expand All @@ -148,6 +154,7 @@ dependencies:
- rfc3986-validator==0.1.1
- rpds-py==0.20.0
- s3transfer==0.10.2
- safetensors==0.7.0
- scikit-learn==1.5.1
- scipy==1.14.1
- send2trash==1.8.3
Expand All @@ -161,12 +168,15 @@ dependencies:
- terminado==0.18.1
- threadpoolctl==3.5.0
- tinycss2==1.3.0
- timm==1.0.24
- tokenizers==0.22.2
- toolz==0.12.1
- torch==2.4.0
- torchvision==0.19.0
- torch==2.6.0
- torchvision==0.21.0
- tornado==6.4.1
- tqdm==4.66.5
- traitlets==5.14.3
- transformers==4.47.0
- types-python-dateutil==2.9.0.20240821
- typing-extensions==4.12.2
- tzdata==2024.1
Expand Down
16 changes: 11 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies = [
"brainscore-core",
"result-caching",
"importlib-metadata<5", # workaround to https://github.com/brain-score/brainio/issues/28
"scikit-learn<1.8", # for metric_helpers/transformations.py cross-validation
"scikit-learn<1.6", # multi_class param removed in 1.6; also for metric_helpers/transformations.py cross-validation
"scipy", # for benchmark_helpers/properties_common.py
"opencv-python", # for microsaccades
"h5py",
Expand All @@ -31,6 +31,8 @@ dependencies = [
"networkx",
"eva-decord",
"psutil",
"torch>=2.6",
"torchvision>=0.21",
]

[project.optional-dependencies]
Expand All @@ -39,10 +41,14 @@ test = [
"pytest_check",
"pytest-mock",
"pytest-timeout",
"torch",
"torchvision",
"matplotlib", # for examples
"pytest-mock",
"matplotlib",
]

competition-models = [
"transformers>=4.45",
"albumentations>=2.0",
"timm>=1.0",
"cornet @ git+https://github.com/dicarlolab/CORnet.git",
]

[build-system]
Expand Down
Loading