From 1c7cd30a1c42c4cd1e141748eb90c330af0a110f Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Sat, 28 Feb 2026 12:23:23 -0500 Subject: [PATCH 01/15] Untested implementation of RadialGrowthStrategy New strategy that determines weights from eccentricity. It results in a radial growth with adjustable over-extension weight factors to reduce bottleneck behavior. Additionally, the strategy has an adjustable edge repeat weight factor. --- src/stratocaster/strategies/radialgrowth.py | 134 ++++++++++++++++++ src/stratocaster/tests/networks.py | 73 ++++++++++ .../tests/test_radialgrowth_strategy.py | 30 ++++ 3 files changed, 237 insertions(+) create mode 100644 src/stratocaster/strategies/radialgrowth.py create mode 100644 src/stratocaster/tests/test_radialgrowth_strategy.py diff --git a/src/stratocaster/strategies/radialgrowth.py b/src/stratocaster/strategies/radialgrowth.py new file mode 100644 index 0000000..b721ed0 --- /dev/null +++ b/src/stratocaster/strategies/radialgrowth.py @@ -0,0 +1,134 @@ +import networkx as nx + +from gufe import AlchemicalNetwork, ProtocolResult +from gufe.tokenization import GufeKey + +from pydantic import ( + Field, + model_validator, + field_validator, +) + +import pydantic + +from stratocaster.base import Strategy, StrategyResult +from stratocaster.base.models import StrategySettings + + +class RadialGrowthStrategySettings(StrategySettings): + + max_runs: int = Field( + default=3, + description="the upper limit of protocol DAG results needed before a transformation is no longer weighed", + ) + + candidacy_max_distance: int = Field( + default=1, + description="the maximum distance a candidate chemical can be from previously reached chemical systems", + ) + + decay_repeat_rate: float = Field( + default=0.5, + description="decay rate of the exponential repeat decay penalty factor", + ) + + decay_distance_rate: float = Field( + default=0.5, + description="decay rate of the exponential distance decay penalty factor", + ) + + @field_validator("max_runs", mode="before") + def validate_max_runs(cls, value): + if not value >= 1: + raise ValueError("`max_runs` must be greater than or equal to 1") + return value + + @field_validator("candidacy_max_distance", mode="before") + def validate_candidtate_max_distance(cls, value): + if not value >= 1: + raise ValueError( + "`candidtate_max_distance` must be greater than or equal to 1" + ) + return value + + @field_validator("decay_repeat_rate", mode="before") + def validate_decay_repeat_rate(cls, value): + if not (0 < value < 1): + raise ValueError("`decay_repeat_rate` must be between 0 and 1") + return value + + @field_validator("decay_distance_rate", mode="before") + def validate_decay_distance_rate(cls, value): + if not (0 < value < 1): + raise ValueError("`decay_distance_rate` must be between 0 and 1") + return value + + +class RadialGrowthStrategy(Strategy): + + _settings_cls = RadialGrowthStrategySettings + + @classmethod + def _default_settings(cls) -> StrategySettings: + return RadialGrowthStrategySettings(max_runs=3) + + def _propose( + self, + alchemical_network: AlchemicalNetwork, + protocol_results: dict[GufeKey, ProtocolResult], + ) -> StrategyResult: + + alchemical_network_mdg = alchemical_network.graph + weights: dict[GufeKey, float | None] = {} + + e = nx.eccentricity(alchemical_network_mdg.to_undirected()) + + lowest_complete_eccentricity = max(e.values()) + transformation_eccentricity = {} + + for state_a, state_b in alchemical_network_mdg.edges(): + edge = e[state_a], e[state_b] + lower, upper = min(edge), max(edge) + + transformation_key = alchemical_network_mdg.get_edge_data(state_a, state_b)[ + 0 + ]["object"].key + + factor_distance = 1 + factor_repeats = 1 + + match (protocol_results.get(transformation_key)): + case None: + transformation_n_protcol_dag_results = 0 + if upper < lowest_complete_eccentricity: + lowest_complete_eccentricity = lower + case pr: + assert isinstance(pr, ProtocolResult) + transformation_n_protcol_dag_results = pr.n_protocol_dag_results + factor_repeats *= ( + self.settings.decay_repeat_rate + ** transformation_n_protcol_dag_results + ) + + # stop condition given max runs + if self.settings.max_runs <= transformation_n_protcol_dag_results: + weights[transformation_key] = None + continue + + # save the upper eccentricity for later when we know the lowest_completed + transformation_eccentricity[transformation_key] = upper + + weights[transformation_key] = factor_repeats + + distance_factor = 1 + for transformation_key, e in transformation_eccentricity.items(): + distance = e - lowest_complete_eccentricity + + if distance <= 1: + distance_factor = 1 + if distance > self.settings.candidacy_max_distance: + distance_factor = 0 + + weights[transformation_key] *= distance_factor + + return StrategyResult(weights) diff --git a/src/stratocaster/tests/networks.py b/src/stratocaster/tests/networks.py index a6aae9e..d3bb08c 100644 --- a/src/stratocaster/tests/networks.py +++ b/src/stratocaster/tests/networks.py @@ -11,6 +11,7 @@ import gufe from gufe.tests.test_protocol import DummyProtocol +import networkx as nx from openff.units import unit from rdkit import Chem @@ -115,3 +116,75 @@ def benzene_variants_star_map(): return gufe.AlchemicalNetwork( solvated_ligand_transformations + solvated_complex_transformations ) + + +def digraph_to_alchemical_network(digraph): + """Convert a digraph to an AlchemicalNetwork.""" + node_to_chemical_system = lambda n: gufe.ChemicalSystem({}, name=n) + + transformations = [] + for a, b in digraph.edges(): + c_a = node_to_chemical_system(a) + c_b = node_to_chemical_system(b) + protocol = None + transformation = gufe.Transformation( + c_a, c_b, DummyProtocol(settings=DummyProtocol.default_settings()) + ) + transformations.append(transformation) + + return gufe.AlchemicalNetwork(transformations) + + +def int_sampler(sampler_start=0): + num = sampler_start + while True: + yield num + num += 1 + + +def fanning_network(branch=3, depth=3): + """Generate a network with a central node that recursively "fans" + out. Branch determines how many edges branch off of each node + while depth determines how many times the process is repeated. + """ + G = nx.DiGraph() + G.add_edges_from(fan(0, branch=branch, depth=depth)) + an = digraph_to_alchemical_network(G) + return an + + +def dual_center_fanning_network(branch=3, depth=3): + """Generate a network with two central nodes that recursively + "fan" out. Branch determines how many edges branch off of each + node while depth determines how many times the process is + repeated. + """ + + G = nx.DiGraph() + + edges = {(0, 1)} + + sampler = int_sampler(sampler_start=2) + edges |= fan(0, branch=branch, depth=depth, id_generator=sampler) + edges |= fan(1, branch=branch, depth=depth, id_generator=sampler) + + G.add_edges_from(edges) + + an = digraph_to_alchemical_network(G) + return an + + +def fan(node, branch=3, depth=3, id_generator=None): + """Recursive edge generator.""" + + id_generator = id_generator or int_sampler(sampler_start=node + 1) + + if depth == 0: + return set() + + edges = {(node, next(id_generator)) for _ in range(branch)} + + for _, next_node in edges.copy(): + edges |= fan(next_node, depth=depth - 1, id_generator=id_generator) + + return edges diff --git a/src/stratocaster/tests/test_radialgrowth_strategy.py b/src/stratocaster/tests/test_radialgrowth_strategy.py new file mode 100644 index 0000000..502ba07 --- /dev/null +++ b/src/stratocaster/tests/test_radialgrowth_strategy.py @@ -0,0 +1,30 @@ +import pytest + +from gufe import AlchemicalNetwork + +from stratocaster.strategies.radialgrowth import ( + RadialGrowthStrategy, + RadialGrowthStrategySettings, +) +from stratocaster.tests.networks import ( + benzene_variants_star_map as _benzene_variants_star_map, + fanning_network as _fanning_network, +) + + +@pytest.fixture(scope="module") +def fanning_network(): + return _fanning_network() + + +@pytest.fixture +def default_strategy(): + _settings = RadialGrowthStrategy._default_settings() + return RadialGrowthStrategy(_settings) + + +def test_propose_no_results( + default_strategy: RadialGrowthStrategy, fanning_network: AlchemicalNetwork +): + proposal: StrategyResult = default_strategy.propose(fanning_network, {}) + raise NotImplementedError From d0c7d1f3c9f5874ffa6f3769afd7537bccc19b54 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Wed, 4 Mar 2026 19:59:41 -0500 Subject: [PATCH 02/15] Allow StrategyResults merging through "|" --- src/stratocaster/base/strategy.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/stratocaster/base/strategy.py b/src/stratocaster/base/strategy.py index a7e3d3e..83cb30e 100644 --- a/src/stratocaster/base/strategy.py +++ b/src/stratocaster/base/strategy.py @@ -43,6 +43,13 @@ def resolve(self) -> dict[GufeKey, float | None]: weights.update(modified_weights) return weights + def __or__(self, other): + if self.weights.keys() & other.weights.keys(): + raise ValueError( + "StrategyResults can only be combined when their transformation keys are mutually exclusive." + ) + return StrategyResult(self.weights | other.weights) + class Strategy(GufeTokenizable): """An object that proposes the relative urgency of computing @@ -125,4 +132,8 @@ def propose( StrategyResult """ - return self._propose(alchemical_network, protocol_results) + subgraphs = alchemical_network.connected_subgraphs() + acc = StrategyResult({}) + for subgraph in subgraphs: + acc |= self._propose(subgraph, protocol_results) + return acc From 922a97ece39b286e5b5fc605dd8e028cf96d120b Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Thu, 5 Mar 2026 10:18:57 -0500 Subject: [PATCH 03/15] Move fixtures to conftest.py --- src/stratocaster/tests/conftest.py | 22 +++++++++++++++++++ src/stratocaster/tests/networks.py | 19 ++++++++++++++++ .../tests/test_connectivity_strategy.py | 9 -------- .../tests/test_radialgrowth_strategy.py | 6 ----- 4 files changed, 41 insertions(+), 15 deletions(-) create mode 100644 src/stratocaster/tests/conftest.py diff --git a/src/stratocaster/tests/conftest.py b/src/stratocaster/tests/conftest.py new file mode 100644 index 0000000..a0f9986 --- /dev/null +++ b/src/stratocaster/tests/conftest.py @@ -0,0 +1,22 @@ +import pytest + +from stratocaster.tests.networks import ( + benzene_variants_star_map as _benzene_variants_star_map, + fanning_network as _fanning_network, + disconnected_fanning_network as _disconnected_fanning_network, +) + + +@pytest.fixture(scope="module") +def benzene_variants_star_map(): + return _benzene_variants_star_map() + + +@pytest.fixture(scope="module") +def fanning_network(): + return _fanning_network() + + +@pytest.fixture(scope="module") +def disconnected_fanning_network(): + return _disconnected_fanning_network() diff --git a/src/stratocaster/tests/networks.py b/src/stratocaster/tests/networks.py index d3bb08c..eb0fdef 100644 --- a/src/stratocaster/tests/networks.py +++ b/src/stratocaster/tests/networks.py @@ -174,6 +174,25 @@ def dual_center_fanning_network(branch=3, depth=3): return an +def disconnected_fanning_network(branch=3, depth=3): + """Generate a network with disconnected fanning subgraphs.""" + + G = nx.DiGraph() + + edges = set() + + sampler = int_sampler(sampler_start=2) + edges |= fan(0, branch=branch, depth=depth, id_generator=sampler) + edges |= fan(1, branch=branch, depth=depth, id_generator=sampler) + + G.add_edges_from(edges) + + assert not nx.is_connected(G.to_undirected()) + + an = digraph_to_alchemical_network(G) + return an + + def fan(node, branch=3, depth=3, id_generator=None): """Recursive edge generator.""" diff --git a/src/stratocaster/tests/test_connectivity_strategy.py b/src/stratocaster/tests/test_connectivity_strategy.py index 236c511..60368a3 100644 --- a/src/stratocaster/tests/test_connectivity_strategy.py +++ b/src/stratocaster/tests/test_connectivity_strategy.py @@ -11,15 +11,6 @@ ConnectivityStrategy, ConnectivityStrategySettings, ) -from stratocaster.tests.networks import ( - benzene_variants_star_map as _benzene_variants_star_map, -) - - -@pytest.fixture(scope="module") -def benzene_variants_star_map(): - return _benzene_variants_star_map() - from gufe.tokenization import GufeKey diff --git a/src/stratocaster/tests/test_radialgrowth_strategy.py b/src/stratocaster/tests/test_radialgrowth_strategy.py index 502ba07..02ae0e3 100644 --- a/src/stratocaster/tests/test_radialgrowth_strategy.py +++ b/src/stratocaster/tests/test_radialgrowth_strategy.py @@ -8,15 +8,9 @@ ) from stratocaster.tests.networks import ( benzene_variants_star_map as _benzene_variants_star_map, - fanning_network as _fanning_network, ) -@pytest.fixture(scope="module") -def fanning_network(): - return _fanning_network() - - @pytest.fixture def default_strategy(): _settings = RadialGrowthStrategy._default_settings() From 1d2dc51c9d05bd24049b3aaaac79bc0fde053736 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Thu, 5 Mar 2026 10:20:01 -0500 Subject: [PATCH 04/15] Test empty results and disconnected graph --- src/stratocaster/tests/test_radialgrowth_strategy.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/stratocaster/tests/test_radialgrowth_strategy.py b/src/stratocaster/tests/test_radialgrowth_strategy.py index 02ae0e3..03de9c3 100644 --- a/src/stratocaster/tests/test_radialgrowth_strategy.py +++ b/src/stratocaster/tests/test_radialgrowth_strategy.py @@ -21,4 +21,14 @@ def test_propose_no_results( default_strategy: RadialGrowthStrategy, fanning_network: AlchemicalNetwork ): proposal: StrategyResult = default_strategy.propose(fanning_network, {}) - raise NotImplementedError + # check that there is at least 1 non-None weight + assert not all(proposal.weights.values()) + + +def test_disconnected( + default_strategy: RadialGrowthStrategy, + disconnected_fanning_network: AlchemicalNetwork, +): + proposal: StrategyResult = default_strategy.propose( + disconnected_fanning_network, {} + ) From 25c9a29d02c2db3430e4f2813161fe57c146a5c4 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Thu, 5 Mar 2026 10:38:18 -0500 Subject: [PATCH 05/15] Handle edge case for distance --- src/stratocaster/strategies/radialgrowth.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/stratocaster/strategies/radialgrowth.py b/src/stratocaster/strategies/radialgrowth.py index b721ed0..c472d72 100644 --- a/src/stratocaster/strategies/radialgrowth.py +++ b/src/stratocaster/strategies/radialgrowth.py @@ -117,16 +117,22 @@ def _propose( # save the upper eccentricity for later when we know the lowest_completed transformation_eccentricity[transformation_key] = upper - weights[transformation_key] = factor_repeats distance_factor = 1 for transformation_key, e in transformation_eccentricity.items(): distance = e - lowest_complete_eccentricity - if distance <= 1: - distance_factor = 1 - if distance > self.settings.candidacy_max_distance: + if distance <= self.settings.candidacy_max_distance: + # edge case where there are multiple vertices with + # eccentricity equal to graph radius + if distance == 0: + distance_factor = 1 + else: + distance_factor = self.settings.decay_distance_rate ** ( + distance - 1 + ) + else: distance_factor = 0 weights[transformation_key] *= distance_factor From 211fd9cc05f8d05c0eb89e7f61ad51f724d64dac Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Thu, 5 Mar 2026 10:44:34 -0500 Subject: [PATCH 06/15] Test graph connectivity in the correct way --- src/stratocaster/tests/networks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stratocaster/tests/networks.py b/src/stratocaster/tests/networks.py index eb0fdef..b7456cc 100644 --- a/src/stratocaster/tests/networks.py +++ b/src/stratocaster/tests/networks.py @@ -187,7 +187,7 @@ def disconnected_fanning_network(branch=3, depth=3): G.add_edges_from(edges) - assert not nx.is_connected(G.to_undirected()) + assert not nx.is_weakly_connected(G) an = digraph_to_alchemical_network(G) return an From bbaea723b80477addd710bf721958aa980e916b3 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Fri, 6 Mar 2026 11:09:57 -0500 Subject: [PATCH 07/15] Create mixin test class --- .../tests/test_radialgrowth_strategy.py | 31 ++--------- src/stratocaster/tests/utils.py | 52 +++++++++++++++++++ 2 files changed, 55 insertions(+), 28 deletions(-) create mode 100644 src/stratocaster/tests/utils.py diff --git a/src/stratocaster/tests/test_radialgrowth_strategy.py b/src/stratocaster/tests/test_radialgrowth_strategy.py index 03de9c3..099230a 100644 --- a/src/stratocaster/tests/test_radialgrowth_strategy.py +++ b/src/stratocaster/tests/test_radialgrowth_strategy.py @@ -1,34 +1,9 @@ -import pytest - -from gufe import AlchemicalNetwork - from stratocaster.strategies.radialgrowth import ( RadialGrowthStrategy, - RadialGrowthStrategySettings, -) -from stratocaster.tests.networks import ( - benzene_variants_star_map as _benzene_variants_star_map, ) - -@pytest.fixture -def default_strategy(): - _settings = RadialGrowthStrategy._default_settings() - return RadialGrowthStrategy(_settings) - - -def test_propose_no_results( - default_strategy: RadialGrowthStrategy, fanning_network: AlchemicalNetwork -): - proposal: StrategyResult = default_strategy.propose(fanning_network, {}) - # check that there is at least 1 non-None weight - assert not all(proposal.weights.values()) +from stratocaster.tests.utils import StrategyTestMixin -def test_disconnected( - default_strategy: RadialGrowthStrategy, - disconnected_fanning_network: AlchemicalNetwork, -): - proposal: StrategyResult = default_strategy.propose( - disconnected_fanning_network, {} - ) +class TestRadialGrowth(StrategyTestMixin): + strategy_class = RadialGrowthStrategy diff --git a/src/stratocaster/tests/utils.py b/src/stratocaster/tests/utils.py new file mode 100644 index 0000000..8e1ce20 --- /dev/null +++ b/src/stratocaster/tests/utils.py @@ -0,0 +1,52 @@ +from random import randint +from gufe.tests.test_protocol import DummyProtocolResult + + +class StrategyTestMixin: + + _default_strategy = None + + @property + def default_strategy(self): + if not self._default_strategy: + _settings = self.strategy_class._default_settings() + self._default_strategy = self.strategy_class(_settings) + return self._default_strategy + + def test_deterministic(self, fanning_network): + settings = self.default_strategy.settings + + max_runs = settings.max_runs + assert isinstance(max_runs, int) + + def random_runs(): + """Generate random randomized inputs for propose.""" + return { + transformation.key: DummyProtocolResult( + n_protocol_dag_results=randint(0, max_runs), + info=f"key: {transformation.key}", + ) + for transformation in fanning_network.edges + } + + for _ in range(10): + random_protocol_results = random_runs() + proposal = self.default_strategy.propose( + fanning_network, protocol_results=random_protocol_results + ) + for _ in range(3): + _proposal = self.default_strategy.propose( + fanning_network, protocol_results=random_protocol_results + ) + assert _proposal == proposal + + def test_starts(self, fanning_network): + proposal: StrategyResult = self.default_strategy.propose(fanning_network, {}) + # check that there is at least 1 non-None weight + assert any(proposal.weights.values()) + + def test_disconnected( + self, + disconnected_fanning_network, + ): + self.default_strategy.propose(disconnected_fanning_network, {}) From 3af2b3c87c55eefa131177414d49ab4f973f6caf Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Fri, 6 Mar 2026 11:10:14 -0500 Subject: [PATCH 08/15] Use mixin test class for connectivity --- .../tests/test_connectivity_strategy.py | 303 ++++++++---------- 1 file changed, 137 insertions(+), 166 deletions(-) diff --git a/src/stratocaster/tests/test_connectivity_strategy.py b/src/stratocaster/tests/test_connectivity_strategy.py index 60368a3..0eff90c 100644 --- a/src/stratocaster/tests/test_connectivity_strategy.py +++ b/src/stratocaster/tests/test_connectivity_strategy.py @@ -12,205 +12,176 @@ ConnectivityStrategySettings, ) -from gufe.tokenization import GufeKey - -SETTINGS_VALID = [(0.5, 0.1, 10), (0.1, None, 10), (0.5, 0.1, None)] - - -@pytest.mark.parametrize( - ["decay_rate", "cutoff", "max_runs", "raises"], - [ - (0, None, None, ValueError), - (1, None, None, ValueError), - (0.5, 0, None, ValueError), - (0.5, None, 0, ValueError), - ] - + [(*vals, None) for vals in SETTINGS_VALID], # include all valid settings -) -def test_connectivity_strategy_settings(decay_rate, cutoff, max_runs, raises): - - def instantiate_settings(): - ConnectivityStrategySettings( - decay_rate=decay_rate, cutoff=cutoff, max_runs=max_runs - ) - - if raises: - with pytest.raises(raises): - instantiate_settings() - else: - instantiate_settings() +from stratocaster.tests.utils import StrategyTestMixin - -@pytest.fixture -def default_strategy(): - _settings = ConnectivityStrategy._default_settings() - return ConnectivityStrategy(_settings) +from gufe.tokenization import GufeKey -def test_propose_no_results( - default_strategy: ConnectivityStrategy, benzene_variants_star_map: AlchemicalNetwork -): - proposal: StrategyResult = default_strategy.propose(benzene_variants_star_map, {}) +class TestConnectivityStrategy(StrategyTestMixin): - assert all([weight == 3.5 for weight in proposal._weights.values()]) - assert 1 == sum( - weight for weight in proposal.resolve().values() if weight is not None - ) + strategy_class = ConnectivityStrategy + valid_settings = ((0.5, 0.1, 10), (0.1, None, 10), (0.5, 0.1, None)) + @pytest.mark.parametrize(["decay_rate", "cutoff", "max_runs"], valid_settings) + def test_simulated_termination( + self, benzene_variants_star_map, decay_rate, cutoff, max_runs + ): -def test_propose_previous_results( - default_strategy: ConnectivityStrategy, benzene_variants_star_map: AlchemicalNetwork -): - - result_data: dict[GufeKey, DummyProtocolResult] = {} - for transformation in benzene_variants_star_map.edges: - transformation_key = transformation.key - result = DummyProtocolResult( - n_protocol_dag_results=2, info=f"key: {transformation_key}" - ) - result_data[transformation_key] = result - - results = default_strategy.propose(benzene_variants_star_map, result_data) - results_no_data = default_strategy.propose(benzene_variants_star_map, {}) - - # the raw weights should no longer be the same - assert results.weights != results_no_data.weights - # since each transformation had the same number of previous results, resolve - # should give back the same normalized weights - assert results.resolve() == results_no_data.resolve() - - -def test_propose_max_runs_termination( - default_strategy: ConnectivityStrategy, benzene_variants_star_map: AlchemicalNetwork -): - assert isinstance(default_strategy.settings, ConnectivityStrategySettings) - max_runs = default_strategy.settings.max_runs - assert isinstance(max_runs, int) - - result_data: dict[GufeKey, DummyProtocolResult] = {} - for transformation in benzene_variants_star_map.edges: - transformation_key = transformation.key - result = DummyProtocolResult( - n_protocol_dag_results=max_runs, info=f"key: {transformation_key}" + settings = ConnectivityStrategySettings( + decay_rate=decay_rate, cutoff=cutoff, max_runs=max_runs ) - result_data[transformation_key] = result - - results = default_strategy.propose(benzene_variants_star_map, result_data) - - # since the default strategy has a max_runs of 3, we expect all Nones - assert not [weight for weight in results.resolve().values() if weight is not None] + default_strategy = self.strategy_class(settings) + + def counts_to_result_data(counts_dict): + result_data = {} + for transformation_key, count in counts_dict.items(): + result = DummyProtocolResult( + n_protocol_dag_results=count, info=f"key: {transformation_key}" + ) + result_data[transformation_key] = result + return result_data + + def shuffle_take_n(keys_list, n): + shuffle(keys_list) + return keys_list[:n] + + # initial transforms + transformation_counts = { + transformation.key: 0 for transformation in benzene_variants_star_map.edges + } + max_iterations = 100 + current_iteration = 0 + while current_iteration <= max_iterations: + + if current_iteration == max_iterations: + raise RuntimeError( + f"Strategy did not terminate in {max_iterations} iterations " + ) + + result_data = counts_to_result_data(transformation_counts) + proposal = default_strategy.propose(benzene_variants_star_map, result_data) + + # get random transformations from those with a non-None weight + resolved_keys = shuffle_take_n( + [ + key + for key, weight in proposal.resolve().items() + if weight is not None + ], + 5, + ) -def test_propose_cutoff_num_runs_predictioned_termination(benzene_variants_star_map): - """We can predict the number of runs needed to terminate with a given cutoff. + if resolved_keys: + # pretend we ran each of the randomly selected protocols + for key in resolved_keys: + transformation_counts[key] += 1 + # if we got an empty list back, there are not more protocols to run + else: + break + current_iteration += 1 - Each edge in benzene_variants_star_map has a base weight of 3.5. - """ + def test_propose_cutoff_num_runs_predictioned_termination( + self, benzene_variants_star_map + ): + """We can predict the number of runs needed to terminate with a given cutoff. - settings = ConnectivityStrategySettings(cutoff=2, decay_rate=0.5) - strategy = ConnectivityStrategy(settings) + Each edge in benzene_variants_star_map has a base weight of 3.5. + """ - assert isinstance(settings.cutoff, float) + settings = ConnectivityStrategySettings(cutoff=2, decay_rate=0.5) + strategy = ConnectivityStrategy(settings) - num_runs = math.floor( - math.log(settings.cutoff / 3.5) / math.log(settings.decay_rate) - ) + assert isinstance(settings.cutoff, float) - result_data: dict[GufeKey, DummyProtocolResult] = {} - for transformation in benzene_variants_star_map.edges: - transformation_key = transformation.key - result = DummyProtocolResult( - n_protocol_dag_results=num_runs + 1, info=f"key: {transformation_key}" + num_runs = math.floor( + math.log(settings.cutoff / 3.5) / math.log(settings.decay_rate) ) - result_data[transformation_key] = result - - results = strategy.propose(benzene_variants_star_map, result_data) - assert not [weight for weight in results.weights.values() if weight is not None] + result_data: dict[GufeKey, DummyProtocolResult] = {} + for transformation in benzene_variants_star_map.edges: + transformation_key = transformation.key + result = DummyProtocolResult( + n_protocol_dag_results=num_runs + 1, info=f"key: {transformation_key}" + ) + result_data[transformation_key] = result + results = strategy.propose(benzene_variants_star_map, result_data) -@pytest.mark.parametrize(["decay_rate", "cutoff", "max_runs"], SETTINGS_VALID) -def test_simulated_termination( - default_strategy, benzene_variants_star_map, decay_rate, cutoff, max_runs -): + assert not [weight for weight in results.weights.values() if weight is not None] - settings = ConnectivityStrategySettings( - decay_rate=decay_rate, cutoff=cutoff, max_runs=max_runs - ) - default_strategy = ConnectivityStrategy(settings) + def test_propose_max_runs_termination( + self, benzene_variants_star_map: AlchemicalNetwork + ): + assert isinstance(self.default_strategy.settings, ConnectivityStrategySettings) + max_runs = self.default_strategy.settings.max_runs + assert isinstance(max_runs, int) - def counts_to_result_data(counts_dict): - result_data = {} - for transformation_key, count in counts_dict.items(): + result_data: dict[GufeKey, DummyProtocolResult] = {} + for transformation in benzene_variants_star_map.edges: + transformation_key = transformation.key result = DummyProtocolResult( - n_protocol_dag_results=count, info=f"key: {transformation_key}" + n_protocol_dag_results=max_runs, info=f"key: {transformation_key}" ) result_data[transformation_key] = result - return result_data - def shuffle_take_n(keys_list, n): - shuffle(keys_list) - return keys_list[:n] + results = self.default_strategy.propose(benzene_variants_star_map, result_data) - # initial transforms - transformation_counts = { - transformation.key: 0 for transformation in benzene_variants_star_map.edges - } + # since the default strategy has a max_runs of 3, we expect all Nones + assert not [ + weight for weight in results.resolve().values() if weight is not None + ] - max_iterations = 100 - current_iteration = 0 - while current_iteration <= max_iterations: + def test_propose_previous_results( + self, benzene_variants_star_map: AlchemicalNetwork + ): - if current_iteration == max_iterations: - raise RuntimeError( - f"Strategy did not terminate in {max_iterations} iterations " + result_data: dict[GufeKey, DummyProtocolResult] = {} + for transformation in benzene_variants_star_map.edges: + transformation_key = transformation.key + result = DummyProtocolResult( + n_protocol_dag_results=2, info=f"key: {transformation_key}" ) + result_data[transformation_key] = result - result_data = counts_to_result_data(transformation_counts) - proposal = default_strategy.propose(benzene_variants_star_map, result_data) - - # get random transformations from those with a non-None weight - resolved_keys = shuffle_take_n( - [key for key, weight in proposal.resolve().items() if weight is not None], 5 - ) - - if resolved_keys: - # pretend we ran each of the randomly selected protocols - for key in resolved_keys: - transformation_counts[key] += 1 - # if we got an empty list back, there are not more protocols to run - else: - break - current_iteration += 1 + results = self.default_strategy.propose(benzene_variants_star_map, result_data) + results_no_data = self.default_strategy.propose(benzene_variants_star_map, {}) + # the raw weights should no longer be the same + assert results.weights != results_no_data.weights + # since each transformation had the same number of previous results, resolve + # should give back the same normalized weights + assert results.resolve() == results_no_data.resolve() -def test_deterministic( - default_strategy: ConnectivityStrategy, benzene_variants_star_map: AlchemicalNetwork -): + def test_propose_no_results(self, benzene_variants_star_map: AlchemicalNetwork): + proposal: StrategyResult = self.default_strategy.propose( + benzene_variants_star_map, {} + ) - settings = default_strategy.settings - assert isinstance(settings, ConnectivityStrategySettings) + assert all([weight == 3.5 for weight in proposal._weights.values()]) + assert 1 == sum( + weight for weight in proposal.resolve().values() if weight is not None + ) - max_runs = settings.max_runs - assert isinstance(max_runs, int) + @pytest.mark.parametrize( + ["decay_rate", "cutoff", "max_runs", "raises"], + [ + (0, None, None, ValueError), + (1, None, None, ValueError), + (0.5, 0, None, ValueError), + (0.5, None, 0, ValueError), + ] + + [(*vals, None) for vals in valid_settings], # include all valid settings + ) + def test_connectivity_strategy_settings(self, decay_rate, cutoff, max_runs, raises): - def random_runs(): - """Generate random randomized inputs for propose.""" - return { - transformation.key: DummyProtocolResult( - n_protocol_dag_results=randint(0, max_runs), - info=f"key: {transformation.key}", + def instantiate_settings(): + ConnectivityStrategySettings( + decay_rate=decay_rate, cutoff=cutoff, max_runs=max_runs ) - for transformation in benzene_variants_star_map.edges - } - for _ in range(10): - random_protocol_results = random_runs() - proposal = default_strategy.propose( - benzene_variants_star_map, protocol_results=random_protocol_results - ) - for _ in range(3): - _proposal = default_strategy.propose( - benzene_variants_star_map, protocol_results=random_protocol_results - ) - assert _proposal == proposal + if raises: + with pytest.raises(raises): + instantiate_settings() + else: + instantiate_settings() From d9b4d0b3f972cce62398e79a9f6673421cffab7f Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Fri, 6 Mar 2026 12:44:26 -0500 Subject: [PATCH 09/15] Make mixin test general --- .../tests/test_connectivity_strategy.py | 2 +- src/stratocaster/tests/utils.py | 19 ++++++++++++------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/stratocaster/tests/test_connectivity_strategy.py b/src/stratocaster/tests/test_connectivity_strategy.py index 0eff90c..59a1d2a 100644 --- a/src/stratocaster/tests/test_connectivity_strategy.py +++ b/src/stratocaster/tests/test_connectivity_strategy.py @@ -20,7 +20,7 @@ class TestConnectivityStrategy(StrategyTestMixin): strategy_class = ConnectivityStrategy - valid_settings = ((0.5, 0.1, 10), (0.1, None, 10), (0.5, 0.1, None)) + valid_settings = {(0.5, 0.1, 10), (0.1, None, 10), (0.5, 0.1, None)} @pytest.mark.parametrize(["decay_rate", "cutoff", "max_runs"], valid_settings) def test_simulated_termination( diff --git a/src/stratocaster/tests/utils.py b/src/stratocaster/tests/utils.py index 8e1ce20..113ce90 100644 --- a/src/stratocaster/tests/utils.py +++ b/src/stratocaster/tests/utils.py @@ -1,29 +1,34 @@ from random import randint + +import pytest + from gufe.tests.test_protocol import DummyProtocolResult class StrategyTestMixin: _default_strategy = None + _default_settings = None @property def default_strategy(self): if not self._default_strategy: - _settings = self.strategy_class._default_settings() - self._default_strategy = self.strategy_class(_settings) + self._default_strategy = self.strategy_class(self.default_settings) return self._default_strategy - def test_deterministic(self, fanning_network): - settings = self.default_strategy.settings + @property + def default_settings(self): + if not self._default_settings: + self._default_settings = self.strategy_class._default_settings() + return self._default_settings - max_runs = settings.max_runs - assert isinstance(max_runs, int) + def test_deterministic(self, fanning_network): def random_runs(): """Generate random randomized inputs for propose.""" return { transformation.key: DummyProtocolResult( - n_protocol_dag_results=randint(0, max_runs), + n_protocol_dag_results=randint(0, 3), info=f"key: {transformation.key}", ) for transformation in fanning_network.edges From 9f41a699c4e5ae9cf423ff2eb08da3c590dbce35 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Fri, 6 Mar 2026 14:50:40 -0500 Subject: [PATCH 10/15] Allow default tests to be overwritten with parametrize --- .../tests/test_connectivity_strategy.py | 60 ++------------- src/stratocaster/tests/utils.py | 73 +++++++++++++++++-- 2 files changed, 73 insertions(+), 60 deletions(-) diff --git a/src/stratocaster/tests/test_connectivity_strategy.py b/src/stratocaster/tests/test_connectivity_strategy.py index 59a1d2a..7e56c1e 100644 --- a/src/stratocaster/tests/test_connectivity_strategy.py +++ b/src/stratocaster/tests/test_connectivity_strategy.py @@ -23,63 +23,15 @@ class TestConnectivityStrategy(StrategyTestMixin): valid_settings = {(0.5, 0.1, 10), (0.1, None, 10), (0.5, 0.1, None)} @pytest.mark.parametrize(["decay_rate", "cutoff", "max_runs"], valid_settings) - def test_simulated_termination( - self, benzene_variants_star_map, decay_rate, cutoff, max_runs - ): + def test_simulated_termination(self, fanning_network, decay_rate, cutoff, max_runs): settings = ConnectivityStrategySettings( decay_rate=decay_rate, cutoff=cutoff, max_runs=max_runs ) - default_strategy = self.strategy_class(settings) - - def counts_to_result_data(counts_dict): - result_data = {} - for transformation_key, count in counts_dict.items(): - result = DummyProtocolResult( - n_protocol_dag_results=count, info=f"key: {transformation_key}" - ) - result_data[transformation_key] = result - return result_data - - def shuffle_take_n(keys_list, n): - shuffle(keys_list) - return keys_list[:n] - - # initial transforms - transformation_counts = { - transformation.key: 0 for transformation in benzene_variants_star_map.edges - } - - max_iterations = 100 - current_iteration = 0 - while current_iteration <= max_iterations: - - if current_iteration == max_iterations: - raise RuntimeError( - f"Strategy did not terminate in {max_iterations} iterations " - ) - - result_data = counts_to_result_data(transformation_counts) - proposal = default_strategy.propose(benzene_variants_star_map, result_data) - - # get random transformations from those with a non-None weight - resolved_keys = shuffle_take_n( - [ - key - for key, weight in proposal.resolve().items() - if weight is not None - ], - 5, - ) - if resolved_keys: - # pretend we ran each of the randomly selected protocols - for key in resolved_keys: - transformation_counts[key] += 1 - # if we got an empty list back, there are not more protocols to run - else: - break - current_iteration += 1 + StrategyTestMixin.test_simulated_termination( + self, fanning_network, settings=settings + ) def test_propose_cutoff_num_runs_predictioned_termination( self, benzene_variants_star_map @@ -90,9 +42,7 @@ def test_propose_cutoff_num_runs_predictioned_termination( """ settings = ConnectivityStrategySettings(cutoff=2, decay_rate=0.5) - strategy = ConnectivityStrategy(settings) - - assert isinstance(settings.cutoff, float) + strategy = self.strategy_or_default(settings) num_runs = math.floor( math.log(settings.cutoff / 3.5) / math.log(settings.decay_rate) diff --git a/src/stratocaster/tests/utils.py b/src/stratocaster/tests/utils.py index 113ce90..31a6ba8 100644 --- a/src/stratocaster/tests/utils.py +++ b/src/stratocaster/tests/utils.py @@ -1,4 +1,4 @@ -from random import randint +from random import randint, shuffle import pytest @@ -22,7 +22,12 @@ def default_settings(self): self._default_settings = self.strategy_class._default_settings() return self._default_settings - def test_deterministic(self, fanning_network): + def strategy_or_default(self, settings): + return self.strategy_class(settings) if settings else self.default_strategy + + def test_deterministic(self, fanning_network, settings=None): + + strategy = self.strategy_or_default(settings) def random_runs(): """Generate random randomized inputs for propose.""" @@ -45,13 +50,71 @@ def random_runs(): ) assert _proposal == proposal - def test_starts(self, fanning_network): - proposal: StrategyResult = self.default_strategy.propose(fanning_network, {}) + def test_starts(self, fanning_network, settings=None): + + strategy = self.strategy_or_default(settings) + + proposal: StrategyResult = strategy.propose(fanning_network, {}) # check that there is at least 1 non-None weight assert any(proposal.weights.values()) def test_disconnected( self, disconnected_fanning_network, + settings=None, ): - self.default_strategy.propose(disconnected_fanning_network, {}) + strategy = self.strategy_or_default(settings) + strategy.propose(disconnected_fanning_network, {}) + + def test_simulated_termination(self, fanning_network, settings=None): + + strategy = self.strategy_or_default(settings) + + def counts_to_result_data(counts_dict): + result_data = {} + for transformation_key, count in counts_dict.items(): + result = DummyProtocolResult( + n_protocol_dag_results=count, info=f"key: {transformation_key}" + ) + result_data[transformation_key] = result + return result_data + + def shuffle_take_n(keys_list, n): + shuffle(keys_list) + return keys_list[:n] + + # initial transforms + transformation_counts = { + transformation.key: 0 for transformation in fanning_network.edges + } + + max_iterations = 100 + current_iteration = 0 + while current_iteration <= max_iterations: + + if current_iteration == max_iterations: + raise RuntimeError( + f"Strategy did not terminate in {max_iterations} iterations " + ) + + result_data = counts_to_result_data(transformation_counts) + proposal = strategy.propose(fanning_network, result_data) + + # get random transformations from those with a non-None weight + resolved_keys = shuffle_take_n( + [ + key + for key, weight in proposal.resolve().items() + if weight is not None + ], + 5, + ) + + if resolved_keys: + # pretend we ran each of the randomly selected protocols + for key in resolved_keys: + transformation_counts[key] += 1 + # if we got an empty list back, there are not more protocols to run + else: + break + current_iteration += 1 From 7db5f88910fe44e6dc5fcbc210073f29fc1611b7 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Mon, 9 Mar 2026 15:03:20 -0400 Subject: [PATCH 11/15] Add docstrings --- src/stratocaster/strategies/radialgrowth.py | 67 ++++++++++++++++++++- 1 file changed, 64 insertions(+), 3 deletions(-) diff --git a/src/stratocaster/strategies/radialgrowth.py b/src/stratocaster/strategies/radialgrowth.py index c472d72..2c2cbf1 100644 --- a/src/stratocaster/strategies/radialgrowth.py +++ b/src/stratocaster/strategies/radialgrowth.py @@ -16,6 +16,7 @@ class RadialGrowthStrategySettings(StrategySettings): + """Settings required for the RadialGrowthStrategy.""" max_runs: int = Field( default=3, @@ -65,6 +66,32 @@ def validate_decay_distance_rate(cls, value): class RadialGrowthStrategy(Strategy): + """A Strategy that favors Transformations close to the network + center. + + The weight proposed for each Transformation depends on its highest + ChemicalSystem distance from the lowest completed tier of + distances. For the graph below, if at least one ProtocolDAGResult + exists for both edges in 3-2-3, then the lowest completed distance + is 3. If only one edge has a result or neither has a result, then + the lowest complete distance is 2. In the prior case, a + transformation going from 3 to 4 then has an effective distance of + 1, while the later case has a distance of 2. + + 4 4 + \ / + \ / + 4--3-2-3--4 + / \ + / \ + 4 4 + + The strategy will penialize transformations that have high + effective distances by multiplying the weight with a penalty of + r^d where r is a user specified decay rate and d is the effective + distance. The candidacy_max_distance setting limits how far out + transformations can be assigned non-zero weights. + """ _settings_cls = RadialGrowthStrategySettings @@ -77,34 +104,62 @@ def _propose( alchemical_network: AlchemicalNetwork, protocol_results: dict[GufeKey, ProtocolResult], ) -> StrategyResult: + """Propose `Transformation` weight recommendations based on + `Transformation` distance from the graph center. + + Parameters + ---------- + alchemical_network + protocol_results + A dictionary whose keys are the `GufeKey`s of `Transformation`s in the `AlchemicalNetwork` + and whose values are the `ProtocolResult`s for those `Transformation`s. + + Returns + ------- + StrategyResult + A `StrategyResult` containing the proposed `Transformation` weights. + + """ alchemical_network_mdg = alchemical_network.graph weights: dict[GufeKey, float | None] = {} + # calculate all node eccentricies e = nx.eccentricity(alchemical_network_mdg.to_undirected()) + # start with the maximum value, this will be decremented as we + # see evidence the value should be lower lowest_complete_eccentricity = max(e.values()) + # hold on to the eccentricies of the transformations instead + # of the distances since we don't know the lowest complete + # eccentricity until we process the full graph, distances can + # be calculated after transformation_eccentricity = {} for state_a, state_b in alchemical_network_mdg.edges(): edge = e[state_a], e[state_b] + # find the range of eccentricies lower, upper = min(edge), max(edge) transformation_key = alchemical_network_mdg.get_edge_data(state_a, state_b)[ 0 ]["object"].key - factor_distance = 1 factor_repeats = 1 - match (protocol_results.get(transformation_key)): case None: transformation_n_protcol_dag_results = 0 + # since we have no results for this + # transformation, we know the lowest complete + # eccentricity must be lower than the upper + # eccentricity of the transformation if upper < lowest_complete_eccentricity: lowest_complete_eccentricity = lower case pr: assert isinstance(pr, ProtocolResult) transformation_n_protcol_dag_results = pr.n_protocol_dag_results + # scale the repeat factor to discourage reruns as + # specified by the user's decay_repeat_rate factor_repeats *= ( self.settings.decay_repeat_rate ** transformation_n_protcol_dag_results @@ -115,10 +170,13 @@ def _propose( weights[transformation_key] = None continue - # save the upper eccentricity for later when we know the lowest_completed + # save the upper eccentricity for later when we know the + # lowest_completed. This is the transformation's effective + # distance from the center transformation_eccentricity[transformation_key] = upper weights[transformation_key] = factor_repeats + # start applying weights due to effective distance distance_factor = 1 for transformation_key, e in transformation_eccentricity.items(): distance = e - lowest_complete_eccentricity @@ -129,10 +187,13 @@ def _propose( if distance == 0: distance_factor = 1 else: + # scale the distance factor to limit the + # calculation of far-out transformations distance_factor = self.settings.decay_distance_rate ** ( distance - 1 ) else: + # set to zero, not None distance_factor = 0 weights[transformation_key] *= distance_factor From cd0a257a6a5bf56f540d12c64d5159a14de4a5a7 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Tue, 10 Mar 2026 14:24:35 -0400 Subject: [PATCH 12/15] Rework docstrings --- src/stratocaster/strategies/__init__.py | 3 +- src/stratocaster/strategies/radialgrowth.py | 51 +++++++++++---------- src/stratocaster/tests/utils.py | 15 +++++- 3 files changed, 43 insertions(+), 26 deletions(-) diff --git a/src/stratocaster/strategies/__init__.py b/src/stratocaster/strategies/__init__.py index 592b5e5..6a5d9fa 100644 --- a/src/stratocaster/strategies/__init__.py +++ b/src/stratocaster/strategies/__init__.py @@ -1,3 +1,4 @@ from stratocaster.strategies.connectivity import ConnectivityStrategy +from stratocaster.strategies.radialgrowth import RadialGrowthStrategy -__all__ = ["ConnectivityStrategy"] +__all__ = ["ConnectivityStrategy", "RadialGrowthStrategy"] diff --git a/src/stratocaster/strategies/radialgrowth.py b/src/stratocaster/strategies/radialgrowth.py index 2c2cbf1..08ae362 100644 --- a/src/stratocaster/strategies/radialgrowth.py +++ b/src/stratocaster/strategies/radialgrowth.py @@ -66,31 +66,35 @@ def validate_decay_distance_rate(cls, value): class RadialGrowthStrategy(Strategy): - """A Strategy that favors Transformations close to the network + r"""A Strategy that favors Transformations close to the network center. - The weight proposed for each Transformation depends on its highest - ChemicalSystem distance from the lowest completed tier of - distances. For the graph below, if at least one ProtocolDAGResult - exists for both edges in 3-2-3, then the lowest completed distance - is 3. If only one edge has a result or neither has a result, then - the lowest complete distance is 2. In the prior case, a - transformation going from 3 to 4 then has an effective distance of - 1, while the later case has a distance of 2. - - 4 4 - \ / - \ / - 4--3-2-3--4 - / \ - / \ - 4 4 - - The strategy will penialize transformations that have high - effective distances by multiplying the weight with a penalty of - r^d where r is a user specified decay rate and d is the effective - distance. The candidacy_max_distance setting limits how far out - transformations can be assigned non-zero weights. + The weight assigned to each Transformation depends on its highest + ChemicalSystem distance, as measured and labeled by the vertex + eccentricty, from the lowest completed tier of distances. In the + graph below, if at least one ``ProtocolDAGResult`` exists for both + edges in 3-2-3, the lowest completed distance is 3. If neither + edge or only one edge has a result, the lowest completed distance + is 2. In the former case, a transformation going from 3 to 4 has + an effective distance of 1, while in the latter case it has a + distance of 2. + + .. code-block:: + + 4 4 + \ / + \ / + 4--3-2-3--4 + / \ + / \ + 4 4 + + The strategy penalizes transformations that have high effective + distances by multiplying the weight with a penalty of :math:`r^d`, where r + is a user-specified decay rate and d is the effective + distance. The ``candidacy_max_distance`` setting limits how far + out transformations can be assigned non-zero weights. + """ _settings_cls = RadialGrowthStrategySettings @@ -177,7 +181,6 @@ def _propose( weights[transformation_key] = factor_repeats # start applying weights due to effective distance - distance_factor = 1 for transformation_key, e in transformation_eccentricity.items(): distance = e - lowest_complete_eccentricity diff --git a/src/stratocaster/tests/utils.py b/src/stratocaster/tests/utils.py index 31a6ba8..28564b8 100644 --- a/src/stratocaster/tests/utils.py +++ b/src/stratocaster/tests/utils.py @@ -4,8 +4,21 @@ from gufe.tests.test_protocol import DummyProtocolResult +from stratocaster.base.strategy import StrategyResult + class StrategyTestMixin: + r"""A mixin base class for testing strategies. + + All tests defined in this base class should include a ``settings`` + keyword argument, defaulting to ``None``. This requirement ensures + predictable pytest parametrization in derived strategy tests. The + ``StrategyTestMixin.strategy_or_default`` method, when called with + optional settings, will return either a strategy configured with + the provided settings or the default strategy established by the + strategy author. + + """ _default_strategy = None _default_settings = None @@ -41,7 +54,7 @@ def random_runs(): for _ in range(10): random_protocol_results = random_runs() - proposal = self.default_strategy.propose( + proposal = strategy.propose( fanning_network, protocol_results=random_protocol_results ) for _ in range(3): From d89a6c0d254c5b3e27681ec3ebe5120e99d1b94b Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Tue, 10 Mar 2026 14:56:24 -0400 Subject: [PATCH 13/15] Derived strategy tests should provide instances of the settings This effectively tests validation passing for free. --- .../tests/test_connectivity_strategy.py | 36 ++++++++----------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/src/stratocaster/tests/test_connectivity_strategy.py b/src/stratocaster/tests/test_connectivity_strategy.py index 7e56c1e..d3bc673 100644 --- a/src/stratocaster/tests/test_connectivity_strategy.py +++ b/src/stratocaster/tests/test_connectivity_strategy.py @@ -20,14 +20,13 @@ class TestConnectivityStrategy(StrategyTestMixin): strategy_class = ConnectivityStrategy - valid_settings = {(0.5, 0.1, 10), (0.1, None, 10), (0.5, 0.1, None)} + valid_settings = [ + ConnectivityStrategySettings(decay_rate=dr, cutoff=co, max_runs=mr) + for dr, co, mr in {(0.5, 0.1, 10), (0.1, None, 10), (0.5, 0.1, None)} + ] - @pytest.mark.parametrize(["decay_rate", "cutoff", "max_runs"], valid_settings) - def test_simulated_termination(self, fanning_network, decay_rate, cutoff, max_runs): - - settings = ConnectivityStrategySettings( - decay_rate=decay_rate, cutoff=cutoff, max_runs=max_runs - ) + @pytest.mark.parametrize("settings", valid_settings) + def test_simulated_termination(self, fanning_network, settings): StrategyTestMixin.test_simulated_termination( self, fanning_network, settings=settings @@ -114,24 +113,17 @@ def test_propose_no_results(self, benzene_variants_star_map: AlchemicalNetwork): ) @pytest.mark.parametrize( - ["decay_rate", "cutoff", "max_runs", "raises"], + ["decay_rate", "cutoff", "max_runs"], [ - (0, None, None, ValueError), - (1, None, None, ValueError), - (0.5, 0, None, ValueError), - (0.5, None, 0, ValueError), - ] - + [(*vals, None) for vals in valid_settings], # include all valid settings + (0, None, None), + (1, None, None), + (0.5, 0, None), + (0.5, None, 0), + ], ) - def test_connectivity_strategy_settings(self, decay_rate, cutoff, max_runs, raises): + def test_connectivity_strategy_invalid_settings(self, decay_rate, cutoff, max_runs): - def instantiate_settings(): + with pytest.raises(ValueError): ConnectivityStrategySettings( decay_rate=decay_rate, cutoff=cutoff, max_runs=max_runs ) - - if raises: - with pytest.raises(raises): - instantiate_settings() - else: - instantiate_settings() From 641e972514de521a37aafa7a44c36c227dca230a Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Wed, 8 Apr 2026 10:45:59 -0400 Subject: [PATCH 14/15] Update field descriptions and fix typos --- src/stratocaster/strategies/radialgrowth.py | 10 +++------- src/stratocaster/tests/utils.py | 2 +- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/stratocaster/strategies/radialgrowth.py b/src/stratocaster/strategies/radialgrowth.py index 08ae362..27f0cc8 100644 --- a/src/stratocaster/strategies/radialgrowth.py +++ b/src/stratocaster/strategies/radialgrowth.py @@ -5,12 +5,9 @@ from pydantic import ( Field, - model_validator, field_validator, ) -import pydantic - from stratocaster.base import Strategy, StrategyResult from stratocaster.base.models import StrategySettings @@ -20,12 +17,12 @@ class RadialGrowthStrategySettings(StrategySettings): max_runs: int = Field( default=3, - description="the upper limit of protocol DAG results needed before a transformation is no longer weighed", + description="the upper limit of ProtocolDAG results needed before a Transformation is no longer weighed", ) candidacy_max_distance: int = Field( default=1, - description="the maximum distance a candidate chemical can be from previously reached chemical systems", + description="the maximum distance a candidate ChemicalSystem can be from previously reached ChemicalSystems", ) decay_repeat_rate: float = Field( @@ -45,7 +42,7 @@ def validate_max_runs(cls, value): return value @field_validator("candidacy_max_distance", mode="before") - def validate_candidtate_max_distance(cls, value): + def validate_candidate_max_distance(cls, value): if not value >= 1: raise ValueError( "`candidtate_max_distance` must be greater than or equal to 1" @@ -160,7 +157,6 @@ def _propose( if upper < lowest_complete_eccentricity: lowest_complete_eccentricity = lower case pr: - assert isinstance(pr, ProtocolResult) transformation_n_protcol_dag_results = pr.n_protocol_dag_results # scale the repeat factor to discourage reruns as # specified by the user's decay_repeat_rate diff --git a/src/stratocaster/tests/utils.py b/src/stratocaster/tests/utils.py index 28564b8..d82b40a 100644 --- a/src/stratocaster/tests/utils.py +++ b/src/stratocaster/tests/utils.py @@ -43,7 +43,7 @@ def test_deterministic(self, fanning_network, settings=None): strategy = self.strategy_or_default(settings) def random_runs(): - """Generate random randomized inputs for propose.""" + """Generate random inputs for propose.""" return { transformation.key: DummyProtocolResult( n_protocol_dag_results=randint(0, 3), From e5da83e4c78ac390da46e3a7dd4af415197f68a7 Mon Sep 17 00:00:00 2001 From: Ian Kenney Date: Wed, 8 Apr 2026 10:46:38 -0400 Subject: [PATCH 15/15] Test against correct strategy in test_deterministic --- src/stratocaster/tests/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/stratocaster/tests/utils.py b/src/stratocaster/tests/utils.py index d82b40a..e5f6981 100644 --- a/src/stratocaster/tests/utils.py +++ b/src/stratocaster/tests/utils.py @@ -58,7 +58,7 @@ def random_runs(): fanning_network, protocol_results=random_protocol_results ) for _ in range(3): - _proposal = self.default_strategy.propose( + _proposal = strategy.propose( fanning_network, protocol_results=random_protocol_results ) assert _proposal == proposal