diff --git a/docs/code/newstrat.py b/docs/code/newstrat.py index cce3784..16dc642 100644 --- a/docs/code/newstrat.py +++ b/docs/code/newstrat.py @@ -1,5 +1,5 @@ from gufe import AlchemicalNetwork, ProtocolResult -from gufe.tokenization import GufeKey +from gufe.transformations import Transformation, NonTransformation # if including validators with settings, recommended from pydantic import Field, field_validator @@ -38,6 +38,6 @@ def _default_settings(cls) -> StrategySettings: def _propose( self, alchem_network: AlchemicalNetwork, - protocol_results: dict[GufeKey, ProtocolResult] + protocol_results: dict[Transformation | NonTransformation, ProtocolResult] ) -> StrategyResult: ... diff --git a/docs/getting_started.rst b/docs/getting_started.rst index 212d0f1..b12e795 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -31,7 +31,7 @@ You can calculate transformation weights for an :external+gufe:py:class:`~gufe.n settings = ConnectivityStrategy.default_settings() strategy = ConnectivityStrategy(settings) - previous_results: dict[GufeKey, ProtocolResult] = {} + previous_results: dict[Transformation | NonTransformation, ProtocolResult] = {} strategy_result: StrategyResult = strategy.propose(alchem_network, previous_results) diff --git a/src/stratocaster/base/__init__.py b/src/stratocaster/base/__init__.py index dbf17b6..963397a 100644 --- a/src/stratocaster/base/__init__.py +++ b/src/stratocaster/base/__init__.py @@ -1,2 +1,3 @@ +from .exceptions import StrategyResultValidationError from .models import StrategySettings from .strategy import Strategy, StrategyResult diff --git a/src/stratocaster/base/exceptions.py b/src/stratocaster/base/exceptions.py new file mode 100644 index 0000000..245c1ba --- /dev/null +++ b/src/stratocaster/base/exceptions.py @@ -0,0 +1,2 @@ +class StrategyResultValidationError(Exception): + pass diff --git a/src/stratocaster/base/strategy.py b/src/stratocaster/base/strategy.py index 5817fc3..2457d26 100644 --- a/src/stratocaster/base/strategy.py +++ b/src/stratocaster/base/strategy.py @@ -2,9 +2,11 @@ from typing import TypeVar from gufe import AlchemicalNetwork, ProtocolResult -from gufe.tokenization import GufeKey, GufeTokenizable +from gufe.tokenization import GufeTokenizable +from gufe.transformations import Transformation, NonTransformation from .models import StrategySettings +from .exceptions import StrategyResultValidationError TProtocolResult = TypeVar("TProtocolResult", bound=ProtocolResult) @@ -12,8 +14,12 @@ class StrategyResult(GufeTokenizable): """Results produced by a Strategy.""" - def __init__(self, weights: dict[GufeKey, float | None]): - self._weights = weights + def __init__(self, weights: dict[Transformation | NonTransformation, float | None]): + self._weights = [ + [transformation, weight] for transformation, weight in weights.items() + ] + if not self.validate(): + raise StrategyResultValidationError @classmethod def _defaults(cls): @@ -25,20 +31,21 @@ def _to_dict(self) -> dict: # TODO: Return type from typing.Self when Python 3.10 is no longer supported @classmethod def _from_dict(cls, dct: dict): - return cls(**dct) + weights = dct["weights"] + return cls( + weights={transformation: weight for transformation, weight in weights} + ) @property - def weights(self) -> dict[GufeKey, float | None]: - return self._weights.copy() + def weights(self) -> dict[Transformation | NonTransformation, float | None]: + return {transformation: weight for transformation, weight in self._weights} - def resolve(self) -> dict[GufeKey, float | None]: + def resolve(self) -> dict[Transformation | NonTransformation, float | None]: """Get normalized proposal weights relative to all non-None Transformation weights.""" - weight_sum = sum( - [weight for weight in self._weights.values() if weight is not None] - ) + weight_sum = sum([weight for _, weight in self._weights if weight is not None]) normalized_weights = { - key: weight / weight_sum if weight is not None else None - for key, weight in self._weights.items() + transformation: weight / weight_sum if weight is not None else None + for transformation, weight in self._weights } return normalized_weights @@ -49,6 +56,28 @@ def __or__(self, other): ) return StrategyResult(self.weights | other.weights) + def validate(self) -> bool: + for transformation, weight in self._weights: + if not isinstance(transformation, (Transformation, NonTransformation)): + return False + + if weight is None: + continue + + match weight: + case None: + continue + case int(): + weight = float(weight) + case float(): + pass + case _: + return False + + if weight < 0: + return False + return True + class Strategy(GufeTokenizable): """An object that proposes the relative urgency of computing @@ -106,14 +135,14 @@ def default_settings(cls) -> StrategySettings: def _propose( self, alchemical_network: AlchemicalNetwork, - protocol_results: dict[GufeKey, TProtocolResult], + protocol_results: dict[Transformation | NonTransformation, TProtocolResult], ) -> StrategyResult: raise NotImplementedError def propose( self, alchemical_network: AlchemicalNetwork, - protocol_results: dict[GufeKey, TProtocolResult], + protocol_results: dict[Transformation | NonTransformation, TProtocolResult], ) -> StrategyResult: """Compute Transformation weights from the ProtocolResults of the Transformations. @@ -122,9 +151,9 @@ def propose( ---------- alchemical_network: AlchemicalNetwork The AlchemicalNetwork containing the Transformations. - protocol_results: dict[GufeKey, ProtocolResult] - A dictionary of Transformation GufeKeys paired with the - Transformation's ProtocolResults. + protocol_results: dict[Transformation | NonTransformation, ProtocolResult] + A dictionary of Transformations paired with the their + ProtocolResults. Returns ------- diff --git a/src/stratocaster/strategies/connectivity.py b/src/stratocaster/strategies/connectivity.py index d6a310b..593dfad 100644 --- a/src/stratocaster/strategies/connectivity.py +++ b/src/stratocaster/strategies/connectivity.py @@ -1,5 +1,5 @@ from gufe import AlchemicalNetwork, ProtocolResult -from gufe.tokenization import GufeKey +from gufe.transformations import Transformation, NonTransformation from stratocaster.base import Strategy, StrategyResult from stratocaster.base.models import StrategySettings @@ -90,15 +90,15 @@ def _exponential_decay_scaling( def _propose( self, alchemical_network: AlchemicalNetwork, - protocol_results: dict[GufeKey, ProtocolResult], + protocol_results: dict[Transformation | NonTransformation, ProtocolResult], ) -> StrategyResult: """Propose `Transformation` weight recommendations based on high connectivity nodes. Parameters ---------- alchemical_network: AlchemicalNetwork - protocol_results: dict[GufeKey, ProtocolResult] - A dictionary whose keys are the `GufeKey`s of `Transformation`s in the `AlchemicalNetwork` + protocol_results: dict[Transformation | NonTransformation, ProtocolResult] + A dictionary whose keys are the `Transformation`s of an `AlchemicalNetwork` and whose values are the `ProtocolResult`s for those `Transformation`s. Returns @@ -113,7 +113,7 @@ def _propose( assert isinstance(settings, ConnectivityStrategySettings) alchemical_network_mdg = alchemical_network.graph - weights: dict[GufeKey, float | None] = {} + weights: dict[Transformation | NonTransformation, float | None] = {} for state_a, state_b in alchemical_network_mdg.edges(): num_neighbors_a = alchemical_network_mdg.degree(state_a) @@ -122,11 +122,11 @@ def _propose( # linter-satisfying assertion assert isinstance(num_neighbors_a, int) and isinstance(num_neighbors_b, int) - transformation_key = alchemical_network_mdg.get_edge_data(state_a, state_b)[ - 0 - ]["object"].key + transformation = alchemical_network_mdg.get_edge_data(state_a, state_b)[0][ + "object" + ] - match (protocol_results.get(transformation_key)): + match (protocol_results.get(transformation)): case None: transformation_n_protcol_dag_results = 0 case pr: @@ -152,7 +152,7 @@ def _propose( ): weight = None - weights[transformation_key] = weight + weights[transformation] = weight results = StrategyResult(weights=weights) return results diff --git a/src/stratocaster/strategies/radialgrowth.py b/src/stratocaster/strategies/radialgrowth.py index 27f0cc8..3e60398 100644 --- a/src/stratocaster/strategies/radialgrowth.py +++ b/src/stratocaster/strategies/radialgrowth.py @@ -1,7 +1,7 @@ import networkx as nx from gufe import AlchemicalNetwork, ProtocolResult -from gufe.tokenization import GufeKey +from gufe.transformations import Transformation, NonTransformation from pydantic import ( Field, @@ -103,7 +103,7 @@ def _default_settings(cls) -> StrategySettings: def _propose( self, alchemical_network: AlchemicalNetwork, - protocol_results: dict[GufeKey, ProtocolResult], + protocol_results: dict[Transformation | NonTransformation, ProtocolResult], ) -> StrategyResult: """Propose `Transformation` weight recommendations based on `Transformation` distance from the graph center. @@ -112,8 +112,9 @@ def _propose( ---------- alchemical_network protocol_results - A dictionary whose keys are the `GufeKey`s of `Transformation`s in the `AlchemicalNetwork` - and whose values are the `ProtocolResult`s for those `Transformation`s. + A dictionary whose keys are the `Transformation`s (or `NonTransformation`s) + in the `AlchemicalNetwork` and whose values are the `ProtocolResult`s for + those `Transformation`s. Returns ------- @@ -123,7 +124,7 @@ def _propose( """ alchemical_network_mdg = alchemical_network.graph - weights: dict[GufeKey, float | None] = {} + weights: dict[Transformation | NonTransformation, float | None] = {} # calculate all node eccentricies e = nx.eccentricity(alchemical_network_mdg.to_undirected()) @@ -142,12 +143,12 @@ def _propose( # find the range of eccentricies lower, upper = min(edge), max(edge) - transformation_key = alchemical_network_mdg.get_edge_data(state_a, state_b)[ - 0 - ]["object"].key + transformation = alchemical_network_mdg.get_edge_data(state_a, state_b)[0][ + "object" + ] factor_repeats = 1 - match (protocol_results.get(transformation_key)): + match (protocol_results.get(transformation)): case None: transformation_n_protcol_dag_results = 0 # since we have no results for this @@ -167,17 +168,17 @@ def _propose( # stop condition given max runs if self.settings.max_runs <= transformation_n_protcol_dag_results: - weights[transformation_key] = None + weights[transformation] = None continue # save the upper eccentricity for later when we know the # lowest_completed. This is the transformation's effective # distance from the center - transformation_eccentricity[transformation_key] = upper - weights[transformation_key] = factor_repeats + transformation_eccentricity[transformation] = upper + weights[transformation] = factor_repeats # start applying weights due to effective distance - for transformation_key, e in transformation_eccentricity.items(): + for transformation, e in transformation_eccentricity.items(): distance = e - lowest_complete_eccentricity if distance <= self.settings.candidacy_max_distance: @@ -195,6 +196,6 @@ def _propose( # set to zero, not None distance_factor = 0 - weights[transformation_key] *= distance_factor + weights[transformation] *= distance_factor return StrategyResult(weights) diff --git a/src/stratocaster/tests/test_connectivity_strategy.py b/src/stratocaster/tests/test_connectivity_strategy.py index d3bc673..81d2977 100644 --- a/src/stratocaster/tests/test_connectivity_strategy.py +++ b/src/stratocaster/tests/test_connectivity_strategy.py @@ -4,6 +4,7 @@ import pytest from gufe import AlchemicalNetwork from gufe.tests.test_protocol import DummyProtocol, DummyProtocolResult +from gufe.transformations import Transformation, NonTransformation from stratocaster.base.models import StrategySettings from stratocaster.base.strategy import StrategyResult @@ -14,8 +15,6 @@ from stratocaster.tests.utils import StrategyTestMixin -from gufe.tokenization import GufeKey - class TestConnectivityStrategy(StrategyTestMixin): @@ -47,13 +46,12 @@ def test_propose_cutoff_num_runs_predictioned_termination( math.log(settings.cutoff / 3.5) / math.log(settings.decay_rate) ) - result_data: dict[GufeKey, DummyProtocolResult] = {} + result_data: dict[Transformation | NonTransformation, DummyProtocolResult] = {} for transformation in benzene_variants_star_map.edges: - transformation_key = transformation.key result = DummyProtocolResult( - n_protocol_dag_results=num_runs + 1, info=f"key: {transformation_key}" + n_protocol_dag_results=num_runs + 1, info=f"key: {transformation.key}" ) - result_data[transformation_key] = result + result_data[transformation] = result results = strategy.propose(benzene_variants_star_map, result_data) @@ -66,13 +64,12 @@ def test_propose_max_runs_termination( max_runs = self.default_strategy.settings.max_runs assert isinstance(max_runs, int) - result_data: dict[GufeKey, DummyProtocolResult] = {} + result_data: dict[Transformation | NonTransformation, DummyProtocolResult] = {} for transformation in benzene_variants_star_map.edges: - transformation_key = transformation.key result = DummyProtocolResult( - n_protocol_dag_results=max_runs, info=f"key: {transformation_key}" + n_protocol_dag_results=max_runs, info=f"key: {transformation.key}" ) - result_data[transformation_key] = result + result_data[transformation] = result results = self.default_strategy.propose(benzene_variants_star_map, result_data) @@ -85,13 +82,12 @@ def test_propose_previous_results( self, benzene_variants_star_map: AlchemicalNetwork ): - result_data: dict[GufeKey, DummyProtocolResult] = {} + result_data: dict[Transformation | NonTransformation, DummyProtocolResult] = {} for transformation in benzene_variants_star_map.edges: - transformation_key = transformation.key result = DummyProtocolResult( - n_protocol_dag_results=2, info=f"key: {transformation_key}" + n_protocol_dag_results=2, info=f"key: {transformation.key}" ) - result_data[transformation_key] = result + result_data[transformation] = result results = self.default_strategy.propose(benzene_variants_star_map, result_data) results_no_data = self.default_strategy.propose(benzene_variants_star_map, {}) @@ -107,7 +103,7 @@ def test_propose_no_results(self, benzene_variants_star_map: AlchemicalNetwork): benzene_variants_star_map, {} ) - assert all([weight == 3.5 for weight in proposal._weights.values()]) + assert all([weight == 3.5 for _, weight in proposal._weights]) assert 1 == sum( weight for weight in proposal.resolve().values() if weight is not None ) diff --git a/src/stratocaster/tests/test_strategy_abstraction.py b/src/stratocaster/tests/test_strategy_abstraction.py index 2d50150..f347ca6 100644 --- a/src/stratocaster/tests/test_strategy_abstraction.py +++ b/src/stratocaster/tests/test_strategy_abstraction.py @@ -1,6 +1,6 @@ import pytest from gufe import AlchemicalNetwork, ProtocolResult -from gufe.tokenization import GufeKey +from gufe.transformations import Transformation, NonTransformation from stratocaster.base import Strategy, StrategySettings from stratocaster.base.strategy import StrategyResult @@ -23,7 +23,7 @@ def _default_settings(cls) -> StrategySettings: def _propose( self, alchemical_network: AlchemicalNetwork, - protocol_results: dict[GufeKey, ProtocolResult], + protocol_results: dict[Transformation | NonTransformation, ProtocolResult], ) -> StrategyResult: return StrategyResult({}) diff --git a/src/stratocaster/tests/test_strategy_base.py b/src/stratocaster/tests/test_strategy_base.py index aaabf58..e26a90e 100644 --- a/src/stratocaster/tests/test_strategy_base.py +++ b/src/stratocaster/tests/test_strategy_base.py @@ -1,18 +1,34 @@ +import pytest + from gufe import AlchemicalNetwork, ProtocolResult -from gufe.tokenization import GufeKey +from gufe.transformations import Transformation, NonTransformation + +from stratocaster.base import ( + Strategy, + StrategyResult, + StrategyResultValidationError, + StrategySettings, +) + + +class DummyTransformation(Transformation): + pass + -from stratocaster.base import Strategy, StrategyResult, StrategySettings +class DummyNonTransformation(NonTransformation): + pass class TestStrategyResult: result = StrategyResult( { - GufeKey("MyTransformation-ABC123"): 1, - GufeKey("MyTransformation-321CBA"): None, - GufeKey("MyOtherTransformation-789xyz"): 10, + DummyTransformation(stateA=0, stateB=1, protocol=None): 1, + DummyTransformation(stateA=1, stateB=2, protocol=None): None, + DummyNonTransformation(system=2, protocol=None): 10, } ) + assert 3 == len(result.weights) def test_dict_roundtrip(self): assert StrategyResult.from_dict(self.result.to_dict()) == self.result @@ -28,6 +44,29 @@ def test_resolve_normalization(self): res = self.result.resolve() assert 1 == sum([value for _, value in res.items() if value is not None]) + def test_validation(self): + + # negative weight should not pass validation + with pytest.raises(StrategyResultValidationError): + result = StrategyResult( + { + DummyTransformation(stateA=0, stateB=1, protocol=None): -1, + DummyTransformation(stateA=1, stateB=2, protocol=None): None, + DummyNonTransformation(system=2, protocol=None): 10, + } + ) + + # Keys must be either Transformation or NonTransformation, + # provide old style keys for test + with pytest.raises(StrategyResultValidationError): + result = StrategyResult( + { + DummyTransformation(stateA=0, stateB=1, protocol=None).key: 1, + DummyTransformation(stateA=1, stateB=2, protocol=None).key: None, + DummyNonTransformation(system=2, protocol=None).key: 10, + } + ) + class DummyStrategySettings(StrategySettings): pass @@ -44,7 +83,7 @@ def _default_settings(cls) -> DummyStrategySettings: def _propose( self, alchemical_network: AlchemicalNetwork, - protocol_results: dict[GufeKey, ProtocolResult], + protocol_results: dict[Transformation | NonTransformation, ProtocolResult], ): assert alchemical_network, protocol_results return StrategyResult({}) diff --git a/src/stratocaster/tests/utils.py b/src/stratocaster/tests/utils.py index e5f6981..f573582 100644 --- a/src/stratocaster/tests/utils.py +++ b/src/stratocaster/tests/utils.py @@ -45,7 +45,7 @@ def test_deterministic(self, fanning_network, settings=None): def random_runs(): """Generate random inputs for propose.""" return { - transformation.key: DummyProtocolResult( + transformation: DummyProtocolResult( n_protocol_dag_results=randint(0, 3), info=f"key: {transformation.key}", ) @@ -85,11 +85,11 @@ def test_simulated_termination(self, fanning_network, settings=None): def counts_to_result_data(counts_dict): result_data = {} - for transformation_key, count in counts_dict.items(): + for transformation, count in counts_dict.items(): result = DummyProtocolResult( - n_protocol_dag_results=count, info=f"key: {transformation_key}" + n_protocol_dag_results=count, info=f"key: {transformation.key}" ) - result_data[transformation_key] = result + result_data[transformation] = result return result_data def shuffle_take_n(keys_list, n): @@ -98,7 +98,7 @@ def shuffle_take_n(keys_list, n): # initial transforms transformation_counts = { - transformation.key: 0 for transformation in fanning_network.edges + transformation: 0 for transformation in fanning_network.edges } max_iterations = 100 @@ -114,19 +114,19 @@ def shuffle_take_n(keys_list, n): proposal = strategy.propose(fanning_network, result_data) # get random transformations from those with a non-None weight - resolved_keys = shuffle_take_n( + resolved_transformations = shuffle_take_n( [ - key - for key, weight in proposal.resolve().items() + transformation + for transformation, weight in proposal.resolve().items() if weight is not None ], 5, ) - if resolved_keys: + if resolved_transformations: # pretend we ran each of the randomly selected protocols - for key in resolved_keys: - transformation_counts[key] += 1 + for transformation in resolved_transformations: + transformation_counts[transformation] += 1 # if we got an empty list back, there are not more protocols to run else: break