Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/stratocaster/base/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ def resolve(self) -> dict[GufeKey, float | None]:
weights.update(modified_weights)
return weights

def __or__(self, other):
if self.weights.keys() & other.weights.keys():
raise ValueError(
"StrategyResults can only be combined when their transformation keys are mutually exclusive."
)
return StrategyResult(self.weights | other.weights)


class Strategy(GufeTokenizable):
"""An object that proposes the relative urgency of computing
Expand Down Expand Up @@ -125,4 +132,8 @@ def propose(
StrategyResult

"""
return self._propose(alchemical_network, protocol_results)
subgraphs = alchemical_network.connected_subgraphs()
acc = StrategyResult({})
for subgraph in subgraphs:
acc |= self._propose(subgraph, protocol_results)
return acc
3 changes: 2 additions & 1 deletion src/stratocaster/strategies/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from stratocaster.strategies.connectivity import ConnectivityStrategy
from stratocaster.strategies.radialgrowth import RadialGrowthStrategy

__all__ = ["ConnectivityStrategy"]
__all__ = ["ConnectivityStrategy", "RadialGrowthStrategy"]
200 changes: 200 additions & 0 deletions src/stratocaster/strategies/radialgrowth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
import networkx as nx

from gufe import AlchemicalNetwork, ProtocolResult
from gufe.tokenization import GufeKey

from pydantic import (
Field,
field_validator,
)

from stratocaster.base import Strategy, StrategyResult
from stratocaster.base.models import StrategySettings


class RadialGrowthStrategySettings(StrategySettings):
"""Settings required for the RadialGrowthStrategy."""

max_runs: int = Field(
default=3,
description="the upper limit of ProtocolDAG results needed before a Transformation is no longer weighed",
)

candidacy_max_distance: int = Field(
default=1,
description="the maximum distance a candidate ChemicalSystem can be from previously reached ChemicalSystems",
)

decay_repeat_rate: float = Field(
default=0.5,
description="decay rate of the exponential repeat decay penalty factor",
)

decay_distance_rate: float = Field(
default=0.5,
description="decay rate of the exponential distance decay penalty factor",
)

@field_validator("max_runs", mode="before")
def validate_max_runs(cls, value):
if not value >= 1:
raise ValueError("`max_runs` must be greater than or equal to 1")
return value

@field_validator("candidacy_max_distance", mode="before")
def validate_candidate_max_distance(cls, value):
if not value >= 1:
raise ValueError(
"`candidtate_max_distance` must be greater than or equal to 1"
)
return value

@field_validator("decay_repeat_rate", mode="before")
def validate_decay_repeat_rate(cls, value):
if not (0 < value < 1):
raise ValueError("`decay_repeat_rate` must be between 0 and 1")
return value

@field_validator("decay_distance_rate", mode="before")
def validate_decay_distance_rate(cls, value):
if not (0 < value < 1):
raise ValueError("`decay_distance_rate` must be between 0 and 1")
return value


class RadialGrowthStrategy(Strategy):
r"""A Strategy that favors Transformations close to the network
center.

The weight assigned to each Transformation depends on its highest
ChemicalSystem distance, as measured and labeled by the vertex
eccentricty, from the lowest completed tier of distances. In the
graph below, if at least one ``ProtocolDAGResult`` exists for both
edges in 3-2-3, the lowest completed distance is 3. If neither
edge or only one edge has a result, the lowest completed distance
is 2. In the former case, a transformation going from 3 to 4 has
an effective distance of 1, while in the latter case it has a
distance of 2.

.. code-block::

4 4
\ /
\ /
4--3-2-3--4
/ \
/ \
4 4

The strategy penalizes transformations that have high effective
distances by multiplying the weight with a penalty of :math:`r^d`, where r
is a user-specified decay rate and d is the effective
distance. The ``candidacy_max_distance`` setting limits how far
out transformations can be assigned non-zero weights.

"""

_settings_cls = RadialGrowthStrategySettings

@classmethod
def _default_settings(cls) -> StrategySettings:
return RadialGrowthStrategySettings(max_runs=3)

def _propose(
self,
alchemical_network: AlchemicalNetwork,
protocol_results: dict[GufeKey, ProtocolResult],
) -> StrategyResult:
"""Propose `Transformation` weight recommendations based on
`Transformation` distance from the graph center.

Parameters
----------
alchemical_network
protocol_results
A dictionary whose keys are the `GufeKey`s of `Transformation`s in the `AlchemicalNetwork`
and whose values are the `ProtocolResult`s for those `Transformation`s.

Returns
-------
StrategyResult
A `StrategyResult` containing the proposed `Transformation` weights.

"""

alchemical_network_mdg = alchemical_network.graph
weights: dict[GufeKey, float | None] = {}

# calculate all node eccentricies
e = nx.eccentricity(alchemical_network_mdg.to_undirected())

# start with the maximum value, this will be decremented as we
# see evidence the value should be lower
lowest_complete_eccentricity = max(e.values())
# hold on to the eccentricies of the transformations instead
# of the distances since we don't know the lowest complete
# eccentricity until we process the full graph, distances can
# be calculated after
transformation_eccentricity = {}

for state_a, state_b in alchemical_network_mdg.edges():
edge = e[state_a], e[state_b]
# find the range of eccentricies
lower, upper = min(edge), max(edge)

transformation_key = alchemical_network_mdg.get_edge_data(state_a, state_b)[
0
]["object"].key

factor_repeats = 1
match (protocol_results.get(transformation_key)):
case None:
transformation_n_protcol_dag_results = 0
# since we have no results for this
# transformation, we know the lowest complete
# eccentricity must be lower than the upper
# eccentricity of the transformation
if upper < lowest_complete_eccentricity:
lowest_complete_eccentricity = lower
case pr:
transformation_n_protcol_dag_results = pr.n_protocol_dag_results
# scale the repeat factor to discourage reruns as
# specified by the user's decay_repeat_rate
factor_repeats *= (
self.settings.decay_repeat_rate
** transformation_n_protcol_dag_results
)

# stop condition given max runs
if self.settings.max_runs <= transformation_n_protcol_dag_results:
weights[transformation_key] = None
continue

# save the upper eccentricity for later when we know the
# lowest_completed. This is the transformation's effective
# distance from the center
transformation_eccentricity[transformation_key] = upper
weights[transformation_key] = factor_repeats

# start applying weights due to effective distance
for transformation_key, e in transformation_eccentricity.items():
distance = e - lowest_complete_eccentricity

if distance <= self.settings.candidacy_max_distance:
# edge case where there are multiple vertices with
# eccentricity equal to graph radius
if distance == 0:
distance_factor = 1
else:
# scale the distance factor to limit the
# calculation of far-out transformations
distance_factor = self.settings.decay_distance_rate ** (
distance - 1
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why we have a - 1 here?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it because we want distance == 1 cases to always get a distance_factor == 1?

)
else:
# set to zero, not None
distance_factor = 0

weights[transformation_key] *= distance_factor

return StrategyResult(weights)
22 changes: 22 additions & 0 deletions src/stratocaster/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest

from stratocaster.tests.networks import (
benzene_variants_star_map as _benzene_variants_star_map,
fanning_network as _fanning_network,
disconnected_fanning_network as _disconnected_fanning_network,
)


@pytest.fixture(scope="module")
def benzene_variants_star_map():
return _benzene_variants_star_map()


@pytest.fixture(scope="module")
def fanning_network():
return _fanning_network()


@pytest.fixture(scope="module")
def disconnected_fanning_network():
return _disconnected_fanning_network()
92 changes: 92 additions & 0 deletions src/stratocaster/tests/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import gufe
from gufe.tests.test_protocol import DummyProtocol
import networkx as nx
from openff.units import unit
from rdkit import Chem

Expand Down Expand Up @@ -115,3 +116,94 @@ def benzene_variants_star_map():
return gufe.AlchemicalNetwork(
solvated_ligand_transformations + solvated_complex_transformations
)


def digraph_to_alchemical_network(digraph):
"""Convert a digraph to an AlchemicalNetwork."""
node_to_chemical_system = lambda n: gufe.ChemicalSystem({}, name=n)

transformations = []
for a, b in digraph.edges():
c_a = node_to_chemical_system(a)
c_b = node_to_chemical_system(b)
protocol = None
transformation = gufe.Transformation(
c_a, c_b, DummyProtocol(settings=DummyProtocol.default_settings())
)
transformations.append(transformation)

return gufe.AlchemicalNetwork(transformations)


def int_sampler(sampler_start=0):
num = sampler_start
while True:
yield num
num += 1


def fanning_network(branch=3, depth=3):
"""Generate a network with a central node that recursively "fans"
out. Branch determines how many edges branch off of each node
while depth determines how many times the process is repeated.
"""
G = nx.DiGraph()
G.add_edges_from(fan(0, branch=branch, depth=depth))
an = digraph_to_alchemical_network(G)
return an


def dual_center_fanning_network(branch=3, depth=3):
"""Generate a network with two central nodes that recursively
"fan" out. Branch determines how many edges branch off of each
node while depth determines how many times the process is
repeated.
"""

G = nx.DiGraph()

edges = {(0, 1)}

sampler = int_sampler(sampler_start=2)
edges |= fan(0, branch=branch, depth=depth, id_generator=sampler)
edges |= fan(1, branch=branch, depth=depth, id_generator=sampler)

G.add_edges_from(edges)

an = digraph_to_alchemical_network(G)
return an


def disconnected_fanning_network(branch=3, depth=3):
"""Generate a network with disconnected fanning subgraphs."""

G = nx.DiGraph()

edges = set()

sampler = int_sampler(sampler_start=2)
edges |= fan(0, branch=branch, depth=depth, id_generator=sampler)
edges |= fan(1, branch=branch, depth=depth, id_generator=sampler)

G.add_edges_from(edges)

assert not nx.is_weakly_connected(G)

an = digraph_to_alchemical_network(G)
return an


def fan(node, branch=3, depth=3, id_generator=None):
"""Recursive edge generator."""

id_generator = id_generator or int_sampler(sampler_start=node + 1)

if depth == 0:
return set()

edges = {(node, next(id_generator)) for _ in range(branch)}

for _, next_node in edges.copy():
edges |= fan(next_node, depth=depth - 1, id_generator=id_generator)

return edges
Loading
Loading