Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 2 additions & 76 deletions analyses/generate_molecules.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,12 @@ def append_predictions(
num_edges = padded_fragment.receivers.shape[0]
focus = pred.globals.focus_indices[0]
focus_position = positions[focus]
target_position = pred.globals.position_vectors[0] + focus_position
target_position = pred.globals.position_vectors[0][0] + focus_position
new_positions = positions.at[num_valid_nodes].set(target_position)

# Update the species of the first dummy node.
species = padded_fragment.nodes.species
target_species = pred.globals.target_species[0]
target_species = pred.globals.target_species[0][0]
new_species = species.at[num_valid_nodes].set(target_species)

# Compute the distance matrix to select the edges.
Expand Down Expand Up @@ -205,24 +205,6 @@ def generate_molecules(
padding_mode: str,
verbose: bool = False,
):
# def generate_molecules(
# apply_fn: Callable[[datatypes.Fragments, chex.PRNGKey], datatypes.Predictions],
# params: optax.Params,
# molecules_outputdir: str,
# radial_cutoff: float,
# focus_and_atom_type_inverse_temperature: float,
# position_inverse_temperature: float,
# start_seed: int,
# num_seeds: int,
# num_seeds_per_chunk: int,
# init_molecules: Sequence[Union[str, ase.Atoms]],
# max_num_atoms: int,
# avg_neighbors_per_atom: int,
# atomic_numbers: np.ndarray = np.arange(1, 81),
# visualize: bool = False,
# visualizations_dir: Optional[str] = None,
# verbose: bool = True,
# ):
"""Generates molecules from a model."""

if verbose:
Expand Down Expand Up @@ -367,64 +349,8 @@ def apply_on_chunk(

molecule_list = []
for i, seed in tqdm.tqdm(enumerate(seeds), desc="Visualizing molecules"):
init_fragment = jax.tree_util.tree_map(lambda x: x[i], init_fragments)
init_molecule_name = init_molecule_names[i]

# if visualize:
# # Get the padded fragment and predictions for this seed.
# padded_fragments_for_seed = jax.tree_util.tree_map(
# lambda x: x[i], padded_fragments
# )
# preds_for_seed = jax.tree_util.tree_map(lambda x: x[i], preds)

# figs = []
# for step in range(max_num_atoms):
# if step == 0:
# padded_fragment = init_fragment
# else:
# padded_fragment = jax.tree_util.tree_map(
# lambda x: x[step - 1], padded_fragments_for_seed
# )
# pred = jax.tree_util.tree_map(lambda x: x[step], preds_for_seed)

# # Save visualization of generation process.
# fragment = jraph.unpad_with_graphs(padded_fragment)
# pred = jraph.unpad_with_graphs(pred)
# fragment = fragment._replace(
# globals=jax.tree_util.tree_map(
# lambda x: np.squeeze(x, axis=0), fragment.globals
# )
# )
# pred = pred._replace(
# globals=jax.tree_util.tree_map(lambda x: np.squeeze(x, axis=0), pred.globals)
# )
# fig = analysis.visualize_predictions(pred, fragment)
# figs.append(fig)

# # This may be the final padded fragment.
# final_padded_fragment = padded_fragment

# # Check if we should stop.
# stop = pred.globals.stop
# if stop:
# break

# # Save the visualizations of the generation process.
# for index, fig in enumerate(figs):
# # Update the title.
# fig.update_layout(
# title=f"Predictions for Seed {seed}",
# title_x=0.5,
# )

# # Save to file.
# outputfile = os.path.join(
# visualizations_dir,
# f"seed_{seed}_fragments_{index}.html",
# )
# fig.write_html(outputfile, include_plotlyjs="cdn")

# else:
# We already have the final padded fragment.
final_padded_fragment = jax.tree_util.tree_map(
lambda x: x[i], final_padded_fragments
Expand Down
9 changes: 6 additions & 3 deletions sweep_scripts/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
#module load cuda/12.1.0-x86_64

mode=nn
max_targets_per_graph=1
cuda=0
max_targets_per_graph=4
cuda=3
dataset=qm9
embedder=nequip
# train=1000
workdir=/data/NFS/radish/songk/spherical-harmonic-net/workdirs/"$dataset"_dec31/e3schnet_and_"$embedder"/$mode/max_targets_$max_targets_per_graph
workdir=/data/NFS/radish/songk/spherical-harmonic-net/workdirs/"$dataset"_jan17_lr1e-3/e3schnet_and_"$embedder"/$mode/max_targets_$max_targets_per_graph
# workdir=/data/NFS/potato/songk/spherical-harmonic-net/workdirs/"$dataset"_nov18_"$train"/e3schnet_and_nequip/$mode/max_targets_$max_targets_per_graph

# CUDA_VISIBLE_DEVICES=$cuda python -m analyses.generate_molecules \
Expand All @@ -35,6 +35,9 @@ CUDA_VISIBLE_DEVICES=$cuda python -m symphony \
--config.num_train_steps=1000000 \
--config.position_noise_std=0.1 \
--config.target_distance_noise_std=0.1 \
--config.learning_rate=1e-3 \
--config.target_position_predictor.radial_predictor.num_bins=64 \
--config.target_position_predictor.radial_predictor_type="discretized" \
--config.max_targets_per_graph=$max_targets_per_graph


Expand Down
31 changes: 5 additions & 26 deletions symphony/data/fragments.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,25 +93,10 @@ def generate_fragments(


def pick_targets(
rng,
targets,
node_species,
target_species_probability_for_focus,
max_targets_per_graph,
):
# Pick a random target species.
rng, k = jax.random.split(rng)
target_species = jax.random.choice(
k,
len(target_species_probability_for_focus),
p=target_species_probability_for_focus,
)

# Pick up to max_targets_per_graph targets of the target species.
targets_of_this_species = targets[node_species[targets] == target_species]
targets_of_this_species = targets_of_this_species[:max_targets_per_graph]

return targets_of_this_species
return targets[:max_targets_per_graph]


def _make_first_fragment(
Expand Down Expand Up @@ -162,12 +147,8 @@ def _make_first_fragment(
graph.nodes.species[targets], num_species
)

rng, k = jax.random.split(rng)
target_nodes = pick_targets(
k,
targets,
graph.nodes.species,
target_species_probability[first_node],
max_targets_per_graph,
)

Expand Down Expand Up @@ -243,10 +224,7 @@ def _make_middle_fragment(
targets = receivers[(senders == focus_node) & mask]

target_nodes = pick_targets(
k,
targets,
graph.nodes.species,
target_species_probability[focus_node],
max_targets_per_graph,
)

Expand Down Expand Up @@ -291,9 +269,10 @@ def _into_fragment(
pos = graph.nodes.positions
species = graph.nodes.species

# Check that all target species are the same.
target_species = species[target_nodes]
assert np.all(target_species == target_species[0])
padded_target_species = np.pad(
target_species, (0, max_targets_per_graph - len(target_species))
)

padded_target_nodes = np.pad(
target_nodes, (0, max_targets_per_graph - len(target_nodes))
Expand All @@ -307,7 +286,7 @@ def _into_fragment(
)
globals = datatypes.FragmentsGlobals(
stop=np.asarray(stop, dtype=bool),
target_species=target_species[0],
target_species=padded_target_species,
target_positions=pos[padded_target_nodes] - pos[focus_node],
target_positions_mask=target_positions_mask,
)
Expand Down
52 changes: 36 additions & 16 deletions symphony/models/continuous_position_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ def __init__(
radial_predictor_fn: Callable[[], RadiusPredictor],
angular_predictor_fn: Callable[[], AngularPredictor],
num_species: int,
num_targets: int,
name: Optional[str] = None,
):
super().__init__(name)
self.node_embedder = node_embedder_fn()
self.radial_predictor = radial_predictor_fn()
self.angular_predictor = angular_predictor_fn()
self.num_species = num_species
self.num_targets = num_targets

def compute_conditioning(
self,
Expand All @@ -46,14 +48,24 @@ def compute_conditioning(
num_graphs,
focus_node_embeddings.irreps.dim,
)
focus_node_embeddings = focus_node_embeddings.reshape(
(num_graphs, 1, focus_node_embeddings.irreps.dim)
)
focus_node_embeddings = focus_node_embeddings.broadcast_to(
(num_graphs, self.num_targets, focus_node_embeddings.irreps.dim)
)

# Embed the target species.
target_species_embeddings = hk.Embed(
self.num_species,
embed_dim=focus_node_embeddings.irreps.num_irreps,
target_species_embeddings = hk.vmap(
hk.Embed(
self.num_species,
embed_dim=focus_node_embeddings.irreps.num_irreps,
),
split_rng=False
)(target_species)
assert target_species_embeddings.shape == (
num_graphs,
self.num_targets,
focus_node_embeddings.irreps.num_irreps,
), (
target_species_embeddings.shape,
Expand All @@ -67,7 +79,7 @@ def compute_conditioning(
conditioning = e3nn.concatenate(
[focus_node_embeddings, target_species_embeddings], axis=-1
)
assert conditioning.shape == (num_graphs, conditioning.irreps.dim)
assert conditioning.shape == (num_graphs, self.num_targets, conditioning.irreps.dim)
return conditioning

def get_training_predictions(
Expand All @@ -83,7 +95,7 @@ def get_training_predictions(
target_species = graphs.globals.target_species
conditioning = self.compute_conditioning(
graphs, focus_node_indices, target_species
)
) # (num_graphs, num_targets, conditioning.irreps.dim)

target_positions = graphs.globals.target_positions
target_positions = e3nn.IrrepsArray("1o", target_positions)
Expand All @@ -98,16 +110,16 @@ def predict_logits_for_single_graph(
) -> Tuple[float, float]:
"""Predicts the logits for a single graph."""
assert target_positions.shape == (num_targets, 3)
assert conditioning.shape == (conditioning.irreps.dim,)
assert conditioning.shape == (num_targets, conditioning.irreps.dim,)

radial_logits = hk.vmap(
lambda pos: self.radial_predictor.log_prob(pos, conditioning),
self.radial_predictor.log_prob,
split_rng=False,
)(target_positions)
)(target_positions, conditioning)
angular_logits = hk.vmap(
lambda pos: self.angular_predictor.log_prob(pos, conditioning),
self.angular_predictor.log_prob,
split_rng=False,
)(target_positions)
)(target_positions, conditioning)
return radial_logits, angular_logits

radial_logits, angular_logits = hk.vmap(
Expand All @@ -128,19 +140,27 @@ def get_evaluation_predictions(
num_graphs = graphs.n_node.shape[0]

# Compute the conditioning based on the focus nodes and target species.
conditioning = self.compute_conditioning(graphs, focus_indices, target_species)
assert conditioning.shape == (num_graphs, conditioning.irreps.dim)
conditioning = self.compute_conditioning(
graphs,
focus_indices,
target_species,
)
assert conditioning.shape == (num_graphs, self.num_targets, conditioning.irreps.dim)

# Sample the radial component.
radii = hk.vmap(self.radial_predictor.sample, split_rng=True)(conditioning)
assert radii.shape == (num_graphs,), (radii.shape, num_graphs)
radii = hk.vmap(
hk.vmap(
self.radial_predictor.sample, split_rng=True
), split_rng=True
)(conditioning)
assert radii.shape == (num_graphs, self.num_targets), (radii.shape, num_graphs)

# Predict the target position vectors.
angular_sample_fn = lambda r, cond: self.angular_predictor.sample(
r, cond, inverse_temperature
)
position_vectors = hk.vmap(angular_sample_fn, split_rng=True)(
position_vectors = hk.vmap(hk.vmap(angular_sample_fn, split_rng=True), split_rng=True)(
radii, conditioning
)
assert position_vectors.shape == (num_graphs, 3)
assert position_vectors.shape == (num_graphs, self.num_targets, 3)
return None, None, position_vectors
12 changes: 10 additions & 2 deletions symphony/models/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,11 @@ def get_evaluation_predictions(
# Sample the focus node and target species.
rng, focus_rng = jax.random.split(rng)
focus_indices, target_species = utils.segment_sample_2D(
focus_and_target_species_probs, segment_ids, num_graphs, focus_rng
focus_and_target_species_probs,
segment_ids,
num_graphs,
self.target_position_predictor.num_targets,
focus_rng,
)

# Compute the position coefficients.
Expand All @@ -155,7 +159,11 @@ def get_evaluation_predictions(
num_nodes,
num_species,
)
assert position_vectors.shape == (num_graphs, 3)
assert position_vectors.shape == (
num_graphs,
self.target_position_predictor.num_targets,
3
)

return datatypes.Predictions(
nodes=datatypes.NodePredictions(
Expand Down
7 changes: 7 additions & 0 deletions symphony/models/radius_predictors/discretized_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,27 @@ def __init__(
range_max: float,
num_layers: int,
latent_size: int,
num_targets: int,
):
super().__init__()
self.num_bins = num_bins
self.range_min = range_min
self.range_max = range_max
self.num_layers = num_layers
self.latent_size = latent_size
self.num_targets = num_targets

def radii(self) -> jnp.ndarray:
return jnp.linspace(self.range_min, self.range_max, self.num_bins)

def predict_logits(self, conditioning: e3nn.IrrepsArray) -> distrax.Bijector:
"""Predicts the logits."""
conditioning = conditioning.filter("0e")
# conditioning = jnp.repeat(
# jnp.expand_dims(conditioning.array, axis=-1),
# self.num_targets,
# axis=-1
# )

logits = hk.nets.MLP(
[self.latent_size] * (self.num_layers - 1) + [self.num_bins],
Expand Down
2 changes: 2 additions & 0 deletions symphony/models/utils/create_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def create_predictor(config: ml_collections.ConfigDict) -> Predictor:
range_max=radial_predictor_config.max_radius,
num_layers=radial_predictor_config.num_layers,
latent_size=radial_predictor_config.latent_size,
num_targets=config.max_targets_per_graph,
)
else:
raise ValueError(
Expand All @@ -224,6 +225,7 @@ def create_predictor(config: ml_collections.ConfigDict) -> Predictor:
angular_predictor_fn=angular_predictor_fn,
radial_predictor_fn=radial_predictor_fn,
num_species=num_species,
num_targets=config.max_targets_per_graph,
)
predictor = Predictor(
focus_and_target_species_predictor=focus_and_target_species_predictor,
Expand Down
Loading