Skip to content

Commit 2b7ccf0

Browse files
committed
fea: use upstream fairchem
1 parent 8d3417a commit 2b7ccf0

6 files changed

Lines changed: 83 additions & 591 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ orb = ["orb-models>=0.6.0"]
5757
sevenn = ["sevenn[torchsim]>=0.12.1"]
5858
graphpes = ["graph-pes>=0.1", "mace-torch>=0.3.12"]
5959
nequip = ["nequip>=0.17.0"]
60+
fairchem = ["fairchem-core @ git+https://github.com/facebookresearch/fairchem.git@main#subdirectory=packages/fairchem-core"]
6061
nequix = ["nequix[torch-sim]>=0.4.5"]
61-
fairchem = ["fairchem-core>=2.7", "scipy<1.17.0"]
6262
docs = [
6363
"autodoc_pydantic==2.2.0",
6464
"furo==2024.8.6",

tests/models/test_fairchem.py

Lines changed: 0 additions & 289 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,14 @@
11
import traceback
22

33
import pytest
4-
import torch
54

6-
import torch_sim as ts
75
from tests.conftest import DEVICE, DTYPE
86
from tests.models.conftest import make_validate_model_outputs_test
97

108

119
try:
12-
from collections.abc import Callable
13-
14-
from ase.build import bulk, fcc100, molecule
15-
from fairchem.core.calculate.pretrained_mlip import (
16-
pretrained_checkpoint_path_from_name,
17-
)
1810
from huggingface_hub.utils._auth import get_token
1911

20-
import torch_sim as ts
2112
from torch_sim.models.fairchem import FairChemModel
2213

2314
except (ImportError, OSError, RuntimeError, AttributeError, ValueError):
@@ -33,205 +24,6 @@ def eqv2_uma_model_pbc() -> FairChemModel:
3324
return FairChemModel(model="uma-s-1p1", task_name="omat", device=DEVICE)
3425

3526

36-
@pytest.mark.skipif(
37-
get_token() is None, reason="Requires HuggingFace authentication for UMA model access"
38-
)
39-
@pytest.mark.parametrize("task_name", ["omat", "omol", "oc20"])
40-
def test_task_initialization(task_name: str) -> None:
41-
"""Test that different UMA task names work correctly."""
42-
model = FairChemModel(
43-
model="uma-s-1p1", task_name=task_name, device=torch.device("cpu")
44-
)
45-
assert model.task_name
46-
assert str(model.task_name.value) == task_name
47-
assert hasattr(model, "predictor")
48-
49-
50-
@pytest.mark.skipif(
51-
get_token() is None, reason="Requires HuggingFace authentication for UMA model access"
52-
)
53-
@pytest.mark.parametrize(
54-
("task_name", "systems_func"),
55-
[
56-
(
57-
"omat",
58-
lambda: [
59-
bulk("Si", "diamond", a=5.43),
60-
bulk("Al", "fcc", a=4.05),
61-
bulk("Fe", "bcc", a=2.87),
62-
bulk("Cu", "fcc", a=3.61),
63-
],
64-
),
65-
(
66-
"omol",
67-
lambda: [molecule("H2O"), molecule("CO2"), molecule("CH4"), molecule("NH3")],
68-
),
69-
],
70-
)
71-
def test_homogeneous_batching(task_name: str, systems_func: Callable) -> None:
72-
"""Test batching multiple systems with the same task."""
73-
systems = systems_func()
74-
75-
# Add molecular properties for molecules
76-
if task_name == "omol":
77-
for mol in systems:
78-
mol.info |= {"charge": 0, "spin": 1}
79-
80-
model = FairChemModel(model="uma-s-1p1", task_name=task_name, device=DEVICE)
81-
state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE)
82-
results = model(state)
83-
84-
# Check batch dimensions
85-
assert results["energy"].shape == (4,)
86-
assert results["forces"].shape[0] == sum(len(s) for s in systems)
87-
assert results["forces"].shape[1] == 3
88-
89-
# Check that different systems have different energies
90-
energies = results["energy"]
91-
uniq_energies = torch.unique(energies, dim=0)
92-
assert len(uniq_energies) > 1, "Different systems should have different energies"
93-
94-
95-
@pytest.mark.skipif(
96-
get_token() is None, reason="Requires HuggingFace authentication for UMA model access"
97-
)
98-
def test_heterogeneous_tasks() -> None:
99-
"""Test different task types work with appropriate systems."""
100-
# Test molecule, material, and catalysis systems separately
101-
test_cases = [
102-
("omol", [molecule("H2O")]),
103-
("omat", [bulk("Pt", cubic=True)]),
104-
("oc20", [fcc100("Cu", (2, 2, 3), vacuum=8, periodic=True)]),
105-
]
106-
107-
for task_name, systems in test_cases:
108-
if task_name == "omol":
109-
systems[0].info |= {"charge": 0, "spin": 1}
110-
111-
model = FairChemModel(
112-
model="uma-s-1p1",
113-
task_name=task_name,
114-
device=DEVICE,
115-
)
116-
state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE)
117-
results = model(state)
118-
119-
assert results["energy"].shape[0] == 1
120-
assert results["forces"].dim() == 2
121-
assert results["forces"].shape[1] == 3
122-
123-
124-
@pytest.mark.skipif(
125-
get_token() is None, reason="Requires HuggingFace authentication for UMA model access"
126-
)
127-
@pytest.mark.parametrize(
128-
("systems_func", "expected_count"),
129-
[
130-
(lambda: [bulk("Si", "diamond", a=5.43)], 1), # Single system
131-
(
132-
lambda: [
133-
bulk("H", "bcc", a=2.0),
134-
bulk("Li", "bcc", a=3.0),
135-
bulk("Si", "diamond", a=5.43),
136-
bulk("Al", "fcc", a=4.05).repeat((2, 1, 1)),
137-
],
138-
4,
139-
), # Mixed sizes
140-
(
141-
lambda: [
142-
bulk(element, "fcc", a=4.0)
143-
for element in ("Al", "Cu", "Ni", "Pd", "Pt") * 3
144-
],
145-
15,
146-
), # Large batch
147-
],
148-
)
149-
def test_batch_size_variations(systems_func: Callable, expected_count: int) -> None:
150-
"""Test batching with different numbers and sizes of systems."""
151-
systems = systems_func()
152-
153-
model = FairChemModel(model="uma-s-1p1", task_name="omat", device=DEVICE)
154-
state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE)
155-
results = model(state)
156-
157-
assert results["energy"].shape == (expected_count,)
158-
assert results["forces"].shape[0] == sum(len(s) for s in systems)
159-
assert results["forces"].shape[1] == 3
160-
assert torch.isfinite(results["energy"]).all()
161-
assert torch.isfinite(results["forces"]).all()
162-
163-
164-
@pytest.mark.skipif(
165-
get_token() is None, reason="Requires HuggingFace authentication for UMA model access"
166-
)
167-
@pytest.mark.parametrize("compute_stress", [True, False])
168-
def test_stress_computation(*, compute_stress: bool) -> None:
169-
"""Test stress tensor computation."""
170-
systems = [bulk("Si", "diamond", a=5.43), bulk("Al", "fcc", a=4.05)]
171-
172-
model = FairChemModel(
173-
model="uma-s-1p1",
174-
task_name="omat",
175-
device=DEVICE,
176-
compute_stress=compute_stress,
177-
)
178-
state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE)
179-
results = model(state)
180-
181-
if compute_stress:
182-
assert "stress" in results
183-
assert results["stress"].shape == (2, 3, 3)
184-
assert torch.isfinite(results["stress"]).all()
185-
else:
186-
assert "stress" not in results
187-
188-
189-
@pytest.mark.skipif(
190-
get_token() is None, reason="Requires HuggingFace authentication for UMA model access"
191-
)
192-
def test_device_consistency() -> None:
193-
"""Test device consistency between model and data."""
194-
model = FairChemModel(model="uma-s-1p1", task_name="omat", device=DEVICE)
195-
system = bulk("Si", "diamond", a=5.43)
196-
state = ts.io.atoms_to_state([system], device=DEVICE, dtype=DTYPE)
197-
198-
results = model(state)
199-
assert results["energy"].device == DEVICE
200-
assert results["forces"].device == DEVICE
201-
202-
203-
@pytest.mark.skipif(
204-
get_token() is None, reason="Requires HuggingFace authentication for UMA model access"
205-
)
206-
def test_empty_batch_error() -> None:
207-
"""Test that empty batches raise appropriate errors."""
208-
model = FairChemModel(model="uma-s-1p1", task_name="omat", device=torch.device("cpu"))
209-
with pytest.raises((ValueError, RuntimeError, IndexError)):
210-
model(ts.io.atoms_to_state([], device=torch.device("cpu"), dtype=torch.float32))
211-
212-
213-
@pytest.mark.skipif(
214-
get_token() is None, reason="Requires HuggingFace authentication for UMA model access"
215-
)
216-
def test_load_from_checkpoint_path() -> None:
217-
"""Test loading model from a saved checkpoint file path."""
218-
checkpoint_path = pretrained_checkpoint_path_from_name("uma-s-1p1")
219-
loaded_model = FairChemModel(
220-
model=str(checkpoint_path), task_name="omat", device=DEVICE
221-
)
222-
223-
# Verify the loaded model works
224-
system = bulk("Si", "diamond", a=5.43)
225-
state = ts.io.atoms_to_state([system], device=DEVICE, dtype=DTYPE)
226-
results = loaded_model(state)
227-
228-
assert "energy" in results
229-
assert "forces" in results
230-
assert results["energy"].shape == (1,)
231-
assert torch.isfinite(results["energy"]).all()
232-
assert torch.isfinite(results["forces"]).all()
233-
234-
23527
test_fairchem_uma_model_outputs = pytest.mark.skipif(
23628
get_token() is None,
23729
reason="Requires HuggingFace authentication for UMA model access",
@@ -240,84 +32,3 @@ def test_load_from_checkpoint_path() -> None:
24032
model_fixture_name="eqv2_uma_model_pbc", device=DEVICE, dtype=DTYPE
24133
)
24234
)
243-
244-
245-
@pytest.mark.skipif(
246-
get_token() is None, reason="Requires HuggingFace authentication for UMA model access"
247-
)
248-
@pytest.mark.parametrize(
249-
("charge", "spin"),
250-
[
251-
(0.0, 0.0), # Neutral, no spin
252-
(1.0, 1.0), # +1 charge, spin=1 (doublet)
253-
(-1.0, 0.0), # -1 charge, no spin (singlet)
254-
(0.0, 2.0), # Neutral, spin=2 (triplet)
255-
],
256-
)
257-
def test_fairchem_charge_spin(charge: float, spin: float) -> None:
258-
"""Test that FairChemModel correctly handles charge and spin from atoms.info."""
259-
# Create a water molecule
260-
mol = molecule("H2O")
261-
262-
# Set charge and spin in ASE atoms.info
263-
mol.info["charge"] = charge
264-
mol.info["spin"] = spin
265-
266-
# Convert to SimState (should extract charge/spin)
267-
state = ts.io.atoms_to_state([mol], device=DEVICE, dtype=DTYPE)
268-
269-
# Verify charge/spin were extracted correctly
270-
assert state.charge is not None
271-
assert state.spin is not None
272-
assert state.charge[0].item() == charge
273-
assert state.spin[0].item() == spin
274-
275-
# Create model with UMA omol task (supports charge/spin for molecules)
276-
model = FairChemModel(
277-
model="uma-s-1p1",
278-
task_name="omol",
279-
device=DEVICE,
280-
)
281-
282-
# This should not raise an error
283-
result = model(state)
284-
285-
# Verify outputs exist
286-
assert "energy" in result
287-
assert result["energy"].shape == (1,)
288-
assert "forces" in result
289-
assert result["forces"].shape == (len(mol), 3)
290-
291-
# Verify outputs are finite
292-
assert torch.isfinite(result["energy"]).all()
293-
assert torch.isfinite(result["forces"]).all()
294-
295-
296-
# TODO: we should perhaps put something like this inside `validate_model_outputs`
297-
# the question is how we can do this with creating a circular dependency
298-
@pytest.mark.skipif(
299-
get_token() is None, reason="Requires HuggingFace authentication for UMA model access"
300-
)
301-
def test_fairchem_single_step_relax(rattled_si_sim_state: ts.SimState) -> None:
302-
"""Test a single optimization step with FairChemModel.
303-
304-
This verifies that the model works correctly with optimizers, particularly
305-
that it doesn't have issues with the computational graph (e.g., missing
306-
.detach() calls).
307-
"""
308-
model = FairChemModel(model="uma-s-1p1", task_name="omat", device=DEVICE)
309-
state = rattled_si_sim_state.to(device=DEVICE, dtype=DTYPE)
310-
311-
# Initialize FIRE optimizer
312-
opt_state = ts.fire_init(state, model)
313-
initial_positions = opt_state.positions.clone()
314-
_initial_energy = opt_state.energy.item()
315-
316-
# Run exactly one step
317-
opt_state = ts.fire_step(opt_state, model)
318-
319-
# Verify positions changed
320-
assert not torch.allclose(opt_state.positions, initial_positions)
321-
# Verify energy is still available and finite
322-
assert torch.isfinite(opt_state.energy).all()
323-
assert isinstance(opt_state.energy.item(), float)

tests/test_neighbors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def _all_nl_backends() -> list[Any]:
171171
not neighbors.VESIN_AVAILABLE, reason="Vesin is not installed"
172172
)
173173
_skip_vesin_ts = pytest.mark.skipif(
174-
not neighbors.VESIN_TORCH_AVAILABLE, reason="Vesin is not installed"
174+
not neighbors.VESIN_TORCHSCRIPT_AVAILABLE, reason="Vesin is not installed"
175175
)
176176

177177
_skip_alchemiops = pytest.mark.skipif(

0 commit comments

Comments
 (0)