Skip to content

Commit 13e4872

Browse files
GardevoirXLuthaf
authored andcommitted
Add a HeatFlux model that wraps existing models
`model = HeatFlux.wrap(model)` will create a new model with the same outputs, and (multiple) extra outputs for the heat_flux, one for each energy variant.
1 parent 2e87d7b commit 13e4872

5 files changed

Lines changed: 849 additions & 0 deletions

File tree

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import numpy as np
2+
import pytest
3+
import torch
4+
from ase import Atoms
5+
from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
6+
7+
import metatomic_lj_test
8+
from metatomic.torch import ModelOutput
9+
from metatomic.torch.heat_flux import (
10+
HeatFlux,
11+
)
12+
from metatomic_ase import MetatomicCalculator
13+
14+
15+
@pytest.fixture
16+
def model():
17+
return metatomic_lj_test.lennard_jones_model(
18+
atomic_type=18,
19+
cutoff=7.0,
20+
sigma=3.405,
21+
epsilon=0.01032,
22+
length_unit="Angstrom",
23+
energy_unit="eV",
24+
with_extension=False,
25+
)
26+
27+
28+
@pytest.fixture
29+
def model_in_kcal_per_mol():
30+
return metatomic_lj_test.lennard_jones_model(
31+
atomic_type=18,
32+
cutoff=7.0,
33+
sigma=3.405,
34+
epsilon=0.2380,
35+
length_unit="Angstrom",
36+
energy_unit="kcal/mol",
37+
with_extension=False,
38+
)
39+
40+
41+
@pytest.fixture
42+
def atoms(request):
43+
if hasattr(request, "param") and request.param == "atoms_triclinic":
44+
cell = np.array([[6.0, 3.0, 1.0], [2.0, 6.0, 0.0], [0.0, 0.0, 6.0]])
45+
positions = np.array([[0.0, 0.0, 0.0]])
46+
else:
47+
cell = np.array([[6.0, 0.0, 0.0], [0.0, 6.0, 0.0], [0.0, 0.0, 6.0]])
48+
positions = np.array([[3.0, 3.0, 3.0]])
49+
atoms = Atoms("Ar", scaled_positions=positions, cell=cell, pbc=True).repeat(
50+
(2, 2, 2)
51+
)
52+
MaxwellBoltzmannDistribution(
53+
atoms, temperature_K=300, rng=np.random.default_rng(42)
54+
)
55+
return atoms
56+
57+
58+
@pytest.mark.parametrize("use_script", [True, False])
59+
@pytest.mark.parametrize(
60+
"atoms, expected",
61+
[
62+
("atoms", [[8.8238e-05], [-2.5559e-04], [-2.0570e-04]]),
63+
],
64+
indirect=["atoms"],
65+
)
66+
def test_wrap(model, atoms, expected, use_script):
67+
wrapped_model = HeatFlux.wrap(model, scripting=use_script)
68+
calc = MetatomicCalculator(
69+
wrapped_model,
70+
device="cpu",
71+
additional_outputs={
72+
"heat_flux": ModelOutput(
73+
quantity="heat_flux",
74+
unit="eV*A/fs",
75+
explicit_gradients=[],
76+
per_atom=False,
77+
)
78+
},
79+
check_consistency=True,
80+
)
81+
atoms.calc = calc
82+
atoms.get_potential_energy()
83+
results = atoms.calc.additional_outputs["heat_flux"].block().values
84+
assert torch.allclose(
85+
results,
86+
torch.tensor(expected, dtype=results.dtype),
87+
)

0 commit comments

Comments
 (0)