Skip to content

Fixing batched MD with JAX-MD #34

@bhcao

Description

@bhcao

https://github.com/instadeepai/mlip/blob/main/src/mlip/simulation/jax_md/jax_md_simulation_engine.py#L284

# Yours
sizes = np.delete(graph.n_node, 1)
# Correct, dummy graph is the last one
sizes = np.delete(graph.n_node, -1)

https://github.com/instadeepai/mlip/blob/main/src/mlip/simulation/jax_md/jax_md_simulation_engine.py#L235

tree_map(
    lambda s, n: s.set(neighbors=n),
    self._internal_state.system_state,
    new_neighbors,
    is_leaf=lambda x: is_system_state(x) or is_neighbor_list(x), # Add this to match different leaves type
)

https://github.com/instadeepai/mlip/blob/main/src/mlip/simulation/jax_md/jax_md_simulation_engine.py#L237

_pure_simulation_step_fun needs to be updated after calling _reallocate_neighbors due to the changes of n_edge.

Add the following after L237

senders = tree_map(lambda n: n.idx[1, :], new_neighbors, is_leaf=is_neighbor_list)
receivers = tree_map(lambda n: n.idx[0, :], new_neighbors, is_leaf=is_neighbor_list)
graph = self._init_base_graph(
    self.atoms, senders, receivers, self.force_field.allowed_atomic_numbers
)
model_calculate_fun = self._get_model_calculate_fun(
    graph, self.force_field, is_batched_sim=isinstance(self.atoms, list)
)
_, sim_apply_fun = init_simulation_algorithm(
    model_calculate_fun, self._shift_fun, self._config
)
self._pure_simulation_step_fun.keywords["apply_fun"] = sim_apply_fun

Due to missing of self.atoms, store self.atoms = atoms and self.force_field = force_field in _initialize.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions