Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/code/newstrat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from gufe import AlchemicalNetwork, ProtocolResult
from gufe.tokenization import GufeKey
from gufe.transformations import Transformation, NonTransformation

# if including validators with settings, recommended
from pydantic import Field, field_validator
Expand Down Expand Up @@ -38,6 +38,6 @@ def _default_settings(cls) -> StrategySettings:
def _propose(
self,
alchem_network: AlchemicalNetwork,
protocol_results: dict[GufeKey, ProtocolResult]
protocol_results: dict[Transformation | NonTransformation, ProtocolResult]
) -> StrategyResult:
...
2 changes: 1 addition & 1 deletion docs/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ You can calculate transformation weights for an :external+gufe:py:class:`~gufe.n
settings = ConnectivityStrategy.default_settings()
strategy = ConnectivityStrategy(settings)

previous_results: dict[GufeKey, ProtocolResult] = {}
previous_results: dict[Transformation | NonTransformation, ProtocolResult] = {}

strategy_result: StrategyResult = strategy.propose(alchem_network, previous_results)

Expand Down
1 change: 1 addition & 0 deletions src/stratocaster/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .exceptions import StrategyResultValidationError
from .models import StrategySettings
from .strategy import Strategy, StrategyResult
2 changes: 2 additions & 0 deletions src/stratocaster/base/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class StrategyResultValidationError(Exception):
pass
63 changes: 46 additions & 17 deletions src/stratocaster/base/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,24 @@
from typing import TypeVar

from gufe import AlchemicalNetwork, ProtocolResult
from gufe.tokenization import GufeKey, GufeTokenizable
from gufe.tokenization import GufeTokenizable
from gufe.transformations import Transformation, NonTransformation

from .models import StrategySettings
from .exceptions import StrategyResultValidationError

TProtocolResult = TypeVar("TProtocolResult", bound=ProtocolResult)


class StrategyResult(GufeTokenizable):
"""Results produced by a Strategy."""

def __init__(self, weights: dict[GufeKey, float | None]):
self._weights = weights
def __init__(self, weights: dict[Transformation | NonTransformation, float | None]):
self._weights = [
[transformation, weight] for transformation, weight in weights.items()
]
if not self.validate():
raise StrategyResultValidationError

@classmethod
def _defaults(cls):
Expand All @@ -25,20 +31,21 @@ def _to_dict(self) -> dict:
# TODO: Return type from typing.Self when Python 3.10 is no longer supported
@classmethod
def _from_dict(cls, dct: dict):
return cls(**dct)
weights = dct["weights"]
return cls(
weights={transformation: weight for transformation, weight in weights}
)

@property
def weights(self) -> dict[GufeKey, float | None]:
return self._weights.copy()
def weights(self) -> dict[Transformation | NonTransformation, float | None]:
return {transformation: weight for transformation, weight in self._weights}

def resolve(self) -> dict[GufeKey, float | None]:
def resolve(self) -> dict[Transformation | NonTransformation, float | None]:
"""Get normalized proposal weights relative to all non-None Transformation weights."""
weight_sum = sum(
[weight for weight in self._weights.values() if weight is not None]
)
weight_sum = sum([weight for _, weight in self._weights if weight is not None])
normalized_weights = {
key: weight / weight_sum if weight is not None else None
for key, weight in self._weights.items()
transformation: weight / weight_sum if weight is not None else None
for transformation, weight in self._weights
}
return normalized_weights

Expand All @@ -49,6 +56,28 @@ def __or__(self, other):
)
return StrategyResult(self.weights | other.weights)

def validate(self) -> bool:
for transformation, weight in self._weights:
if not isinstance(transformation, (Transformation, NonTransformation)):
return False

if weight is None:
continue

match weight:
case None:
continue
case int():
weight = float(weight)
case float():
pass
case _:
return False

if weight < 0:
return False
return True


class Strategy(GufeTokenizable):
"""An object that proposes the relative urgency of computing
Expand Down Expand Up @@ -106,14 +135,14 @@ def default_settings(cls) -> StrategySettings:
def _propose(
self,
alchemical_network: AlchemicalNetwork,
protocol_results: dict[GufeKey, TProtocolResult],
protocol_results: dict[Transformation | NonTransformation, TProtocolResult],
) -> StrategyResult:
raise NotImplementedError

def propose(
self,
alchemical_network: AlchemicalNetwork,
protocol_results: dict[GufeKey, TProtocolResult],
protocol_results: dict[Transformation | NonTransformation, TProtocolResult],
) -> StrategyResult:
"""Compute Transformation weights from the ProtocolResults of
the Transformations.
Expand All @@ -122,9 +151,9 @@ def propose(
----------
alchemical_network: AlchemicalNetwork
The AlchemicalNetwork containing the Transformations.
protocol_results: dict[GufeKey, ProtocolResult]
A dictionary of Transformation GufeKeys paired with the
Transformation's ProtocolResults.
protocol_results: dict[Transformation | NonTransformation, ProtocolResult]
A dictionary of Transformations paired with the their
ProtocolResults.

Returns
-------
Expand Down
20 changes: 10 additions & 10 deletions src/stratocaster/strategies/connectivity.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from gufe import AlchemicalNetwork, ProtocolResult
from gufe.tokenization import GufeKey
from gufe.transformations import Transformation, NonTransformation

from stratocaster.base import Strategy, StrategyResult
from stratocaster.base.models import StrategySettings
Expand Down Expand Up @@ -90,15 +90,15 @@ def _exponential_decay_scaling(
def _propose(
self,
alchemical_network: AlchemicalNetwork,
protocol_results: dict[GufeKey, ProtocolResult],
protocol_results: dict[Transformation | NonTransformation, ProtocolResult],
) -> StrategyResult:
"""Propose `Transformation` weight recommendations based on high connectivity nodes.

Parameters
----------
alchemical_network: AlchemicalNetwork
protocol_results: dict[GufeKey, ProtocolResult]
A dictionary whose keys are the `GufeKey`s of `Transformation`s in the `AlchemicalNetwork`
protocol_results: dict[Transformation | NonTransformation, ProtocolResult]
A dictionary whose keys are the `Transformation`s of an `AlchemicalNetwork`
and whose values are the `ProtocolResult`s for those `Transformation`s.

Returns
Expand All @@ -113,7 +113,7 @@ def _propose(
assert isinstance(settings, ConnectivityStrategySettings)

alchemical_network_mdg = alchemical_network.graph
weights: dict[GufeKey, float | None] = {}
weights: dict[Transformation | NonTransformation, float | None] = {}

for state_a, state_b in alchemical_network_mdg.edges():
num_neighbors_a = alchemical_network_mdg.degree(state_a)
Expand All @@ -122,11 +122,11 @@ def _propose(
# linter-satisfying assertion
assert isinstance(num_neighbors_a, int) and isinstance(num_neighbors_b, int)

transformation_key = alchemical_network_mdg.get_edge_data(state_a, state_b)[
0
]["object"].key
transformation = alchemical_network_mdg.get_edge_data(state_a, state_b)[0][
"object"
]

match (protocol_results.get(transformation_key)):
match (protocol_results.get(transformation)):
case None:
transformation_n_protcol_dag_results = 0
case pr:
Expand All @@ -152,7 +152,7 @@ def _propose(
):
weight = None

weights[transformation_key] = weight
weights[transformation] = weight

results = StrategyResult(weights=weights)
return results
Expand Down
29 changes: 15 additions & 14 deletions src/stratocaster/strategies/radialgrowth.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import networkx as nx

from gufe import AlchemicalNetwork, ProtocolResult
from gufe.tokenization import GufeKey
from gufe.transformations import Transformation, NonTransformation

from pydantic import (
Field,
Expand Down Expand Up @@ -103,7 +103,7 @@ def _default_settings(cls) -> StrategySettings:
def _propose(
self,
alchemical_network: AlchemicalNetwork,
protocol_results: dict[GufeKey, ProtocolResult],
protocol_results: dict[Transformation | NonTransformation, ProtocolResult],
) -> StrategyResult:
"""Propose `Transformation` weight recommendations based on
`Transformation` distance from the graph center.
Expand All @@ -112,8 +112,9 @@ def _propose(
----------
alchemical_network
protocol_results
A dictionary whose keys are the `GufeKey`s of `Transformation`s in the `AlchemicalNetwork`
and whose values are the `ProtocolResult`s for those `Transformation`s.
A dictionary whose keys are the `Transformation`s (or `NonTransformation`s)
in the `AlchemicalNetwork` and whose values are the `ProtocolResult`s for
those `Transformation`s.

Returns
-------
Expand All @@ -123,7 +124,7 @@ def _propose(
"""

alchemical_network_mdg = alchemical_network.graph
weights: dict[GufeKey, float | None] = {}
weights: dict[Transformation | NonTransformation, float | None] = {}

# calculate all node eccentricies
e = nx.eccentricity(alchemical_network_mdg.to_undirected())
Expand All @@ -142,12 +143,12 @@ def _propose(
# find the range of eccentricies
lower, upper = min(edge), max(edge)

transformation_key = alchemical_network_mdg.get_edge_data(state_a, state_b)[
0
]["object"].key
transformation = alchemical_network_mdg.get_edge_data(state_a, state_b)[0][
"object"
]

factor_repeats = 1
match (protocol_results.get(transformation_key)):
match (protocol_results.get(transformation)):
case None:
transformation_n_protcol_dag_results = 0
# since we have no results for this
Expand All @@ -167,17 +168,17 @@ def _propose(

# stop condition given max runs
if self.settings.max_runs <= transformation_n_protcol_dag_results:
weights[transformation_key] = None
weights[transformation] = None
continue

# save the upper eccentricity for later when we know the
# lowest_completed. This is the transformation's effective
# distance from the center
transformation_eccentricity[transformation_key] = upper
weights[transformation_key] = factor_repeats
transformation_eccentricity[transformation] = upper
weights[transformation] = factor_repeats

# start applying weights due to effective distance
for transformation_key, e in transformation_eccentricity.items():
for transformation, e in transformation_eccentricity.items():
distance = e - lowest_complete_eccentricity

if distance <= self.settings.candidacy_max_distance:
Expand All @@ -195,6 +196,6 @@ def _propose(
# set to zero, not None
distance_factor = 0

weights[transformation_key] *= distance_factor
weights[transformation] *= distance_factor

return StrategyResult(weights)
26 changes: 11 additions & 15 deletions src/stratocaster/tests/test_connectivity_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from gufe import AlchemicalNetwork
from gufe.tests.test_protocol import DummyProtocol, DummyProtocolResult
from gufe.transformations import Transformation, NonTransformation

from stratocaster.base.models import StrategySettings
from stratocaster.base.strategy import StrategyResult
Expand All @@ -14,8 +15,6 @@

from stratocaster.tests.utils import StrategyTestMixin

from gufe.tokenization import GufeKey


class TestConnectivityStrategy(StrategyTestMixin):

Expand Down Expand Up @@ -47,13 +46,12 @@ def test_propose_cutoff_num_runs_predictioned_termination(
math.log(settings.cutoff / 3.5) / math.log(settings.decay_rate)
)

result_data: dict[GufeKey, DummyProtocolResult] = {}
result_data: dict[Transformation | NonTransformation, DummyProtocolResult] = {}
for transformation in benzene_variants_star_map.edges:
transformation_key = transformation.key
result = DummyProtocolResult(
n_protocol_dag_results=num_runs + 1, info=f"key: {transformation_key}"
n_protocol_dag_results=num_runs + 1, info=f"key: {transformation.key}"
)
result_data[transformation_key] = result
result_data[transformation] = result

results = strategy.propose(benzene_variants_star_map, result_data)

Expand All @@ -66,13 +64,12 @@ def test_propose_max_runs_termination(
max_runs = self.default_strategy.settings.max_runs
assert isinstance(max_runs, int)

result_data: dict[GufeKey, DummyProtocolResult] = {}
result_data: dict[Transformation | NonTransformation, DummyProtocolResult] = {}
for transformation in benzene_variants_star_map.edges:
transformation_key = transformation.key
result = DummyProtocolResult(
n_protocol_dag_results=max_runs, info=f"key: {transformation_key}"
n_protocol_dag_results=max_runs, info=f"key: {transformation.key}"
)
result_data[transformation_key] = result
result_data[transformation] = result

results = self.default_strategy.propose(benzene_variants_star_map, result_data)

Expand All @@ -85,13 +82,12 @@ def test_propose_previous_results(
self, benzene_variants_star_map: AlchemicalNetwork
):

result_data: dict[GufeKey, DummyProtocolResult] = {}
result_data: dict[Transformation | NonTransformation, DummyProtocolResult] = {}
for transformation in benzene_variants_star_map.edges:
transformation_key = transformation.key
result = DummyProtocolResult(
n_protocol_dag_results=2, info=f"key: {transformation_key}"
n_protocol_dag_results=2, info=f"key: {transformation.key}"
)
result_data[transformation_key] = result
result_data[transformation] = result

results = self.default_strategy.propose(benzene_variants_star_map, result_data)
results_no_data = self.default_strategy.propose(benzene_variants_star_map, {})
Expand All @@ -107,7 +103,7 @@ def test_propose_no_results(self, benzene_variants_star_map: AlchemicalNetwork):
benzene_variants_star_map, {}
)

assert all([weight == 3.5 for weight in proposal._weights.values()])
assert all([weight == 3.5 for _, weight in proposal._weights])
assert 1 == sum(
weight for weight in proposal.resolve().values() if weight is not None
)
Expand Down
4 changes: 2 additions & 2 deletions src/stratocaster/tests/test_strategy_abstraction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
from gufe import AlchemicalNetwork, ProtocolResult
from gufe.tokenization import GufeKey
from gufe.transformations import Transformation, NonTransformation

from stratocaster.base import Strategy, StrategySettings
from stratocaster.base.strategy import StrategyResult
Expand All @@ -23,7 +23,7 @@ def _default_settings(cls) -> StrategySettings:
def _propose(
self,
alchemical_network: AlchemicalNetwork,
protocol_results: dict[GufeKey, ProtocolResult],
protocol_results: dict[Transformation | NonTransformation, ProtocolResult],
) -> StrategyResult:
return StrategyResult({})

Expand Down
Loading
Loading