Skip to content

Commit ca325bc

Browse files
committed
Add warnings and protections around loading pickles
1 parent e19b5e5 commit ca325bc

4 files changed

Lines changed: 15 additions & 0 deletions

File tree

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
![header](envelope.png)
22

3+
**Warning: Orbformer checkpoints are stored using `pickle`. Never read a checkpoint from an untrusted source.**
4+
35
# OneQMC
46

57
This package provides an implementation of the [Orbformer wave function foundation model](https://arxiv.org/abs/2506.19960).
@@ -61,6 +63,9 @@ python scripts/transferable.py -d <subdirectory of ./data> -n <number of trainin
6163
We recommend using distinct output directories for every training run.
6264
Regarding other optional arguments, run `python scripts/transferable.py -h` for more information.
6365

66+
**Note**: Checkpoints are stored using `pickle`. This means that opening a checkpoint from an untrusted source is a major security risk. Only ever open checkpoint files from trusted sources. To prevent reading untrusted pickle files, checkpoint reading is disabled by default and can be
67+
re-enabled by settings the environment variables `ORBFORMER_PICKLE_LOADING=1`.
68+
6469
### Preparing new structure data for fine-tuning
6570

6671
To create a new dataset for fine-tuning, we recommend using [qcelemental](https://github.com/MolSSI/QCElemental) format.

src/oneqmc/entrypoint.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919

2020
def load_chkpt_file(chkpt: str, discard_sampler_state: bool) -> Tuple[TrainState, int]:
21+
if not os.environ["ORBFORMER_PICKLE_LOADING"] == "1":
22+
raise PermissionError("Loading pickle files is disable for security. Set ORBFORMER_PICKLE_LOADING=1 to allow")
2123
with open(chkpt, "rb") as chkpt_file:
2224
init_step, (smpl_state, param_state, opt_state) = pickle.load(chkpt_file)
2325
if discard_sampler_state:
@@ -28,6 +30,8 @@ def load_chkpt_file(chkpt: str, discard_sampler_state: bool) -> Tuple[TrainState
2830

2931

3032
def load_density_chkpt_file(chkpt: str, discard_sampler_state: bool) -> Tuple[Tuple, int]:
33+
if not os.environ["ORBFORMER_PICKLE_LOADING"] == "1":
34+
raise PermissionError("Loading pickle files is disable for security. Set ORBFORMER_PICKLE_LOADING=1 to allow")
3135
with open(chkpt, "rb") as chkpt_file:
3236
init_step, (param_state, opt_state) = pickle.load(chkpt_file)
3337
return (param_state, opt_state), init_step

src/oneqmc/log.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ def close(self):
133133
def last(self):
134134
step_fast, step_slow = -1, -1 # account for the case where a queue is not initialized
135135
while self.fast_chkpts:
136+
if not os.environ["ORBFORMER_PICKLE_LOADING"] == "1":
137+
raise PermissionError("Loading pickle files is disable for security. Set ORBFORMER_PICKLE_LOADING=1 to allow")
136138
with self.fast_chkpts.pop(-1).path.open("rb") as f:
137139
step_fast, last_chkpt_fast = pickle.load(f)
138140
if not jax.tree.reduce(
@@ -142,6 +144,8 @@ def last(self):
142144
):
143145
break
144146
while self.slow_chkpts:
147+
if not os.environ["ORBFORMER_PICKLE_LOADING"] == "1":
148+
raise PermissionError("Loading pickle files is disable for security. Set ORBFORMER_PICKLE_LOADING=1 to allow")
145149
with self.slow_chkpts.pop(-1).path.open("rb") as f:
146150
step_slow, last_chkpt_slow = pickle.load(f)
147151
if not jax.tree.reduce(

tests/integration_tests/test_scripts.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def runner(extra_args):
5252
],
5353
cwd=project_root,
5454
capture_output=True,
55+
env={"ORBFORMER_PICKLE_LOADING": "1"}
5556
)
5657
if result.returncode != 0:
5758
raise OneQMCProcessError(result.stderr.decode())
@@ -90,6 +91,7 @@ def runner(extra_args):
9091
],
9192
cwd=project_root,
9293
capture_output=True,
94+
env={"ORBFORMER_PICKLE_LOADING": "1"}
9395
)
9496
if result.returncode != 0:
9597
raise OneQMCProcessError(result.stderr.decode())

0 commit comments

Comments
 (0)