From fa563e35b56a654e14f63a9dc45a1821e36b3cd7 Mon Sep 17 00:00:00 2001 From: Song Kim Date: Tue, 24 Dec 2024 20:10:52 -0500 Subject: [PATCH 1/8] some setup for multitarget prediction --- analyses/generate_molecules.py | 67 ++++++++++++++++++- .../models/continuous_position_predictor.py | 18 +++-- symphony/models/embedders/nequip.py | 16 +++-- symphony/models/predictor.py | 6 +- .../discretized_predictor.py | 7 ++ symphony/models/utils/create_model.py | 2 + 6 files changed, 102 insertions(+), 14 deletions(-) diff --git a/analyses/generate_molecules.py b/analyses/generate_molecules.py index ae98985f..f64f6ecf 100644 --- a/analyses/generate_molecules.py +++ b/analyses/generate_molecules.py @@ -97,10 +97,61 @@ def round_to_nearest_multiple_of_64(x): return padding_budget +def append_predictions_single( + target_position: jnp.ndarray, + target_species: int, + padded_fragment: datatypes.Fragments, + radial_cutoff: float, +) -> datatypes.Fragments: + """Appends the predictions to the padded fragment.""" + # Update the positions of the first dummy node. + positions = padded_fragment.nodes.positions + num_valid_nodes = padded_fragment.n_node[0] + num_nodes = padded_fragment.nodes.positions.shape[0] + num_edges = padded_fragment.receivers.shape[0] + new_positions = positions.at[num_valid_nodes].set(target_position) + + # Update the species of the first dummy node. + species = padded_fragment.nodes.species + new_species = species.at[num_valid_nodes].set(target_species) + + # Compute the distance matrix to select the edges. + distance_matrix = jnp.linalg.norm( + new_positions[None, :, :] - new_positions[:, None, :], axis=-1 + ) + node_indices = jnp.arange(num_nodes) + + # Avoid self-edges. + valid_edges = (distance_matrix > 0) & (distance_matrix < radial_cutoff) + valid_edges = ( + valid_edges + & (node_indices[None, :] <= num_valid_nodes) + & (node_indices[:, None] <= num_valid_nodes) + ) + senders, receivers = jnp.nonzero( + valid_edges, size=num_edges, fill_value=-1 + ) + num_valid_edges = jnp.sum(valid_edges) + num_valid_nodes += 1 + + return padded_fragment._replace( + nodes=padded_fragment.nodes._replace( + positions=new_positions, + species=new_species, + ), + n_node=jnp.asarray([num_valid_nodes, num_nodes - num_valid_nodes]), + n_edge=jnp.asarray([num_valid_edges, num_edges - num_valid_edges]), + senders=senders, + receivers=receivers, + ) + + def append_predictions( pred: datatypes.Predictions, padded_fragment: datatypes.Fragments, radial_cutoff: float, + eps: float, + max_num_atoms: int, ) -> datatypes.Fragments: """Appends the predictions to the padded fragment.""" # Update the positions of the first dummy node. @@ -155,12 +206,14 @@ def generate_one_step( rng: chex.PRNGKey, apply_fn: Callable[[datatypes.Fragments, chex.PRNGKey], datatypes.Predictions], radial_cutoff: float, + eps: float, + max_num_atoms: int, ) -> Tuple[ Tuple[datatypes.Fragments, bool], Tuple[datatypes.Fragments, datatypes.Predictions] ]: """Generates the next fragment for a given seed.""" pred = apply_fn(padded_fragment, rng) - next_padded_fragment = append_predictions(pred, padded_fragment, radial_cutoff) + next_padded_fragment = append_predictions(pred, padded_fragment, radial_cutoff, eps, max_num_atoms) stop = pred.globals.stop[0] | stop return jax.lax.cond( stop, @@ -172,6 +225,7 @@ def generate_one_step( def generate_for_one_seed( apply_fn: Callable[[datatypes.Fragments, chex.PRNGKey], datatypes.Predictions], init_fragment: datatypes.Fragments, + eps: float, max_num_atoms: int, cutoff: float, rng: chex.PRNGKey, @@ -180,7 +234,14 @@ def generate_for_one_seed( """Generates a single molecule for a given seed.""" step_rngs = jax.random.split(rng, num=max_num_atoms) (final_padded_fragment, stop), (padded_fragments, preds) = jax.lax.scan( - lambda args, rng: generate_one_step(*args, rng, apply_fn, cutoff), + lambda args, rng: generate_one_step( + *args, + rng, + apply_fn, + cutoff, + eps, + max_num_atoms, + ), (init_fragment, False), step_rngs, ) @@ -204,6 +265,7 @@ def generate_molecules( dataset: str, padding_mode: str, verbose: bool = False, + eps: float = 1e-5, ): # def generate_molecules( # apply_fn: Callable[[datatypes.Fragments, chex.PRNGKey], datatypes.Predictions], @@ -331,6 +393,7 @@ def apply_on_chunk( generate_for_one_seed_fn = lambda rng, init_fragment: generate_for_one_seed( apply_fn_wrapped, init_fragment, + eps, max_num_atoms, radial_cutoff, rng, diff --git a/symphony/models/continuous_position_predictor.py b/symphony/models/continuous_position_predictor.py index 12f8c958..6b830bd1 100644 --- a/symphony/models/continuous_position_predictor.py +++ b/symphony/models/continuous_position_predictor.py @@ -21,6 +21,7 @@ def __init__( radial_predictor_fn: Callable[[], RadiusPredictor], angular_predictor_fn: Callable[[], AngularPredictor], num_species: int, + num_targets: int, name: Optional[str] = None, ): super().__init__(name) @@ -28,6 +29,7 @@ def __init__( self.radial_predictor = radial_predictor_fn() self.angular_predictor = angular_predictor_fn() self.num_species = num_species + self.num_targets = num_targets def compute_conditioning( self, @@ -130,17 +132,25 @@ def get_evaluation_predictions( # Compute the conditioning based on the focus nodes and target species. conditioning = self.compute_conditioning(graphs, focus_indices, target_species) assert conditioning.shape == (num_graphs, conditioning.irreps.dim) + conditioning.array = jnp.repeat( + jnp.expand_dims(conditioning.array, axis=1), + self.num_targets, + ) # Sample the radial component. - radii = hk.vmap(self.radial_predictor.sample, split_rng=True)(conditioning) - assert radii.shape == (num_graphs,), (radii.shape, num_graphs) + radii = hk.vmap( + hk.vmap( + self.radial_predictor.sample, split_rng=True + ), split_rng=True + )(conditioning) + assert radii.shape == (num_graphs, self.num_targets), (radii.shape, num_graphs) # Predict the target position vectors. angular_sample_fn = lambda r, cond: self.angular_predictor.sample( r, cond, inverse_temperature ) - position_vectors = hk.vmap(angular_sample_fn, split_rng=True)( + position_vectors = hk.vmap(hk.vmap(angular_sample_fn, split_rng=True), split_rng=True)( radii, conditioning ) - assert position_vectors.shape == (num_graphs, 3) + assert position_vectors.shape == (num_graphs, self.num_targets, 3) return None, None, position_vectors diff --git a/symphony/models/embedders/nequip.py b/symphony/models/embedders/nequip.py index f0a0577d..305a92b5 100644 --- a/symphony/models/embedders/nequip.py +++ b/symphony/models/embedders/nequip.py @@ -35,7 +35,6 @@ def __init__( self.avg_num_neighbors = avg_num_neighbors self.max_ell = max_ell self.init_embedding_dims = init_embedding_dims - self.hidden_irreps = hidden_irreps self.output_irreps = output_irreps self.num_interactions = num_interactions self.even_activation = even_activation @@ -61,11 +60,11 @@ def __call__( node_feats = hk.Embed(self.num_species, self.init_embedding_dims)(species) node_feats = e3nn.IrrepsArray(f"{node_feats.shape[1]}x0e", node_feats) - for _ in range(self.num_interactions): - node_feats = nequip_jax.NEQUIPESCNLayerHaiku( + for interaction in range(self.num_interactions): + new_node_feats = nequip_jax.NEQUIPESCNLayerHaiku( avg_num_neighbors=self.avg_num_neighbors, num_species=self.num_species, - output_irreps=self.hidden_irreps, + output_irreps=self.output_irreps, even_activation=self.even_activation, odd_activation=self.odd_activation, mlp_activation=self.mlp_activation, @@ -73,10 +72,13 @@ def __call__( mlp_n_layers=self.mlp_n_layers, n_radial_basis=self.n_radial_basis, )(relative_positions, node_feats, species, graphs.senders, graphs.receivers) + new_node_feats = e3nn.haiku.Linear( + self.output_irreps, force_irreps_out=True + )(new_node_feats) - node_feats = e3nn.haiku.Linear( - self.output_irreps, force_irreps_out=True - )(node_feats) + if self.skip_connection and interaction > 0: + new_node_feats += node_feats + node_feats = new_node_feats alpha = 0.5 ** jnp.array(node_feats.irreps.ls) node_feats = node_feats * alpha diff --git a/symphony/models/predictor.py b/symphony/models/predictor.py index 475c5588..230dfbec 100644 --- a/symphony/models/predictor.py +++ b/symphony/models/predictor.py @@ -155,7 +155,11 @@ def get_evaluation_predictions( num_nodes, num_species, ) - assert position_vectors.shape == (num_graphs, 3) + assert position_vectors.shape == ( + num_graphs, + self.target_position_predictor.num_targets, + 3 + ) return datatypes.Predictions( nodes=datatypes.NodePredictions( diff --git a/symphony/models/radius_predictors/discretized_predictor.py b/symphony/models/radius_predictors/discretized_predictor.py index 6b8e9a51..5378fa7a 100644 --- a/symphony/models/radius_predictors/discretized_predictor.py +++ b/symphony/models/radius_predictors/discretized_predictor.py @@ -17,6 +17,7 @@ def __init__( range_max: float, num_layers: int, latent_size: int, + num_targets: int, ): super().__init__() self.num_bins = num_bins @@ -24,6 +25,7 @@ def __init__( self.range_max = range_max self.num_layers = num_layers self.latent_size = latent_size + self.num_targets = num_targets def radii(self) -> jnp.ndarray: return jnp.linspace(self.range_min, self.range_max, self.num_bins) @@ -31,6 +33,11 @@ def radii(self) -> jnp.ndarray: def predict_logits(self, conditioning: e3nn.IrrepsArray) -> distrax.Bijector: """Predicts the logits.""" conditioning = conditioning.filter("0e") + # conditioning = jnp.repeat( + # jnp.expand_dims(conditioning.array, axis=-1), + # self.num_targets, + # axis=-1 + # ) logits = hk.nets.MLP( [self.latent_size] * (self.num_layers - 1) + [self.num_bins], diff --git a/symphony/models/utils/create_model.py b/symphony/models/utils/create_model.py index 47dc9954..f1ed299e 100644 --- a/symphony/models/utils/create_model.py +++ b/symphony/models/utils/create_model.py @@ -207,6 +207,7 @@ def create_predictor(config: ml_collections.ConfigDict) -> Predictor: range_max=radial_predictor_config.max_radius, num_layers=radial_predictor_config.num_layers, latent_size=radial_predictor_config.latent_size, + num_targets=config.max_targets_per_graph, ) else: raise ValueError( @@ -220,6 +221,7 @@ def create_predictor(config: ml_collections.ConfigDict) -> Predictor: angular_predictor_fn=angular_predictor_fn, radial_predictor_fn=radial_predictor_fn, num_species=num_species, + num_targets=config.max_targets_per_graph, ) predictor = Predictor( focus_and_target_species_predictor=focus_and_target_species_predictor, From 66452eb6bc0ea434c212abf09bdd771dc98a8f57 Mon Sep 17 00:00:00 2001 From: Song Kim Date: Thu, 9 Jan 2025 20:01:07 -0500 Subject: [PATCH 2/8] multispecies multitarget --- analyses/generate_molecules.py | 151 ++++-------------- symphony/data/fragments.py | 28 +--- .../models/continuous_position_predictor.py | 37 +++-- symphony/models/predictor.py | 6 +- symphony/models/utils/utils.py | 9 +- 5 files changed, 70 insertions(+), 161 deletions(-) diff --git a/analyses/generate_molecules.py b/analyses/generate_molecules.py index f64f6ecf..fb5cf1ba 100644 --- a/analyses/generate_molecules.py +++ b/analyses/generate_molecules.py @@ -154,50 +154,39 @@ def append_predictions( max_num_atoms: int, ) -> datatypes.Fragments: """Appends the predictions to the padded fragment.""" - # Update the positions of the first dummy node. - positions = padded_fragment.nodes.positions - num_valid_nodes = padded_fragment.n_node[0] - num_nodes = padded_fragment.nodes.positions.shape[0] - num_edges = padded_fragment.receivers.shape[0] - focus = pred.globals.focus_indices[0] - focus_position = positions[focus] - target_position = pred.globals.position_vectors[0] + focus_position - new_positions = positions.at[num_valid_nodes].set(target_position) - - # Update the species of the first dummy node. - species = padded_fragment.nodes.species - target_species = pred.globals.target_species[0] - new_species = species.at[num_valid_nodes].set(target_species) - - # Compute the distance matrix to select the edges. - distance_matrix = jnp.linalg.norm( - new_positions[None, :, :] - new_positions[:, None, :], axis=-1 - ) - node_indices = jnp.arange(num_nodes) - - # Avoid self-edges. - valid_edges = (distance_matrix > 0) & (distance_matrix < radial_cutoff) - valid_edges = ( - valid_edges - & (node_indices[None, :] <= num_valid_nodes) - & (node_indices[:, None] <= num_valid_nodes) - ) - senders, receivers = jnp.nonzero( - valid_edges, size=num_edges, fill_value=-1 - ) - num_valid_edges = jnp.sum(valid_edges) - num_valid_nodes += 1 - - return padded_fragment._replace( - nodes=padded_fragment.nodes._replace( - positions=new_positions, - species=new_species, - ), - n_node=jnp.asarray([num_valid_nodes, num_nodes - num_valid_nodes]), - n_edge=jnp.asarray([num_valid_edges, num_edges - num_valid_edges]), - senders=senders, - receivers=receivers, - ) + n_nodes = padded_fragment.n_node[0] + target_relative_positions = pred.globals.position_vectors[0] # (num_targets, 3) + focus_indices = pred.globals.focus_indices[0] + focus_positions = padded_fragment.nodes.positions[focus_indices] + extra_positions = (target_relative_positions + focus_positions).reshape(-1, 3) + extra_species = (pred.globals.target_species[0]).reshape(-1,) + + + new_fragment = padded_fragment + extra_atoms = 0 + i = 0 + def f(fragment, extra_atoms): + return ( + append_predictions_single( + extra_positions[i], + extra_species[i], + fragment, + radial_cutoff + ), + extra_atoms + 1, + ) + for i in range(len(extra_positions)): + new_fragment, extra_atoms = jax.lax.cond( + jnp.logical_and( + jnp.linalg.norm(extra_positions[i]) > eps, + n_nodes + extra_atoms < max_num_atoms, + ), + f, + lambda x, y: (x, y), + new_fragment, + extra_atoms, + ) + return new_fragment def generate_one_step( @@ -267,24 +256,6 @@ def generate_molecules( verbose: bool = False, eps: float = 1e-5, ): -# def generate_molecules( -# apply_fn: Callable[[datatypes.Fragments, chex.PRNGKey], datatypes.Predictions], -# params: optax.Params, -# molecules_outputdir: str, -# radial_cutoff: float, -# focus_and_atom_type_inverse_temperature: float, -# position_inverse_temperature: float, -# start_seed: int, -# num_seeds: int, -# num_seeds_per_chunk: int, -# init_molecules: Sequence[Union[str, ase.Atoms]], -# max_num_atoms: int, -# avg_neighbors_per_atom: int, -# atomic_numbers: np.ndarray = np.arange(1, 81), -# visualize: bool = False, -# visualizations_dir: Optional[str] = None, -# verbose: bool = True, -# ): """Generates molecules from a model.""" if verbose: @@ -430,64 +401,8 @@ def apply_on_chunk( molecule_list = [] for i, seed in tqdm.tqdm(enumerate(seeds), desc="Visualizing molecules"): - init_fragment = jax.tree_util.tree_map(lambda x: x[i], init_fragments) init_molecule_name = init_molecule_names[i] - # if visualize: - # # Get the padded fragment and predictions for this seed. - # padded_fragments_for_seed = jax.tree_util.tree_map( - # lambda x: x[i], padded_fragments - # ) - # preds_for_seed = jax.tree_util.tree_map(lambda x: x[i], preds) - - # figs = [] - # for step in range(max_num_atoms): - # if step == 0: - # padded_fragment = init_fragment - # else: - # padded_fragment = jax.tree_util.tree_map( - # lambda x: x[step - 1], padded_fragments_for_seed - # ) - # pred = jax.tree_util.tree_map(lambda x: x[step], preds_for_seed) - - # # Save visualization of generation process. - # fragment = jraph.unpad_with_graphs(padded_fragment) - # pred = jraph.unpad_with_graphs(pred) - # fragment = fragment._replace( - # globals=jax.tree_util.tree_map( - # lambda x: np.squeeze(x, axis=0), fragment.globals - # ) - # ) - # pred = pred._replace( - # globals=jax.tree_util.tree_map(lambda x: np.squeeze(x, axis=0), pred.globals) - # ) - # fig = analysis.visualize_predictions(pred, fragment) - # figs.append(fig) - - # # This may be the final padded fragment. - # final_padded_fragment = padded_fragment - - # # Check if we should stop. - # stop = pred.globals.stop - # if stop: - # break - - # # Save the visualizations of the generation process. - # for index, fig in enumerate(figs): - # # Update the title. - # fig.update_layout( - # title=f"Predictions for Seed {seed}", - # title_x=0.5, - # ) - - # # Save to file. - # outputfile = os.path.join( - # visualizations_dir, - # f"seed_{seed}_fragments_{index}.html", - # ) - # fig.write_html(outputfile, include_plotlyjs="cdn") - - # else: # We already have the final padded fragment. final_padded_fragment = jax.tree_util.tree_map( lambda x: x[i], final_padded_fragments diff --git a/symphony/data/fragments.py b/symphony/data/fragments.py index 54fbecf2..8a8a3a21 100644 --- a/symphony/data/fragments.py +++ b/symphony/data/fragments.py @@ -93,25 +93,10 @@ def generate_fragments( def pick_targets( - rng, targets, - node_species, - target_species_probability_for_focus, max_targets_per_graph, ): - # Pick a random target species. - rng, k = jax.random.split(rng) - target_species = jax.random.choice( - k, - len(target_species_probability_for_focus), - p=target_species_probability_for_focus, - ) - - # Pick up to max_targets_per_graph targets of the target species. - targets_of_this_species = targets[node_species[targets] == target_species] - targets_of_this_species = targets_of_this_species[:max_targets_per_graph] - - return targets_of_this_species + return targets[:max_targets_per_graph] def _make_first_fragment( @@ -162,12 +147,8 @@ def _make_first_fragment( graph.nodes.species[targets], num_species ) - rng, k = jax.random.split(rng) target_nodes = pick_targets( - k, targets, - graph.nodes.species, - target_species_probability[first_node], max_targets_per_graph, ) @@ -243,10 +224,7 @@ def _make_middle_fragment( targets = receivers[(senders == focus_node) & mask] target_nodes = pick_targets( - k, targets, - graph.nodes.species, - target_species_probability[focus_node], max_targets_per_graph, ) @@ -293,7 +271,7 @@ def _into_fragment( # Check that all target species are the same. target_species = species[target_nodes] - assert np.all(target_species == target_species[0]) + # assert np.all(target_species == target_species[0]) padded_target_nodes = np.pad( target_nodes, (0, max_targets_per_graph - len(target_nodes)) @@ -307,7 +285,7 @@ def _into_fragment( ) globals = datatypes.FragmentsGlobals( stop=np.asarray(stop, dtype=bool), - target_species=target_species[0], + target_species=target_species, target_positions=pos[padded_target_nodes] - pos[focus_node], target_positions_mask=target_positions_mask, ) diff --git a/symphony/models/continuous_position_predictor.py b/symphony/models/continuous_position_predictor.py index 6b830bd1..92c6e6d4 100644 --- a/symphony/models/continuous_position_predictor.py +++ b/symphony/models/continuous_position_predictor.py @@ -48,14 +48,21 @@ def compute_conditioning( num_graphs, focus_node_embeddings.irreps.dim, ) + focus_node_embeddings = focus_node_embeddings.reshape( + (num_graphs, 1, focus_node_embeddings.irreps.dim) + ) # Embed the target species. - target_species_embeddings = hk.Embed( - self.num_species, - embed_dim=focus_node_embeddings.irreps.num_irreps, + target_species_embeddings = hk.vmap( + hk.Embed( + self.num_species, + embed_dim=focus_node_embeddings.irreps.num_irreps, + ), + split_rng=False )(target_species) assert target_species_embeddings.shape == ( num_graphs, + self.num_targets, focus_node_embeddings.irreps.num_irreps, ), ( target_species_embeddings.shape, @@ -69,7 +76,7 @@ def compute_conditioning( conditioning = e3nn.concatenate( [focus_node_embeddings, target_species_embeddings], axis=-1 ) - assert conditioning.shape == (num_graphs, conditioning.irreps.dim) + assert conditioning.shape == (num_graphs, self.num_targets, conditioning.irreps.dim) return conditioning def get_training_predictions( @@ -85,7 +92,7 @@ def get_training_predictions( target_species = graphs.globals.target_species conditioning = self.compute_conditioning( graphs, focus_node_indices, target_species - ) + ) # (num_graphs, num_targets, conditioning.irreps.dim) target_positions = graphs.globals.target_positions target_positions = e3nn.IrrepsArray("1o", target_positions) @@ -100,16 +107,16 @@ def predict_logits_for_single_graph( ) -> Tuple[float, float]: """Predicts the logits for a single graph.""" assert target_positions.shape == (num_targets, 3) - assert conditioning.shape == (conditioning.irreps.dim,) + assert conditioning.shape == (num_targets, conditioning.irreps.dim,) radial_logits = hk.vmap( - lambda pos: self.radial_predictor.log_prob(pos, conditioning), + self.radial_predictor.log_prob, split_rng=False, - )(target_positions) + )(target_positions, conditioning) angular_logits = hk.vmap( - lambda pos: self.angular_predictor.log_prob(pos, conditioning), + self.angular_predictor.log_prob, split_rng=False, - )(target_positions) + )(target_positions, conditioning) return radial_logits, angular_logits radial_logits, angular_logits = hk.vmap( @@ -130,12 +137,12 @@ def get_evaluation_predictions( num_graphs = graphs.n_node.shape[0] # Compute the conditioning based on the focus nodes and target species. - conditioning = self.compute_conditioning(graphs, focus_indices, target_species) - assert conditioning.shape == (num_graphs, conditioning.irreps.dim) - conditioning.array = jnp.repeat( - jnp.expand_dims(conditioning.array, axis=1), - self.num_targets, + conditioning = self.compute_conditioning( + graphs, + focus_indices, + target_species, ) + assert conditioning.shape == (num_graphs, self.num_targets, conditioning.irreps.dim) # Sample the radial component. radii = hk.vmap( diff --git a/symphony/models/predictor.py b/symphony/models/predictor.py index 230dfbec..152cf4d1 100644 --- a/symphony/models/predictor.py +++ b/symphony/models/predictor.py @@ -134,7 +134,11 @@ def get_evaluation_predictions( # Sample the focus node and target species. rng, focus_rng = jax.random.split(rng) focus_indices, target_species = utils.segment_sample_2D( - focus_and_target_species_probs, segment_ids, num_graphs, focus_rng + focus_and_target_species_probs, + segment_ids, + num_graphs, + self.target_position_predictor.num_targets, + focus_rng, ) # Compute the position coefficients. diff --git a/symphony/models/utils/utils.py b/symphony/models/utils/utils.py index 32e2b591..2da4f957 100755 --- a/symphony/models/utils/utils.py +++ b/symphony/models/utils/utils.py @@ -68,6 +68,7 @@ def segment_sample_2D( species_probabilities: jnp.ndarray, segment_ids: jnp.ndarray, num_segments: int, + num_targets: int, rng: chex.PRNGKey, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Sample indices from a categorical distribution across each segment. @@ -103,7 +104,11 @@ def sample_for_segment(rng: chex.PRNGKey, segment_id: int) -> Tuple[float, float species_probabilities[node_index] ) species_index = jax.random.choice( - logit_rng, jnp.arange(num_species), p=normalized_probs_for_index + logit_rng, + jnp.arange(num_species), + shape=(num_targets,), + replace=True, + p=normalized_probs_for_index ) return node_index, species_index @@ -112,7 +117,7 @@ def sample_for_segment(rng: chex.PRNGKey, segment_id: int) -> Tuple[float, float rngs, jnp.arange(num_segments) ) assert node_indices.shape == (num_segments,) - assert species_indices.shape == (num_segments,) + assert species_indices.shape == (num_segments, num_targets) return node_indices, species_indices From 9545acd865effe6c888718a09449d67592aa0199 Mon Sep 17 00:00:00 2001 From: Song Kim Date: Thu, 9 Jan 2025 22:23:57 -0500 Subject: [PATCH 3/8] fix generation atom filtering --- analyses/generate_molecules.py | 20 ++++++++++++++++---- symphony/data/input_pipeline.py | 8 ++++---- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/analyses/generate_molecules.py b/analyses/generate_molecules.py index fb5cf1ba..5653519c 100644 --- a/analyses/generate_molecules.py +++ b/analyses/generate_molecules.py @@ -156,11 +156,11 @@ def append_predictions( """Appends the predictions to the padded fragment.""" n_nodes = padded_fragment.n_node[0] target_relative_positions = pred.globals.position_vectors[0] # (num_targets, 3) - focus_indices = pred.globals.focus_indices[0] - focus_positions = padded_fragment.nodes.positions[focus_indices] + num_targets = target_relative_positions.shape[0] + focus = pred.globals.focus_indices[0] + focus_positions = padded_fragment.nodes.positions[focus] extra_positions = (target_relative_positions + focus_positions).reshape(-1, 3) extra_species = (pred.globals.target_species[0]).reshape(-1,) - new_fragment = padded_fragment extra_atoms = 0 @@ -175,10 +175,22 @@ def f(fragment, extra_atoms): ), extra_atoms + 1, ) + all_positions = jnp.concatenate([extra_positions, padded_fragment.nodes.positions], axis=0) for i in range(len(extra_positions)): + collision_dists = jnp.linalg.norm( + all_positions - extra_positions[i], axis=-1 + ) + # filter out nodes that aren't part of the mol + collision_dists = jnp.where( + jnp.arange(padded_fragment.nodes.positions.shape[0] + num_targets) < n_nodes + num_targets, + collision_dists, + jnp.inf, + ) + # filter out the node itself + following targets + collision_dists = collision_dists.at[i:num_targets].set(jnp.inf) new_fragment, extra_atoms = jax.lax.cond( jnp.logical_and( - jnp.linalg.norm(extra_positions[i]) > eps, + jnp.min(collision_dists) > eps, n_nodes + extra_atoms < max_num_atoms, ), f, diff --git a/symphony/data/input_pipeline.py b/symphony/data/input_pipeline.py index acdf9d3d..b81f20fb 100755 --- a/symphony/data/input_pipeline.py +++ b/symphony/data/input_pipeline.py @@ -48,9 +48,9 @@ def create_fragments_dataset( fragment_logic: str, heavy_first: bool, max_targets_per_graph: int, + num_seeds: int, max_radius: Optional[float] = None, nn_tolerance: Optional[float] = None, - num_seeds: int = 1, transition_first: Optional[bool] = False, fragment_number: Optional[int] = -1, ) -> Iterator[datatypes.Fragments]: @@ -60,12 +60,12 @@ def create_fragments_dataset( def fragment_generator(rng: chex.PRNGKey): """Generates fragments for a split.""" - original_rng = rng + rngs = jax.random.split(rng, num_seeds) # Loop indefinitely. while True: - for _ in range(num_seeds): - _, rng = jax.random.split(original_rng) + for rng_ndx in range(num_seeds): + rng = rngs[rng_ndx] for index in keep_indices: structure = structures[index] if use_same_rng_across_structures: From fcceff33c32ec5b99f8baccabcfb245037b779e6 Mon Sep 17 00:00:00 2001 From: Song Kim Date: Fri, 10 Jan 2025 19:16:15 -0500 Subject: [PATCH 4/8] fix target species shape --- symphony/data/fragments.py | 7 ++++--- symphony/models/continuous_position_predictor.py | 3 +++ 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/symphony/data/fragments.py b/symphony/data/fragments.py index 8a8a3a21..14584803 100644 --- a/symphony/data/fragments.py +++ b/symphony/data/fragments.py @@ -269,9 +269,10 @@ def _into_fragment( pos = graph.nodes.positions species = graph.nodes.species - # Check that all target species are the same. target_species = species[target_nodes] - # assert np.all(target_species == target_species[0]) + padded_target_species = np.pad( + target_species, (0, max_targets_per_graph - len(target_species)) + ) padded_target_nodes = np.pad( target_nodes, (0, max_targets_per_graph - len(target_nodes)) @@ -285,7 +286,7 @@ def _into_fragment( ) globals = datatypes.FragmentsGlobals( stop=np.asarray(stop, dtype=bool), - target_species=target_species, + target_species=padded_target_species, target_positions=pos[padded_target_nodes] - pos[focus_node], target_positions_mask=target_positions_mask, ) diff --git a/symphony/models/continuous_position_predictor.py b/symphony/models/continuous_position_predictor.py index 92c6e6d4..5403ca75 100644 --- a/symphony/models/continuous_position_predictor.py +++ b/symphony/models/continuous_position_predictor.py @@ -51,6 +51,9 @@ def compute_conditioning( focus_node_embeddings = focus_node_embeddings.reshape( (num_graphs, 1, focus_node_embeddings.irreps.dim) ) + focus_node_embeddings = focus_node_embeddings.broadcast_to( + (num_graphs, self.num_targets, focus_node_embeddings.irreps.dim) + ) # Embed the target species. target_species_embeddings = hk.vmap( From 91f57051f651231d7762b6ce75531e36abe88ed1 Mon Sep 17 00:00:00 2001 From: Song Kim Date: Sat, 11 Jan 2025 17:21:08 -0500 Subject: [PATCH 5/8] i don't actually want to change nequip ._. --- symphony/models/embedders/nequip.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/symphony/models/embedders/nequip.py b/symphony/models/embedders/nequip.py index 305a92b5..f0a0577d 100644 --- a/symphony/models/embedders/nequip.py +++ b/symphony/models/embedders/nequip.py @@ -35,6 +35,7 @@ def __init__( self.avg_num_neighbors = avg_num_neighbors self.max_ell = max_ell self.init_embedding_dims = init_embedding_dims + self.hidden_irreps = hidden_irreps self.output_irreps = output_irreps self.num_interactions = num_interactions self.even_activation = even_activation @@ -60,11 +61,11 @@ def __call__( node_feats = hk.Embed(self.num_species, self.init_embedding_dims)(species) node_feats = e3nn.IrrepsArray(f"{node_feats.shape[1]}x0e", node_feats) - for interaction in range(self.num_interactions): - new_node_feats = nequip_jax.NEQUIPESCNLayerHaiku( + for _ in range(self.num_interactions): + node_feats = nequip_jax.NEQUIPESCNLayerHaiku( avg_num_neighbors=self.avg_num_neighbors, num_species=self.num_species, - output_irreps=self.output_irreps, + output_irreps=self.hidden_irreps, even_activation=self.even_activation, odd_activation=self.odd_activation, mlp_activation=self.mlp_activation, @@ -72,13 +73,10 @@ def __call__( mlp_n_layers=self.mlp_n_layers, n_radial_basis=self.n_radial_basis, )(relative_positions, node_feats, species, graphs.senders, graphs.receivers) - new_node_feats = e3nn.haiku.Linear( - self.output_irreps, force_irreps_out=True - )(new_node_feats) - if self.skip_connection and interaction > 0: - new_node_feats += node_feats - node_feats = new_node_feats + node_feats = e3nn.haiku.Linear( + self.output_irreps, force_irreps_out=True + )(node_feats) alpha = 0.5 ** jnp.array(node_feats.irreps.ls) node_feats = node_feats * alpha From 60ef49115f3578cdd538867e4ef3a2235dec4733 Mon Sep 17 00:00:00 2001 From: Song Kim Date: Sun, 12 Jan 2025 11:44:34 -0500 Subject: [PATCH 6/8] fix generation indexing --- analyses/generate_molecules.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/analyses/generate_molecules.py b/analyses/generate_molecules.py index 449150f1..3823ba5e 100644 --- a/analyses/generate_molecules.py +++ b/analyses/generate_molecules.py @@ -164,12 +164,11 @@ def append_predictions( new_fragment = padded_fragment extra_atoms = 0 - i = 0 - def f(fragment, extra_atoms): + def f(fragment, extra_atoms, ndx): return ( append_predictions_single( - extra_positions[i], - extra_species[i], + extra_positions[ndx], + extra_species[ndx], fragment, radial_cutoff ), @@ -188,12 +187,20 @@ def f(fragment, extra_atoms): ) # filter out the node itself + following targets collision_dists = collision_dists.at[i:num_targets].set(jnp.inf) + # jax.debug.print("index {i}:\nall positions: {all_positions}\nmin collision dist: {dist}, {b}", + # i=i, + # all_positions=all_positions, + # dist=jnp.min(collision_dists), + # b=jnp.logical_and( + # jnp.min(collision_dists) > eps, + # n_nodes + extra_atoms < max_num_atoms, + # )) new_fragment, extra_atoms = jax.lax.cond( jnp.logical_and( jnp.min(collision_dists) > eps, n_nodes + extra_atoms < max_num_atoms, ), - f, + lambda x, y: f(x, y, i), lambda x, y: (x, y), new_fragment, extra_atoms, @@ -266,7 +273,7 @@ def generate_molecules( dataset: str, padding_mode: str, verbose: bool = False, - eps: float = 1e-5, + eps: float = 5e-1, # ~bohr radius ): """Generates molecules from a model.""" From 7a53acdb240a56b475cb8fb3d59ba77b945bc1da Mon Sep 17 00:00:00 2001 From: Song Kim Date: Sun, 12 Jan 2025 12:22:07 -0500 Subject: [PATCH 7/8] reset generate_molecules to put down 1 atom at a time --- analyses/generate_molecules.py | 87 ++++------------------------------ 1 file changed, 8 insertions(+), 79 deletions(-) diff --git a/analyses/generate_molecules.py b/analyses/generate_molecules.py index 3823ba5e..c9956d96 100644 --- a/analyses/generate_molecules.py +++ b/analyses/generate_molecules.py @@ -97,9 +97,8 @@ def round_to_nearest_multiple_of_64(x): return padding_budget -def append_predictions_single( - target_position: jnp.ndarray, - target_species: int, +def append_predictions( + pred: datatypes.Predictions, padded_fragment: datatypes.Fragments, radial_cutoff: float, ) -> datatypes.Fragments: @@ -109,10 +108,14 @@ def append_predictions_single( num_valid_nodes = padded_fragment.n_node[0] num_nodes = padded_fragment.nodes.positions.shape[0] num_edges = padded_fragment.receivers.shape[0] + focus = pred.globals.focus_indices[0] + focus_position = positions[focus] + target_position = pred.globals.position_vectors[0][0] + focus_position new_positions = positions.at[num_valid_nodes].set(target_position) # Update the species of the first dummy node. species = padded_fragment.nodes.species + target_species = pred.globals.target_species[0][0] new_species = species.at[num_valid_nodes].set(target_species) # Compute the distance matrix to select the edges. @@ -146,82 +149,18 @@ def append_predictions_single( ) -def append_predictions( - pred: datatypes.Predictions, - padded_fragment: datatypes.Fragments, - radial_cutoff: float, - eps: float, - max_num_atoms: int, -) -> datatypes.Fragments: - """Appends the predictions to the padded fragment.""" - n_nodes = padded_fragment.n_node[0] - target_relative_positions = pred.globals.position_vectors[0] # (num_targets, 3) - num_targets = target_relative_positions.shape[0] - focus = pred.globals.focus_indices[0] - focus_positions = padded_fragment.nodes.positions[focus] - extra_positions = (target_relative_positions + focus_positions).reshape(-1, 3) - extra_species = (pred.globals.target_species[0]).reshape(-1,) - - new_fragment = padded_fragment - extra_atoms = 0 - def f(fragment, extra_atoms, ndx): - return ( - append_predictions_single( - extra_positions[ndx], - extra_species[ndx], - fragment, - radial_cutoff - ), - extra_atoms + 1, - ) - all_positions = jnp.concatenate([extra_positions, padded_fragment.nodes.positions], axis=0) - for i in range(len(extra_positions)): - collision_dists = jnp.linalg.norm( - all_positions - extra_positions[i], axis=-1 - ) - # filter out nodes that aren't part of the mol - collision_dists = jnp.where( - jnp.arange(padded_fragment.nodes.positions.shape[0] + num_targets) < n_nodes + num_targets, - collision_dists, - jnp.inf, - ) - # filter out the node itself + following targets - collision_dists = collision_dists.at[i:num_targets].set(jnp.inf) - # jax.debug.print("index {i}:\nall positions: {all_positions}\nmin collision dist: {dist}, {b}", - # i=i, - # all_positions=all_positions, - # dist=jnp.min(collision_dists), - # b=jnp.logical_and( - # jnp.min(collision_dists) > eps, - # n_nodes + extra_atoms < max_num_atoms, - # )) - new_fragment, extra_atoms = jax.lax.cond( - jnp.logical_and( - jnp.min(collision_dists) > eps, - n_nodes + extra_atoms < max_num_atoms, - ), - lambda x, y: f(x, y, i), - lambda x, y: (x, y), - new_fragment, - extra_atoms, - ) - return new_fragment - - def generate_one_step( padded_fragment: datatypes.Fragments, stop: bool, rng: chex.PRNGKey, apply_fn: Callable[[datatypes.Fragments, chex.PRNGKey], datatypes.Predictions], radial_cutoff: float, - eps: float, - max_num_atoms: int, ) -> Tuple[ Tuple[datatypes.Fragments, bool], Tuple[datatypes.Fragments, datatypes.Predictions] ]: """Generates the next fragment for a given seed.""" pred = apply_fn(padded_fragment, rng) - next_padded_fragment = append_predictions(pred, padded_fragment, radial_cutoff, eps, max_num_atoms) + next_padded_fragment = append_predictions(pred, padded_fragment, radial_cutoff) stop = pred.globals.stop[0] | stop return jax.lax.cond( stop, @@ -233,7 +172,6 @@ def generate_one_step( def generate_for_one_seed( apply_fn: Callable[[datatypes.Fragments, chex.PRNGKey], datatypes.Predictions], init_fragment: datatypes.Fragments, - eps: float, max_num_atoms: int, cutoff: float, rng: chex.PRNGKey, @@ -242,14 +180,7 @@ def generate_for_one_seed( """Generates a single molecule for a given seed.""" step_rngs = jax.random.split(rng, num=max_num_atoms) (final_padded_fragment, stop), (padded_fragments, preds) = jax.lax.scan( - lambda args, rng: generate_one_step( - *args, - rng, - apply_fn, - cutoff, - eps, - max_num_atoms, - ), + lambda args, rng: generate_one_step(*args, rng, apply_fn, cutoff), (init_fragment, False), step_rngs, ) @@ -273,7 +204,6 @@ def generate_molecules( dataset: str, padding_mode: str, verbose: bool = False, - eps: float = 5e-1, # ~bohr radius ): """Generates molecules from a model.""" @@ -383,7 +313,6 @@ def apply_on_chunk( generate_for_one_seed_fn = lambda rng, init_fragment: generate_for_one_seed( apply_fn_wrapped, init_fragment, - eps, max_num_atoms, radial_cutoff, rng, From 81508378a255c0ef07f146ea3550801555aa75d0 Mon Sep 17 00:00:00 2001 From: Song Kim Date: Fri, 17 Jan 2025 23:57:10 -0500 Subject: [PATCH 8/8] run --- sweep_scripts/run.sh | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sweep_scripts/run.sh b/sweep_scripts/run.sh index 60f8e141..b84ba076 100755 --- a/sweep_scripts/run.sh +++ b/sweep_scripts/run.sh @@ -12,12 +12,12 @@ #module load cuda/12.1.0-x86_64 mode=nn -max_targets_per_graph=1 -cuda=0 +max_targets_per_graph=4 +cuda=3 dataset=qm9 embedder=nequip # train=1000 -workdir=/data/NFS/radish/songk/spherical-harmonic-net/workdirs/"$dataset"_dec31/e3schnet_and_"$embedder"/$mode/max_targets_$max_targets_per_graph +workdir=/data/NFS/radish/songk/spherical-harmonic-net/workdirs/"$dataset"_jan17_lr1e-3/e3schnet_and_"$embedder"/$mode/max_targets_$max_targets_per_graph # workdir=/data/NFS/potato/songk/spherical-harmonic-net/workdirs/"$dataset"_nov18_"$train"/e3schnet_and_nequip/$mode/max_targets_$max_targets_per_graph # CUDA_VISIBLE_DEVICES=$cuda python -m analyses.generate_molecules \ @@ -35,6 +35,9 @@ CUDA_VISIBLE_DEVICES=$cuda python -m symphony \ --config.num_train_steps=1000000 \ --config.position_noise_std=0.1 \ --config.target_distance_noise_std=0.1 \ + --config.learning_rate=1e-3 \ + --config.target_position_predictor.radial_predictor.num_bins=64 \ + --config.target_position_predictor.radial_predictor_type="discretized" \ --config.max_targets_per_graph=$max_targets_per_graph