Skip to content

added ReBind QM9 baseline; organometallic models running#1

Open
jwtoney wants to merge 1 commit into
mainfrom
rebind_baseline
Open

added ReBind QM9 baseline; organometallic models running#1
jwtoney wants to merge 1 commit into
mainfrom
rebind_baseline

Conversation

@jwtoney
Copy link
Copy Markdown
Collaborator

@jwtoney jwtoney commented May 25, 2026

Added baselines to test ReBind (https://arxiv.org/abs/2410.14696) on QM9, tmQMg, and BOS-TMC datasets. QM9 has completed, tmQMg and BOS-TMC are still running. Using hyperparameters and epochs from original ReBind implementation.

@jwtoney jwtoney requested a review from sid-betalol May 25, 2026 00:34
Copy link
Copy Markdown
Collaborator

@sid-betalol sid-betalol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work, @jwtoney! This will need some changes, though, for some minor issues.

Comment thread pyproject.toml
"ignore::UserWarning",
]

[tool.uv.sources]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch is globally pinned to the pytorch-cu124. On macOS arm64, uv run pytest -q fails before pytest starts:

error: Distribution `torch==2.6.0+cu124 ...` can't be installed because it doesn't have a source distribution or wheel for the current platform

I suggest making the default dependency set CPU/platform-compatible, and moving CUDA Torch into an optional group or environment-specific install path. For example:

  • keep normal torch in base dependencies so uv sync --dev works on macOS/Linux CPU;
  • add a CUDA extra or separate docs command for GPU training environments;

Comment thread README.md
```

```bash
sbatch scripts/train.slurm configs/qm9.yaml
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

scripts/train.slurm was not pushed. Do you mean scripts/train.sh?

padding_mask=node_mask,
compute_loss=True,
)
loss_cache, conformer_cache = cache_out["loss"], cache_out["conformer_hat"]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I am misunderstanding something for this comment, so please clarify.

The conformer head output is stored in conformer_cache:

loss_cache, conformer_cache = cache_out["loss"], cache_out["conformer_hat"]

but later the code sets:

inputs["pred_conformation"] = node_embedding

and passes that to the residual head as conformer_base:

conformer_base=inputs["pred_conformation"],

That looks wrong: node_embedding has shape (batch, atoms, d_model), while a conformer base should be coordinates with shape (batch, atoms, 3)? The natural value here appears to be conformer_cache, not node_embedding. If this path is exercised, it can either crash on a shape mismatch or train against the wrong tensor.

If what I am saying is correct, then the fix would be to change inputs["pred_conformation"] to use conformer_cache, then add/strengthen a forward-pass test that exercises the patched forward with a real collated batch and asserts finite loss/output shapes.

Comment thread tests/conftest.py

import pytest

QM9_PATH = Path("/home/gridsan/jtoney/ElemNet/benchmarking/datasets/QM9-full.csv")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test fixtures point at absolute cluster paths:

QM9_PATH = Path("/home/gridsan/jtoney/ElemNet/benchmarking/datasets/QM9-full.csv")
TMQMG_PATH = Path("/home/gridsan/jtoney/ElemNet/benchmarking/datasets/tmQMg-full.csv")
BOSTMC_PATH = Path("/home/gridsan/jtoney/BOSTMC/datasets/BOSTMC-low-spin.csv")

On any contributor machine or GitHub runner without those files, the dataset and forward tests will skip. That means CI will not actually validate the new featurization, MOL2 parser, LJ patch, or training path.

I suggest committing tiny synthetic fixture CSVs under tests/fixtures/ that cover:

  • one QM9-style XYZ row;
  • one MOL2+XYZ organometallic row with a transition metal;
  • at least one ring/aromatic case if ring flags matter.

We should use private full datasets only for optional/integration tests, not core CI tests.

@@ -0,0 +1,123 @@
"""Convert raw XYZ / MOL2 blocks into ReBind-format graph dicts.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This module touches external/ReBIND as soon as it is imported. That means a normal import path like:

from step_up.data.csv_dataset import CSVMoleculeDataset

can try to load vendored ReBIND code before any test body runs. This is fragile because pytest imports test modules during collection. If the repo was cloned without --recursive, external/ReBIND may exist as an empty submodule directory, but the actual files such as external/ReBIND/models and external/ReBIND/data/utils.py are missing. In that case, pytest can fail during collection with an import error instead of reaching a test fixture that could skip with a useful message.

There is also an inconsistency between the two modules: featurize.py checks for a concrete file (data/utils.py) and raises a helpful FileNotFoundError, while rebind.py only checks whether the submodule root path exists. An empty submodule directory passes that check, then fails later with a less helpful ModuleNotFoundError.

I suggest we make rebind.py check for concrete submodule contents, e.g. external/ReBIND/models or external/ReBIND/data/utils.py, and raise the same actionable error: Run: git submodule update --init --recursive. Also ensure CI and docs use recursive submodule checkout. Longer term, defer vendored ReBIND imports until build_rebind() or the specific featurization function is called, so importing dataset utilities does not require the model submodule to be initialized.

Comment thread configs/bostmc.yaml
# - `charge` and `spinmult` columns are currently IGNORED.
# - Recommendation for first publishable run: pre-filter to singlets-only (spinmult == 1) for the cleanest comparison to tmQMg. The CSVMoleculeDataset does not yet expose a filter knob, add one when this config is first run.
# - Follow-up: condition the model on (charge, spinmult) as a global feature.
dataset_path: /home/gridsan/jtoney/BOSTMC/datasets/BOSTMC-low-spin.csv
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config comment says charge and spinmult are ignored and recommends pre-filtering to singlets before a publishable run, but the config still points at BOSTMC-low-spin.csv with no filter support:

dataset_path: /home/gridsan/jtoney/BOSTMC/datasets/BOSTMC-low-spin.csv

For a benchmark, mixing singlets and doublets without conditioning on spin/charge can make comparisons ambiguous and hard to reproduce.

It'd be good to either add filtering to CSVMoleculeDataset / config and make this config singlet-only, or rename/mark the BOSTMC full config as experimental until spin/charge conditioning exists.

Comment thread src/step_up/train.py

with open(out_dir / "history.json", "w") as f:
json.dump(history, f, indent=2)
del test_set # held out; first round does not evaluate on test
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code creates a test split but then discards it:

train_set, val_set, test_set = random_split(...)
...
del test_set  # held out; first round does not evaluate on test

Do you think we should save test metrics also?
After training, we could load/use the best checkpoint and evaluate on test_set, then write test_metrics.json. It might be worth setting up a wandb for this project @luispintoc?

zs = atomic_numbers.to(torch.long)
per_z: dict[int, float] = {}
for z in torch.unique(zs):
if int(z.item()) == 0:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function says atomic_numbers should contain true atomic numbers and uses 0 as padding:

if int(z.item()) == 0:
    continue

But the graph data stores node_type as Z - 1, where hydrogen is 0. The test also passes [5, 6, 7] while commenting that these are C/N/O, which are actually Z - 1 indices, not atomic numbers. If callers pass node_type directly, hydrogen is silently dropped, and all element keys are off by one.

We should choose one representation and enforce it. Either:

  • rename the argument to atomic_number_indices, document Z - 1, and use a separate padding mask instead of skipping 0; or
  • require true atomic numbers and explicitly convert node_type + 1 before calling.

Add a test that includes hydrogen to prevent this regression.

self._df = _read_csv_subset(self.path, cols, nrows=subset_size)

if validate:
self._valid_indices = self._validate_rows()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CSVMoleculeDataset._validate_rows catches every exception, records the class name, and keeps going. If every row fails, len(dataset) becomes 0. train.py then uses max(len(train_loader), 1) and can continue into a run with empty loaders, meaningless metrics, and no useful checkpoint.

We should fail fast when validation keeps zero rows, and consider adding a configurable maximum drop rate for full runs.


The CSV is loaded into a pandas DataFrame at construction time (the
``subset_size`` knob keeps memory bounded for smoke runs). When
``validate=True`` (the default), the dataset then walks every row, calls
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Construction with validate=True featurizes every row once to find valid indices. __getitem__ featurizes the same row again during training.

We can cache validated graph dicts for small/subset runs, or store validation results plus a cheaper row-level validity marker? At minimum, we should expose a config knob to disable validation for already-clean production datasets.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants