11import traceback
22
33import pytest
4- import torch
54
6- import torch_sim as ts
75from tests .conftest import DEVICE , DTYPE
86from tests .models .conftest import make_validate_model_outputs_test
97
108
119try :
12- from collections .abc import Callable
13-
14- from ase .build import bulk , fcc100 , molecule
15- from fairchem .core .calculate .pretrained_mlip import (
16- pretrained_checkpoint_path_from_name ,
17- )
1810 from huggingface_hub .utils ._auth import get_token
1911
20- import torch_sim as ts
2112 from torch_sim .models .fairchem import FairChemModel
2213
2314except (ImportError , OSError , RuntimeError , AttributeError , ValueError ):
@@ -33,205 +24,6 @@ def eqv2_uma_model_pbc() -> FairChemModel:
3324 return FairChemModel (model = "uma-s-1p1" , task_name = "omat" , device = DEVICE )
3425
3526
36- @pytest .mark .skipif (
37- get_token () is None , reason = "Requires HuggingFace authentication for UMA model access"
38- )
39- @pytest .mark .parametrize ("task_name" , ["omat" , "omol" , "oc20" ])
40- def test_task_initialization (task_name : str ) -> None :
41- """Test that different UMA task names work correctly."""
42- model = FairChemModel (
43- model = "uma-s-1p1" , task_name = task_name , device = torch .device ("cpu" )
44- )
45- assert model .task_name
46- assert str (model .task_name .value ) == task_name
47- assert hasattr (model , "predictor" )
48-
49-
50- @pytest .mark .skipif (
51- get_token () is None , reason = "Requires HuggingFace authentication for UMA model access"
52- )
53- @pytest .mark .parametrize (
54- ("task_name" , "systems_func" ),
55- [
56- (
57- "omat" ,
58- lambda : [
59- bulk ("Si" , "diamond" , a = 5.43 ),
60- bulk ("Al" , "fcc" , a = 4.05 ),
61- bulk ("Fe" , "bcc" , a = 2.87 ),
62- bulk ("Cu" , "fcc" , a = 3.61 ),
63- ],
64- ),
65- (
66- "omol" ,
67- lambda : [molecule ("H2O" ), molecule ("CO2" ), molecule ("CH4" ), molecule ("NH3" )],
68- ),
69- ],
70- )
71- def test_homogeneous_batching (task_name : str , systems_func : Callable ) -> None :
72- """Test batching multiple systems with the same task."""
73- systems = systems_func ()
74-
75- # Add molecular properties for molecules
76- if task_name == "omol" :
77- for mol in systems :
78- mol .info |= {"charge" : 0 , "spin" : 1 }
79-
80- model = FairChemModel (model = "uma-s-1p1" , task_name = task_name , device = DEVICE )
81- state = ts .io .atoms_to_state (systems , device = DEVICE , dtype = DTYPE )
82- results = model (state )
83-
84- # Check batch dimensions
85- assert results ["energy" ].shape == (4 ,)
86- assert results ["forces" ].shape [0 ] == sum (len (s ) for s in systems )
87- assert results ["forces" ].shape [1 ] == 3
88-
89- # Check that different systems have different energies
90- energies = results ["energy" ]
91- uniq_energies = torch .unique (energies , dim = 0 )
92- assert len (uniq_energies ) > 1 , "Different systems should have different energies"
93-
94-
95- @pytest .mark .skipif (
96- get_token () is None , reason = "Requires HuggingFace authentication for UMA model access"
97- )
98- def test_heterogeneous_tasks () -> None :
99- """Test different task types work with appropriate systems."""
100- # Test molecule, material, and catalysis systems separately
101- test_cases = [
102- ("omol" , [molecule ("H2O" )]),
103- ("omat" , [bulk ("Pt" , cubic = True )]),
104- ("oc20" , [fcc100 ("Cu" , (2 , 2 , 3 ), vacuum = 8 , periodic = True )]),
105- ]
106-
107- for task_name , systems in test_cases :
108- if task_name == "omol" :
109- systems [0 ].info |= {"charge" : 0 , "spin" : 1 }
110-
111- model = FairChemModel (
112- model = "uma-s-1p1" ,
113- task_name = task_name ,
114- device = DEVICE ,
115- )
116- state = ts .io .atoms_to_state (systems , device = DEVICE , dtype = DTYPE )
117- results = model (state )
118-
119- assert results ["energy" ].shape [0 ] == 1
120- assert results ["forces" ].dim () == 2
121- assert results ["forces" ].shape [1 ] == 3
122-
123-
124- @pytest .mark .skipif (
125- get_token () is None , reason = "Requires HuggingFace authentication for UMA model access"
126- )
127- @pytest .mark .parametrize (
128- ("systems_func" , "expected_count" ),
129- [
130- (lambda : [bulk ("Si" , "diamond" , a = 5.43 )], 1 ), # Single system
131- (
132- lambda : [
133- bulk ("H" , "bcc" , a = 2.0 ),
134- bulk ("Li" , "bcc" , a = 3.0 ),
135- bulk ("Si" , "diamond" , a = 5.43 ),
136- bulk ("Al" , "fcc" , a = 4.05 ).repeat ((2 , 1 , 1 )),
137- ],
138- 4 ,
139- ), # Mixed sizes
140- (
141- lambda : [
142- bulk (element , "fcc" , a = 4.0 )
143- for element in ("Al" , "Cu" , "Ni" , "Pd" , "Pt" ) * 3
144- ],
145- 15 ,
146- ), # Large batch
147- ],
148- )
149- def test_batch_size_variations (systems_func : Callable , expected_count : int ) -> None :
150- """Test batching with different numbers and sizes of systems."""
151- systems = systems_func ()
152-
153- model = FairChemModel (model = "uma-s-1p1" , task_name = "omat" , device = DEVICE )
154- state = ts .io .atoms_to_state (systems , device = DEVICE , dtype = DTYPE )
155- results = model (state )
156-
157- assert results ["energy" ].shape == (expected_count ,)
158- assert results ["forces" ].shape [0 ] == sum (len (s ) for s in systems )
159- assert results ["forces" ].shape [1 ] == 3
160- assert torch .isfinite (results ["energy" ]).all ()
161- assert torch .isfinite (results ["forces" ]).all ()
162-
163-
164- @pytest .mark .skipif (
165- get_token () is None , reason = "Requires HuggingFace authentication for UMA model access"
166- )
167- @pytest .mark .parametrize ("compute_stress" , [True , False ])
168- def test_stress_computation (* , compute_stress : bool ) -> None :
169- """Test stress tensor computation."""
170- systems = [bulk ("Si" , "diamond" , a = 5.43 ), bulk ("Al" , "fcc" , a = 4.05 )]
171-
172- model = FairChemModel (
173- model = "uma-s-1p1" ,
174- task_name = "omat" ,
175- device = DEVICE ,
176- compute_stress = compute_stress ,
177- )
178- state = ts .io .atoms_to_state (systems , device = DEVICE , dtype = DTYPE )
179- results = model (state )
180-
181- if compute_stress :
182- assert "stress" in results
183- assert results ["stress" ].shape == (2 , 3 , 3 )
184- assert torch .isfinite (results ["stress" ]).all ()
185- else :
186- assert "stress" not in results
187-
188-
189- @pytest .mark .skipif (
190- get_token () is None , reason = "Requires HuggingFace authentication for UMA model access"
191- )
192- def test_device_consistency () -> None :
193- """Test device consistency between model and data."""
194- model = FairChemModel (model = "uma-s-1p1" , task_name = "omat" , device = DEVICE )
195- system = bulk ("Si" , "diamond" , a = 5.43 )
196- state = ts .io .atoms_to_state ([system ], device = DEVICE , dtype = DTYPE )
197-
198- results = model (state )
199- assert results ["energy" ].device == DEVICE
200- assert results ["forces" ].device == DEVICE
201-
202-
203- @pytest .mark .skipif (
204- get_token () is None , reason = "Requires HuggingFace authentication for UMA model access"
205- )
206- def test_empty_batch_error () -> None :
207- """Test that empty batches raise appropriate errors."""
208- model = FairChemModel (model = "uma-s-1p1" , task_name = "omat" , device = torch .device ("cpu" ))
209- with pytest .raises ((ValueError , RuntimeError , IndexError )):
210- model (ts .io .atoms_to_state ([], device = torch .device ("cpu" ), dtype = torch .float32 ))
211-
212-
213- @pytest .mark .skipif (
214- get_token () is None , reason = "Requires HuggingFace authentication for UMA model access"
215- )
216- def test_load_from_checkpoint_path () -> None :
217- """Test loading model from a saved checkpoint file path."""
218- checkpoint_path = pretrained_checkpoint_path_from_name ("uma-s-1p1" )
219- loaded_model = FairChemModel (
220- model = str (checkpoint_path ), task_name = "omat" , device = DEVICE
221- )
222-
223- # Verify the loaded model works
224- system = bulk ("Si" , "diamond" , a = 5.43 )
225- state = ts .io .atoms_to_state ([system ], device = DEVICE , dtype = DTYPE )
226- results = loaded_model (state )
227-
228- assert "energy" in results
229- assert "forces" in results
230- assert results ["energy" ].shape == (1 ,)
231- assert torch .isfinite (results ["energy" ]).all ()
232- assert torch .isfinite (results ["forces" ]).all ()
233-
234-
23527test_fairchem_uma_model_outputs = pytest .mark .skipif (
23628 get_token () is None ,
23729 reason = "Requires HuggingFace authentication for UMA model access" ,
@@ -240,84 +32,3 @@ def test_load_from_checkpoint_path() -> None:
24032 model_fixture_name = "eqv2_uma_model_pbc" , device = DEVICE , dtype = DTYPE
24133 )
24234)
243-
244-
245- @pytest .mark .skipif (
246- get_token () is None , reason = "Requires HuggingFace authentication for UMA model access"
247- )
248- @pytest .mark .parametrize (
249- ("charge" , "spin" ),
250- [
251- (0.0 , 0.0 ), # Neutral, no spin
252- (1.0 , 1.0 ), # +1 charge, spin=1 (doublet)
253- (- 1.0 , 0.0 ), # -1 charge, no spin (singlet)
254- (0.0 , 2.0 ), # Neutral, spin=2 (triplet)
255- ],
256- )
257- def test_fairchem_charge_spin (charge : float , spin : float ) -> None :
258- """Test that FairChemModel correctly handles charge and spin from atoms.info."""
259- # Create a water molecule
260- mol = molecule ("H2O" )
261-
262- # Set charge and spin in ASE atoms.info
263- mol .info ["charge" ] = charge
264- mol .info ["spin" ] = spin
265-
266- # Convert to SimState (should extract charge/spin)
267- state = ts .io .atoms_to_state ([mol ], device = DEVICE , dtype = DTYPE )
268-
269- # Verify charge/spin were extracted correctly
270- assert state .charge is not None
271- assert state .spin is not None
272- assert state .charge [0 ].item () == charge
273- assert state .spin [0 ].item () == spin
274-
275- # Create model with UMA omol task (supports charge/spin for molecules)
276- model = FairChemModel (
277- model = "uma-s-1p1" ,
278- task_name = "omol" ,
279- device = DEVICE ,
280- )
281-
282- # This should not raise an error
283- result = model (state )
284-
285- # Verify outputs exist
286- assert "energy" in result
287- assert result ["energy" ].shape == (1 ,)
288- assert "forces" in result
289- assert result ["forces" ].shape == (len (mol ), 3 )
290-
291- # Verify outputs are finite
292- assert torch .isfinite (result ["energy" ]).all ()
293- assert torch .isfinite (result ["forces" ]).all ()
294-
295-
296- # TODO: we should perhaps put something like this inside `validate_model_outputs`
297- # the question is how we can do this with creating a circular dependency
298- @pytest .mark .skipif (
299- get_token () is None , reason = "Requires HuggingFace authentication for UMA model access"
300- )
301- def test_fairchem_single_step_relax (rattled_si_sim_state : ts .SimState ) -> None :
302- """Test a single optimization step with FairChemModel.
303-
304- This verifies that the model works correctly with optimizers, particularly
305- that it doesn't have issues with the computational graph (e.g., missing
306- .detach() calls).
307- """
308- model = FairChemModel (model = "uma-s-1p1" , task_name = "omat" , device = DEVICE )
309- state = rattled_si_sim_state .to (device = DEVICE , dtype = DTYPE )
310-
311- # Initialize FIRE optimizer
312- opt_state = ts .fire_init (state , model )
313- initial_positions = opt_state .positions .clone ()
314- _initial_energy = opt_state .energy .item ()
315-
316- # Run exactly one step
317- opt_state = ts .fire_step (opt_state , model )
318-
319- # Verify positions changed
320- assert not torch .allclose (opt_state .positions , initial_positions )
321- # Verify energy is still available and finite
322- assert torch .isfinite (opt_state .energy ).all ()
323- assert isinstance (opt_state .energy .item (), float )
0 commit comments