diff --git a/analyses/generate_molecules.py b/analyses/generate_molecules.py index 01faae82..c9956d96 100644 --- a/analyses/generate_molecules.py +++ b/analyses/generate_molecules.py @@ -110,12 +110,12 @@ def append_predictions( num_edges = padded_fragment.receivers.shape[0] focus = pred.globals.focus_indices[0] focus_position = positions[focus] - target_position = pred.globals.position_vectors[0] + focus_position + target_position = pred.globals.position_vectors[0][0] + focus_position new_positions = positions.at[num_valid_nodes].set(target_position) # Update the species of the first dummy node. species = padded_fragment.nodes.species - target_species = pred.globals.target_species[0] + target_species = pred.globals.target_species[0][0] new_species = species.at[num_valid_nodes].set(target_species) # Compute the distance matrix to select the edges. @@ -205,24 +205,6 @@ def generate_molecules( padding_mode: str, verbose: bool = False, ): -# def generate_molecules( -# apply_fn: Callable[[datatypes.Fragments, chex.PRNGKey], datatypes.Predictions], -# params: optax.Params, -# molecules_outputdir: str, -# radial_cutoff: float, -# focus_and_atom_type_inverse_temperature: float, -# position_inverse_temperature: float, -# start_seed: int, -# num_seeds: int, -# num_seeds_per_chunk: int, -# init_molecules: Sequence[Union[str, ase.Atoms]], -# max_num_atoms: int, -# avg_neighbors_per_atom: int, -# atomic_numbers: np.ndarray = np.arange(1, 81), -# visualize: bool = False, -# visualizations_dir: Optional[str] = None, -# verbose: bool = True, -# ): """Generates molecules from a model.""" if verbose: @@ -367,64 +349,8 @@ def apply_on_chunk( molecule_list = [] for i, seed in tqdm.tqdm(enumerate(seeds), desc="Visualizing molecules"): - init_fragment = jax.tree_util.tree_map(lambda x: x[i], init_fragments) init_molecule_name = init_molecule_names[i] - # if visualize: - # # Get the padded fragment and predictions for this seed. - # padded_fragments_for_seed = jax.tree_util.tree_map( - # lambda x: x[i], padded_fragments - # ) - # preds_for_seed = jax.tree_util.tree_map(lambda x: x[i], preds) - - # figs = [] - # for step in range(max_num_atoms): - # if step == 0: - # padded_fragment = init_fragment - # else: - # padded_fragment = jax.tree_util.tree_map( - # lambda x: x[step - 1], padded_fragments_for_seed - # ) - # pred = jax.tree_util.tree_map(lambda x: x[step], preds_for_seed) - - # # Save visualization of generation process. - # fragment = jraph.unpad_with_graphs(padded_fragment) - # pred = jraph.unpad_with_graphs(pred) - # fragment = fragment._replace( - # globals=jax.tree_util.tree_map( - # lambda x: np.squeeze(x, axis=0), fragment.globals - # ) - # ) - # pred = pred._replace( - # globals=jax.tree_util.tree_map(lambda x: np.squeeze(x, axis=0), pred.globals) - # ) - # fig = analysis.visualize_predictions(pred, fragment) - # figs.append(fig) - - # # This may be the final padded fragment. - # final_padded_fragment = padded_fragment - - # # Check if we should stop. - # stop = pred.globals.stop - # if stop: - # break - - # # Save the visualizations of the generation process. - # for index, fig in enumerate(figs): - # # Update the title. - # fig.update_layout( - # title=f"Predictions for Seed {seed}", - # title_x=0.5, - # ) - - # # Save to file. - # outputfile = os.path.join( - # visualizations_dir, - # f"seed_{seed}_fragments_{index}.html", - # ) - # fig.write_html(outputfile, include_plotlyjs="cdn") - - # else: # We already have the final padded fragment. final_padded_fragment = jax.tree_util.tree_map( lambda x: x[i], final_padded_fragments diff --git a/sweep_scripts/run.sh b/sweep_scripts/run.sh index 60f8e141..b84ba076 100755 --- a/sweep_scripts/run.sh +++ b/sweep_scripts/run.sh @@ -12,12 +12,12 @@ #module load cuda/12.1.0-x86_64 mode=nn -max_targets_per_graph=1 -cuda=0 +max_targets_per_graph=4 +cuda=3 dataset=qm9 embedder=nequip # train=1000 -workdir=/data/NFS/radish/songk/spherical-harmonic-net/workdirs/"$dataset"_dec31/e3schnet_and_"$embedder"/$mode/max_targets_$max_targets_per_graph +workdir=/data/NFS/radish/songk/spherical-harmonic-net/workdirs/"$dataset"_jan17_lr1e-3/e3schnet_and_"$embedder"/$mode/max_targets_$max_targets_per_graph # workdir=/data/NFS/potato/songk/spherical-harmonic-net/workdirs/"$dataset"_nov18_"$train"/e3schnet_and_nequip/$mode/max_targets_$max_targets_per_graph # CUDA_VISIBLE_DEVICES=$cuda python -m analyses.generate_molecules \ @@ -35,6 +35,9 @@ CUDA_VISIBLE_DEVICES=$cuda python -m symphony \ --config.num_train_steps=1000000 \ --config.position_noise_std=0.1 \ --config.target_distance_noise_std=0.1 \ + --config.learning_rate=1e-3 \ + --config.target_position_predictor.radial_predictor.num_bins=64 \ + --config.target_position_predictor.radial_predictor_type="discretized" \ --config.max_targets_per_graph=$max_targets_per_graph diff --git a/symphony/data/fragments.py b/symphony/data/fragments.py index 54fbecf2..14584803 100644 --- a/symphony/data/fragments.py +++ b/symphony/data/fragments.py @@ -93,25 +93,10 @@ def generate_fragments( def pick_targets( - rng, targets, - node_species, - target_species_probability_for_focus, max_targets_per_graph, ): - # Pick a random target species. - rng, k = jax.random.split(rng) - target_species = jax.random.choice( - k, - len(target_species_probability_for_focus), - p=target_species_probability_for_focus, - ) - - # Pick up to max_targets_per_graph targets of the target species. - targets_of_this_species = targets[node_species[targets] == target_species] - targets_of_this_species = targets_of_this_species[:max_targets_per_graph] - - return targets_of_this_species + return targets[:max_targets_per_graph] def _make_first_fragment( @@ -162,12 +147,8 @@ def _make_first_fragment( graph.nodes.species[targets], num_species ) - rng, k = jax.random.split(rng) target_nodes = pick_targets( - k, targets, - graph.nodes.species, - target_species_probability[first_node], max_targets_per_graph, ) @@ -243,10 +224,7 @@ def _make_middle_fragment( targets = receivers[(senders == focus_node) & mask] target_nodes = pick_targets( - k, targets, - graph.nodes.species, - target_species_probability[focus_node], max_targets_per_graph, ) @@ -291,9 +269,10 @@ def _into_fragment( pos = graph.nodes.positions species = graph.nodes.species - # Check that all target species are the same. target_species = species[target_nodes] - assert np.all(target_species == target_species[0]) + padded_target_species = np.pad( + target_species, (0, max_targets_per_graph - len(target_species)) + ) padded_target_nodes = np.pad( target_nodes, (0, max_targets_per_graph - len(target_nodes)) @@ -307,7 +286,7 @@ def _into_fragment( ) globals = datatypes.FragmentsGlobals( stop=np.asarray(stop, dtype=bool), - target_species=target_species[0], + target_species=padded_target_species, target_positions=pos[padded_target_nodes] - pos[focus_node], target_positions_mask=target_positions_mask, ) diff --git a/symphony/models/continuous_position_predictor.py b/symphony/models/continuous_position_predictor.py index 12f8c958..5403ca75 100644 --- a/symphony/models/continuous_position_predictor.py +++ b/symphony/models/continuous_position_predictor.py @@ -21,6 +21,7 @@ def __init__( radial_predictor_fn: Callable[[], RadiusPredictor], angular_predictor_fn: Callable[[], AngularPredictor], num_species: int, + num_targets: int, name: Optional[str] = None, ): super().__init__(name) @@ -28,6 +29,7 @@ def __init__( self.radial_predictor = radial_predictor_fn() self.angular_predictor = angular_predictor_fn() self.num_species = num_species + self.num_targets = num_targets def compute_conditioning( self, @@ -46,14 +48,24 @@ def compute_conditioning( num_graphs, focus_node_embeddings.irreps.dim, ) + focus_node_embeddings = focus_node_embeddings.reshape( + (num_graphs, 1, focus_node_embeddings.irreps.dim) + ) + focus_node_embeddings = focus_node_embeddings.broadcast_to( + (num_graphs, self.num_targets, focus_node_embeddings.irreps.dim) + ) # Embed the target species. - target_species_embeddings = hk.Embed( - self.num_species, - embed_dim=focus_node_embeddings.irreps.num_irreps, + target_species_embeddings = hk.vmap( + hk.Embed( + self.num_species, + embed_dim=focus_node_embeddings.irreps.num_irreps, + ), + split_rng=False )(target_species) assert target_species_embeddings.shape == ( num_graphs, + self.num_targets, focus_node_embeddings.irreps.num_irreps, ), ( target_species_embeddings.shape, @@ -67,7 +79,7 @@ def compute_conditioning( conditioning = e3nn.concatenate( [focus_node_embeddings, target_species_embeddings], axis=-1 ) - assert conditioning.shape == (num_graphs, conditioning.irreps.dim) + assert conditioning.shape == (num_graphs, self.num_targets, conditioning.irreps.dim) return conditioning def get_training_predictions( @@ -83,7 +95,7 @@ def get_training_predictions( target_species = graphs.globals.target_species conditioning = self.compute_conditioning( graphs, focus_node_indices, target_species - ) + ) # (num_graphs, num_targets, conditioning.irreps.dim) target_positions = graphs.globals.target_positions target_positions = e3nn.IrrepsArray("1o", target_positions) @@ -98,16 +110,16 @@ def predict_logits_for_single_graph( ) -> Tuple[float, float]: """Predicts the logits for a single graph.""" assert target_positions.shape == (num_targets, 3) - assert conditioning.shape == (conditioning.irreps.dim,) + assert conditioning.shape == (num_targets, conditioning.irreps.dim,) radial_logits = hk.vmap( - lambda pos: self.radial_predictor.log_prob(pos, conditioning), + self.radial_predictor.log_prob, split_rng=False, - )(target_positions) + )(target_positions, conditioning) angular_logits = hk.vmap( - lambda pos: self.angular_predictor.log_prob(pos, conditioning), + self.angular_predictor.log_prob, split_rng=False, - )(target_positions) + )(target_positions, conditioning) return radial_logits, angular_logits radial_logits, angular_logits = hk.vmap( @@ -128,19 +140,27 @@ def get_evaluation_predictions( num_graphs = graphs.n_node.shape[0] # Compute the conditioning based on the focus nodes and target species. - conditioning = self.compute_conditioning(graphs, focus_indices, target_species) - assert conditioning.shape == (num_graphs, conditioning.irreps.dim) + conditioning = self.compute_conditioning( + graphs, + focus_indices, + target_species, + ) + assert conditioning.shape == (num_graphs, self.num_targets, conditioning.irreps.dim) # Sample the radial component. - radii = hk.vmap(self.radial_predictor.sample, split_rng=True)(conditioning) - assert radii.shape == (num_graphs,), (radii.shape, num_graphs) + radii = hk.vmap( + hk.vmap( + self.radial_predictor.sample, split_rng=True + ), split_rng=True + )(conditioning) + assert radii.shape == (num_graphs, self.num_targets), (radii.shape, num_graphs) # Predict the target position vectors. angular_sample_fn = lambda r, cond: self.angular_predictor.sample( r, cond, inverse_temperature ) - position_vectors = hk.vmap(angular_sample_fn, split_rng=True)( + position_vectors = hk.vmap(hk.vmap(angular_sample_fn, split_rng=True), split_rng=True)( radii, conditioning ) - assert position_vectors.shape == (num_graphs, 3) + assert position_vectors.shape == (num_graphs, self.num_targets, 3) return None, None, position_vectors diff --git a/symphony/models/predictor.py b/symphony/models/predictor.py index 475c5588..152cf4d1 100644 --- a/symphony/models/predictor.py +++ b/symphony/models/predictor.py @@ -134,7 +134,11 @@ def get_evaluation_predictions( # Sample the focus node and target species. rng, focus_rng = jax.random.split(rng) focus_indices, target_species = utils.segment_sample_2D( - focus_and_target_species_probs, segment_ids, num_graphs, focus_rng + focus_and_target_species_probs, + segment_ids, + num_graphs, + self.target_position_predictor.num_targets, + focus_rng, ) # Compute the position coefficients. @@ -155,7 +159,11 @@ def get_evaluation_predictions( num_nodes, num_species, ) - assert position_vectors.shape == (num_graphs, 3) + assert position_vectors.shape == ( + num_graphs, + self.target_position_predictor.num_targets, + 3 + ) return datatypes.Predictions( nodes=datatypes.NodePredictions( diff --git a/symphony/models/radius_predictors/discretized_predictor.py b/symphony/models/radius_predictors/discretized_predictor.py index 6b8e9a51..5378fa7a 100644 --- a/symphony/models/radius_predictors/discretized_predictor.py +++ b/symphony/models/radius_predictors/discretized_predictor.py @@ -17,6 +17,7 @@ def __init__( range_max: float, num_layers: int, latent_size: int, + num_targets: int, ): super().__init__() self.num_bins = num_bins @@ -24,6 +25,7 @@ def __init__( self.range_max = range_max self.num_layers = num_layers self.latent_size = latent_size + self.num_targets = num_targets def radii(self) -> jnp.ndarray: return jnp.linspace(self.range_min, self.range_max, self.num_bins) @@ -31,6 +33,11 @@ def radii(self) -> jnp.ndarray: def predict_logits(self, conditioning: e3nn.IrrepsArray) -> distrax.Bijector: """Predicts the logits.""" conditioning = conditioning.filter("0e") + # conditioning = jnp.repeat( + # jnp.expand_dims(conditioning.array, axis=-1), + # self.num_targets, + # axis=-1 + # ) logits = hk.nets.MLP( [self.latent_size] * (self.num_layers - 1) + [self.num_bins], diff --git a/symphony/models/utils/create_model.py b/symphony/models/utils/create_model.py index ba8d82eb..51cf56a8 100644 --- a/symphony/models/utils/create_model.py +++ b/symphony/models/utils/create_model.py @@ -211,6 +211,7 @@ def create_predictor(config: ml_collections.ConfigDict) -> Predictor: range_max=radial_predictor_config.max_radius, num_layers=radial_predictor_config.num_layers, latent_size=radial_predictor_config.latent_size, + num_targets=config.max_targets_per_graph, ) else: raise ValueError( @@ -224,6 +225,7 @@ def create_predictor(config: ml_collections.ConfigDict) -> Predictor: angular_predictor_fn=angular_predictor_fn, radial_predictor_fn=radial_predictor_fn, num_species=num_species, + num_targets=config.max_targets_per_graph, ) predictor = Predictor( focus_and_target_species_predictor=focus_and_target_species_predictor, diff --git a/symphony/models/utils/utils.py b/symphony/models/utils/utils.py index 32e2b591..2da4f957 100755 --- a/symphony/models/utils/utils.py +++ b/symphony/models/utils/utils.py @@ -68,6 +68,7 @@ def segment_sample_2D( species_probabilities: jnp.ndarray, segment_ids: jnp.ndarray, num_segments: int, + num_targets: int, rng: chex.PRNGKey, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Sample indices from a categorical distribution across each segment. @@ -103,7 +104,11 @@ def sample_for_segment(rng: chex.PRNGKey, segment_id: int) -> Tuple[float, float species_probabilities[node_index] ) species_index = jax.random.choice( - logit_rng, jnp.arange(num_species), p=normalized_probs_for_index + logit_rng, + jnp.arange(num_species), + shape=(num_targets,), + replace=True, + p=normalized_probs_for_index ) return node_index, species_index @@ -112,7 +117,7 @@ def sample_for_segment(rng: chex.PRNGKey, segment_id: int) -> Tuple[float, float rngs, jnp.arange(num_segments) ) assert node_indices.shape == (num_segments,) - assert species_indices.shape == (num_segments,) + assert species_indices.shape == (num_segments, num_targets) return node_indices, species_indices