https://github.com/instadeepai/mlip/blob/main/src/mlip/simulation/jax_md/jax_md_simulation_engine.py#L284
# Yours
sizes = np.delete(graph.n_node, 1)
# Correct, dummy graph is the last one
sizes = np.delete(graph.n_node, -1)
https://github.com/instadeepai/mlip/blob/main/src/mlip/simulation/jax_md/jax_md_simulation_engine.py#L235
tree_map(
lambda s, n: s.set(neighbors=n),
self._internal_state.system_state,
new_neighbors,
is_leaf=lambda x: is_system_state(x) or is_neighbor_list(x), # Add this to match different leaves type
)
https://github.com/instadeepai/mlip/blob/main/src/mlip/simulation/jax_md/jax_md_simulation_engine.py#L237
_pure_simulation_step_fun needs to be updated after calling _reallocate_neighbors due to the changes of n_edge.
Add the following after L237
senders = tree_map(lambda n: n.idx[1, :], new_neighbors, is_leaf=is_neighbor_list)
receivers = tree_map(lambda n: n.idx[0, :], new_neighbors, is_leaf=is_neighbor_list)
graph = self._init_base_graph(
self.atoms, senders, receivers, self.force_field.allowed_atomic_numbers
)
model_calculate_fun = self._get_model_calculate_fun(
graph, self.force_field, is_batched_sim=isinstance(self.atoms, list)
)
_, sim_apply_fun = init_simulation_algorithm(
model_calculate_fun, self._shift_fun, self._config
)
self._pure_simulation_step_fun.keywords["apply_fun"] = sim_apply_fun
Due to missing of self.atoms, store self.atoms = atoms and self.force_field = force_field in _initialize.
https://github.com/instadeepai/mlip/blob/main/src/mlip/simulation/jax_md/jax_md_simulation_engine.py#L284
https://github.com/instadeepai/mlip/blob/main/src/mlip/simulation/jax_md/jax_md_simulation_engine.py#L235
https://github.com/instadeepai/mlip/blob/main/src/mlip/simulation/jax_md/jax_md_simulation_engine.py#L237
_pure_simulation_step_funneeds to be updated after calling_reallocate_neighborsdue to the changes ofn_edge.Add the following after L237
Due to missing of
self.atoms, storeself.atoms = atomsandself.force_field = force_fieldin_initialize.