diff --git a/src/stratocaster/base/strategy.py b/src/stratocaster/base/strategy.py index a7e3d3e..83cb30e 100644 --- a/src/stratocaster/base/strategy.py +++ b/src/stratocaster/base/strategy.py @@ -43,6 +43,13 @@ def resolve(self) -> dict[GufeKey, float | None]: weights.update(modified_weights) return weights + def __or__(self, other): + if self.weights.keys() & other.weights.keys(): + raise ValueError( + "StrategyResults can only be combined when their transformation keys are mutually exclusive." + ) + return StrategyResult(self.weights | other.weights) + class Strategy(GufeTokenizable): """An object that proposes the relative urgency of computing @@ -125,4 +132,8 @@ def propose( StrategyResult """ - return self._propose(alchemical_network, protocol_results) + subgraphs = alchemical_network.connected_subgraphs() + acc = StrategyResult({}) + for subgraph in subgraphs: + acc |= self._propose(subgraph, protocol_results) + return acc diff --git a/src/stratocaster/strategies/__init__.py b/src/stratocaster/strategies/__init__.py index 592b5e5..6a5d9fa 100644 --- a/src/stratocaster/strategies/__init__.py +++ b/src/stratocaster/strategies/__init__.py @@ -1,3 +1,4 @@ from stratocaster.strategies.connectivity import ConnectivityStrategy +from stratocaster.strategies.radialgrowth import RadialGrowthStrategy -__all__ = ["ConnectivityStrategy"] +__all__ = ["ConnectivityStrategy", "RadialGrowthStrategy"] diff --git a/src/stratocaster/strategies/radialgrowth.py b/src/stratocaster/strategies/radialgrowth.py new file mode 100644 index 0000000..27f0cc8 --- /dev/null +++ b/src/stratocaster/strategies/radialgrowth.py @@ -0,0 +1,200 @@ +import networkx as nx + +from gufe import AlchemicalNetwork, ProtocolResult +from gufe.tokenization import GufeKey + +from pydantic import ( + Field, + field_validator, +) + +from stratocaster.base import Strategy, StrategyResult +from stratocaster.base.models import StrategySettings + + +class RadialGrowthStrategySettings(StrategySettings): + """Settings required for the RadialGrowthStrategy.""" + + max_runs: int = Field( + default=3, + description="the upper limit of ProtocolDAG results needed before a Transformation is no longer weighed", + ) + + candidacy_max_distance: int = Field( + default=1, + description="the maximum distance a candidate ChemicalSystem can be from previously reached ChemicalSystems", + ) + + decay_repeat_rate: float = Field( + default=0.5, + description="decay rate of the exponential repeat decay penalty factor", + ) + + decay_distance_rate: float = Field( + default=0.5, + description="decay rate of the exponential distance decay penalty factor", + ) + + @field_validator("max_runs", mode="before") + def validate_max_runs(cls, value): + if not value >= 1: + raise ValueError("`max_runs` must be greater than or equal to 1") + return value + + @field_validator("candidacy_max_distance", mode="before") + def validate_candidate_max_distance(cls, value): + if not value >= 1: + raise ValueError( + "`candidtate_max_distance` must be greater than or equal to 1" + ) + return value + + @field_validator("decay_repeat_rate", mode="before") + def validate_decay_repeat_rate(cls, value): + if not (0 < value < 1): + raise ValueError("`decay_repeat_rate` must be between 0 and 1") + return value + + @field_validator("decay_distance_rate", mode="before") + def validate_decay_distance_rate(cls, value): + if not (0 < value < 1): + raise ValueError("`decay_distance_rate` must be between 0 and 1") + return value + + +class RadialGrowthStrategy(Strategy): + r"""A Strategy that favors Transformations close to the network + center. + + The weight assigned to each Transformation depends on its highest + ChemicalSystem distance, as measured and labeled by the vertex + eccentricty, from the lowest completed tier of distances. In the + graph below, if at least one ``ProtocolDAGResult`` exists for both + edges in 3-2-3, the lowest completed distance is 3. If neither + edge or only one edge has a result, the lowest completed distance + is 2. In the former case, a transformation going from 3 to 4 has + an effective distance of 1, while in the latter case it has a + distance of 2. + + .. code-block:: + + 4 4 + \ / + \ / + 4--3-2-3--4 + / \ + / \ + 4 4 + + The strategy penalizes transformations that have high effective + distances by multiplying the weight with a penalty of :math:`r^d`, where r + is a user-specified decay rate and d is the effective + distance. The ``candidacy_max_distance`` setting limits how far + out transformations can be assigned non-zero weights. + + """ + + _settings_cls = RadialGrowthStrategySettings + + @classmethod + def _default_settings(cls) -> StrategySettings: + return RadialGrowthStrategySettings(max_runs=3) + + def _propose( + self, + alchemical_network: AlchemicalNetwork, + protocol_results: dict[GufeKey, ProtocolResult], + ) -> StrategyResult: + """Propose `Transformation` weight recommendations based on + `Transformation` distance from the graph center. + + Parameters + ---------- + alchemical_network + protocol_results + A dictionary whose keys are the `GufeKey`s of `Transformation`s in the `AlchemicalNetwork` + and whose values are the `ProtocolResult`s for those `Transformation`s. + + Returns + ------- + StrategyResult + A `StrategyResult` containing the proposed `Transformation` weights. + + """ + + alchemical_network_mdg = alchemical_network.graph + weights: dict[GufeKey, float | None] = {} + + # calculate all node eccentricies + e = nx.eccentricity(alchemical_network_mdg.to_undirected()) + + # start with the maximum value, this will be decremented as we + # see evidence the value should be lower + lowest_complete_eccentricity = max(e.values()) + # hold on to the eccentricies of the transformations instead + # of the distances since we don't know the lowest complete + # eccentricity until we process the full graph, distances can + # be calculated after + transformation_eccentricity = {} + + for state_a, state_b in alchemical_network_mdg.edges(): + edge = e[state_a], e[state_b] + # find the range of eccentricies + lower, upper = min(edge), max(edge) + + transformation_key = alchemical_network_mdg.get_edge_data(state_a, state_b)[ + 0 + ]["object"].key + + factor_repeats = 1 + match (protocol_results.get(transformation_key)): + case None: + transformation_n_protcol_dag_results = 0 + # since we have no results for this + # transformation, we know the lowest complete + # eccentricity must be lower than the upper + # eccentricity of the transformation + if upper < lowest_complete_eccentricity: + lowest_complete_eccentricity = lower + case pr: + transformation_n_protcol_dag_results = pr.n_protocol_dag_results + # scale the repeat factor to discourage reruns as + # specified by the user's decay_repeat_rate + factor_repeats *= ( + self.settings.decay_repeat_rate + ** transformation_n_protcol_dag_results + ) + + # stop condition given max runs + if self.settings.max_runs <= transformation_n_protcol_dag_results: + weights[transformation_key] = None + continue + + # save the upper eccentricity for later when we know the + # lowest_completed. This is the transformation's effective + # distance from the center + transformation_eccentricity[transformation_key] = upper + weights[transformation_key] = factor_repeats + + # start applying weights due to effective distance + for transformation_key, e in transformation_eccentricity.items(): + distance = e - lowest_complete_eccentricity + + if distance <= self.settings.candidacy_max_distance: + # edge case where there are multiple vertices with + # eccentricity equal to graph radius + if distance == 0: + distance_factor = 1 + else: + # scale the distance factor to limit the + # calculation of far-out transformations + distance_factor = self.settings.decay_distance_rate ** ( + distance - 1 + ) + else: + # set to zero, not None + distance_factor = 0 + + weights[transformation_key] *= distance_factor + + return StrategyResult(weights) diff --git a/src/stratocaster/tests/conftest.py b/src/stratocaster/tests/conftest.py new file mode 100644 index 0000000..a0f9986 --- /dev/null +++ b/src/stratocaster/tests/conftest.py @@ -0,0 +1,22 @@ +import pytest + +from stratocaster.tests.networks import ( + benzene_variants_star_map as _benzene_variants_star_map, + fanning_network as _fanning_network, + disconnected_fanning_network as _disconnected_fanning_network, +) + + +@pytest.fixture(scope="module") +def benzene_variants_star_map(): + return _benzene_variants_star_map() + + +@pytest.fixture(scope="module") +def fanning_network(): + return _fanning_network() + + +@pytest.fixture(scope="module") +def disconnected_fanning_network(): + return _disconnected_fanning_network() diff --git a/src/stratocaster/tests/networks.py b/src/stratocaster/tests/networks.py index a6aae9e..b7456cc 100644 --- a/src/stratocaster/tests/networks.py +++ b/src/stratocaster/tests/networks.py @@ -11,6 +11,7 @@ import gufe from gufe.tests.test_protocol import DummyProtocol +import networkx as nx from openff.units import unit from rdkit import Chem @@ -115,3 +116,94 @@ def benzene_variants_star_map(): return gufe.AlchemicalNetwork( solvated_ligand_transformations + solvated_complex_transformations ) + + +def digraph_to_alchemical_network(digraph): + """Convert a digraph to an AlchemicalNetwork.""" + node_to_chemical_system = lambda n: gufe.ChemicalSystem({}, name=n) + + transformations = [] + for a, b in digraph.edges(): + c_a = node_to_chemical_system(a) + c_b = node_to_chemical_system(b) + protocol = None + transformation = gufe.Transformation( + c_a, c_b, DummyProtocol(settings=DummyProtocol.default_settings()) + ) + transformations.append(transformation) + + return gufe.AlchemicalNetwork(transformations) + + +def int_sampler(sampler_start=0): + num = sampler_start + while True: + yield num + num += 1 + + +def fanning_network(branch=3, depth=3): + """Generate a network with a central node that recursively "fans" + out. Branch determines how many edges branch off of each node + while depth determines how many times the process is repeated. + """ + G = nx.DiGraph() + G.add_edges_from(fan(0, branch=branch, depth=depth)) + an = digraph_to_alchemical_network(G) + return an + + +def dual_center_fanning_network(branch=3, depth=3): + """Generate a network with two central nodes that recursively + "fan" out. Branch determines how many edges branch off of each + node while depth determines how many times the process is + repeated. + """ + + G = nx.DiGraph() + + edges = {(0, 1)} + + sampler = int_sampler(sampler_start=2) + edges |= fan(0, branch=branch, depth=depth, id_generator=sampler) + edges |= fan(1, branch=branch, depth=depth, id_generator=sampler) + + G.add_edges_from(edges) + + an = digraph_to_alchemical_network(G) + return an + + +def disconnected_fanning_network(branch=3, depth=3): + """Generate a network with disconnected fanning subgraphs.""" + + G = nx.DiGraph() + + edges = set() + + sampler = int_sampler(sampler_start=2) + edges |= fan(0, branch=branch, depth=depth, id_generator=sampler) + edges |= fan(1, branch=branch, depth=depth, id_generator=sampler) + + G.add_edges_from(edges) + + assert not nx.is_weakly_connected(G) + + an = digraph_to_alchemical_network(G) + return an + + +def fan(node, branch=3, depth=3, id_generator=None): + """Recursive edge generator.""" + + id_generator = id_generator or int_sampler(sampler_start=node + 1) + + if depth == 0: + return set() + + edges = {(node, next(id_generator)) for _ in range(branch)} + + for _, next_node in edges.copy(): + edges |= fan(next_node, depth=depth - 1, id_generator=id_generator) + + return edges diff --git a/src/stratocaster/tests/test_connectivity_strategy.py b/src/stratocaster/tests/test_connectivity_strategy.py index 236c511..d3bc673 100644 --- a/src/stratocaster/tests/test_connectivity_strategy.py +++ b/src/stratocaster/tests/test_connectivity_strategy.py @@ -11,215 +11,119 @@ ConnectivityStrategy, ConnectivityStrategySettings, ) -from stratocaster.tests.networks import ( - benzene_variants_star_map as _benzene_variants_star_map, -) - - -@pytest.fixture(scope="module") -def benzene_variants_star_map(): - return _benzene_variants_star_map() +from stratocaster.tests.utils import StrategyTestMixin from gufe.tokenization import GufeKey -SETTINGS_VALID = [(0.5, 0.1, 10), (0.1, None, 10), (0.5, 0.1, None)] +class TestConnectivityStrategy(StrategyTestMixin): -@pytest.mark.parametrize( - ["decay_rate", "cutoff", "max_runs", "raises"], - [ - (0, None, None, ValueError), - (1, None, None, ValueError), - (0.5, 0, None, ValueError), - (0.5, None, 0, ValueError), + strategy_class = ConnectivityStrategy + valid_settings = [ + ConnectivityStrategySettings(decay_rate=dr, cutoff=co, max_runs=mr) + for dr, co, mr in {(0.5, 0.1, 10), (0.1, None, 10), (0.5, 0.1, None)} ] - + [(*vals, None) for vals in SETTINGS_VALID], # include all valid settings -) -def test_connectivity_strategy_settings(decay_rate, cutoff, max_runs, raises): - - def instantiate_settings(): - ConnectivityStrategySettings( - decay_rate=decay_rate, cutoff=cutoff, max_runs=max_runs - ) - - if raises: - with pytest.raises(raises): - instantiate_settings() - else: - instantiate_settings() - - -@pytest.fixture -def default_strategy(): - _settings = ConnectivityStrategy._default_settings() - return ConnectivityStrategy(_settings) - -def test_propose_no_results( - default_strategy: ConnectivityStrategy, benzene_variants_star_map: AlchemicalNetwork -): - proposal: StrategyResult = default_strategy.propose(benzene_variants_star_map, {}) + @pytest.mark.parametrize("settings", valid_settings) + def test_simulated_termination(self, fanning_network, settings): - assert all([weight == 3.5 for weight in proposal._weights.values()]) - assert 1 == sum( - weight for weight in proposal.resolve().values() if weight is not None - ) - - -def test_propose_previous_results( - default_strategy: ConnectivityStrategy, benzene_variants_star_map: AlchemicalNetwork -): - - result_data: dict[GufeKey, DummyProtocolResult] = {} - for transformation in benzene_variants_star_map.edges: - transformation_key = transformation.key - result = DummyProtocolResult( - n_protocol_dag_results=2, info=f"key: {transformation_key}" - ) - result_data[transformation_key] = result - - results = default_strategy.propose(benzene_variants_star_map, result_data) - results_no_data = default_strategy.propose(benzene_variants_star_map, {}) - - # the raw weights should no longer be the same - assert results.weights != results_no_data.weights - # since each transformation had the same number of previous results, resolve - # should give back the same normalized weights - assert results.resolve() == results_no_data.resolve() - - -def test_propose_max_runs_termination( - default_strategy: ConnectivityStrategy, benzene_variants_star_map: AlchemicalNetwork -): - assert isinstance(default_strategy.settings, ConnectivityStrategySettings) - max_runs = default_strategy.settings.max_runs - assert isinstance(max_runs, int) - - result_data: dict[GufeKey, DummyProtocolResult] = {} - for transformation in benzene_variants_star_map.edges: - transformation_key = transformation.key - result = DummyProtocolResult( - n_protocol_dag_results=max_runs, info=f"key: {transformation_key}" + StrategyTestMixin.test_simulated_termination( + self, fanning_network, settings=settings ) - result_data[transformation_key] = result - - results = default_strategy.propose(benzene_variants_star_map, result_data) - - # since the default strategy has a max_runs of 3, we expect all Nones - assert not [weight for weight in results.resolve().values() if weight is not None] - - -def test_propose_cutoff_num_runs_predictioned_termination(benzene_variants_star_map): - """We can predict the number of runs needed to terminate with a given cutoff. - Each edge in benzene_variants_star_map has a base weight of 3.5. - """ + def test_propose_cutoff_num_runs_predictioned_termination( + self, benzene_variants_star_map + ): + """We can predict the number of runs needed to terminate with a given cutoff. - settings = ConnectivityStrategySettings(cutoff=2, decay_rate=0.5) - strategy = ConnectivityStrategy(settings) + Each edge in benzene_variants_star_map has a base weight of 3.5. + """ - assert isinstance(settings.cutoff, float) + settings = ConnectivityStrategySettings(cutoff=2, decay_rate=0.5) + strategy = self.strategy_or_default(settings) - num_runs = math.floor( - math.log(settings.cutoff / 3.5) / math.log(settings.decay_rate) - ) - - result_data: dict[GufeKey, DummyProtocolResult] = {} - for transformation in benzene_variants_star_map.edges: - transformation_key = transformation.key - result = DummyProtocolResult( - n_protocol_dag_results=num_runs + 1, info=f"key: {transformation_key}" + num_runs = math.floor( + math.log(settings.cutoff / 3.5) / math.log(settings.decay_rate) ) - result_data[transformation_key] = result - - results = strategy.propose(benzene_variants_star_map, result_data) - assert not [weight for weight in results.weights.values() if weight is not None] + result_data: dict[GufeKey, DummyProtocolResult] = {} + for transformation in benzene_variants_star_map.edges: + transformation_key = transformation.key + result = DummyProtocolResult( + n_protocol_dag_results=num_runs + 1, info=f"key: {transformation_key}" + ) + result_data[transformation_key] = result + results = strategy.propose(benzene_variants_star_map, result_data) -@pytest.mark.parametrize(["decay_rate", "cutoff", "max_runs"], SETTINGS_VALID) -def test_simulated_termination( - default_strategy, benzene_variants_star_map, decay_rate, cutoff, max_runs -): + assert not [weight for weight in results.weights.values() if weight is not None] - settings = ConnectivityStrategySettings( - decay_rate=decay_rate, cutoff=cutoff, max_runs=max_runs - ) - default_strategy = ConnectivityStrategy(settings) + def test_propose_max_runs_termination( + self, benzene_variants_star_map: AlchemicalNetwork + ): + assert isinstance(self.default_strategy.settings, ConnectivityStrategySettings) + max_runs = self.default_strategy.settings.max_runs + assert isinstance(max_runs, int) - def counts_to_result_data(counts_dict): - result_data = {} - for transformation_key, count in counts_dict.items(): + result_data: dict[GufeKey, DummyProtocolResult] = {} + for transformation in benzene_variants_star_map.edges: + transformation_key = transformation.key result = DummyProtocolResult( - n_protocol_dag_results=count, info=f"key: {transformation_key}" + n_protocol_dag_results=max_runs, info=f"key: {transformation_key}" ) result_data[transformation_key] = result - return result_data - def shuffle_take_n(keys_list, n): - shuffle(keys_list) - return keys_list[:n] + results = self.default_strategy.propose(benzene_variants_star_map, result_data) - # initial transforms - transformation_counts = { - transformation.key: 0 for transformation in benzene_variants_star_map.edges - } + # since the default strategy has a max_runs of 3, we expect all Nones + assert not [ + weight for weight in results.resolve().values() if weight is not None + ] - max_iterations = 100 - current_iteration = 0 - while current_iteration <= max_iterations: + def test_propose_previous_results( + self, benzene_variants_star_map: AlchemicalNetwork + ): - if current_iteration == max_iterations: - raise RuntimeError( - f"Strategy did not terminate in {max_iterations} iterations " + result_data: dict[GufeKey, DummyProtocolResult] = {} + for transformation in benzene_variants_star_map.edges: + transformation_key = transformation.key + result = DummyProtocolResult( + n_protocol_dag_results=2, info=f"key: {transformation_key}" ) + result_data[transformation_key] = result - result_data = counts_to_result_data(transformation_counts) - proposal = default_strategy.propose(benzene_variants_star_map, result_data) - - # get random transformations from those with a non-None weight - resolved_keys = shuffle_take_n( - [key for key, weight in proposal.resolve().items() if weight is not None], 5 - ) - - if resolved_keys: - # pretend we ran each of the randomly selected protocols - for key in resolved_keys: - transformation_counts[key] += 1 - # if we got an empty list back, there are not more protocols to run - else: - break - current_iteration += 1 - + results = self.default_strategy.propose(benzene_variants_star_map, result_data) + results_no_data = self.default_strategy.propose(benzene_variants_star_map, {}) -def test_deterministic( - default_strategy: ConnectivityStrategy, benzene_variants_star_map: AlchemicalNetwork -): + # the raw weights should no longer be the same + assert results.weights != results_no_data.weights + # since each transformation had the same number of previous results, resolve + # should give back the same normalized weights + assert results.resolve() == results_no_data.resolve() - settings = default_strategy.settings - assert isinstance(settings, ConnectivityStrategySettings) + def test_propose_no_results(self, benzene_variants_star_map: AlchemicalNetwork): + proposal: StrategyResult = self.default_strategy.propose( + benzene_variants_star_map, {} + ) - max_runs = settings.max_runs - assert isinstance(max_runs, int) + assert all([weight == 3.5 for weight in proposal._weights.values()]) + assert 1 == sum( + weight for weight in proposal.resolve().values() if weight is not None + ) - def random_runs(): - """Generate random randomized inputs for propose.""" - return { - transformation.key: DummyProtocolResult( - n_protocol_dag_results=randint(0, max_runs), - info=f"key: {transformation.key}", - ) - for transformation in benzene_variants_star_map.edges - } + @pytest.mark.parametrize( + ["decay_rate", "cutoff", "max_runs"], + [ + (0, None, None), + (1, None, None), + (0.5, 0, None), + (0.5, None, 0), + ], + ) + def test_connectivity_strategy_invalid_settings(self, decay_rate, cutoff, max_runs): - for _ in range(10): - random_protocol_results = random_runs() - proposal = default_strategy.propose( - benzene_variants_star_map, protocol_results=random_protocol_results - ) - for _ in range(3): - _proposal = default_strategy.propose( - benzene_variants_star_map, protocol_results=random_protocol_results + with pytest.raises(ValueError): + ConnectivityStrategySettings( + decay_rate=decay_rate, cutoff=cutoff, max_runs=max_runs ) - assert _proposal == proposal diff --git a/src/stratocaster/tests/test_radialgrowth_strategy.py b/src/stratocaster/tests/test_radialgrowth_strategy.py new file mode 100644 index 0000000..099230a --- /dev/null +++ b/src/stratocaster/tests/test_radialgrowth_strategy.py @@ -0,0 +1,9 @@ +from stratocaster.strategies.radialgrowth import ( + RadialGrowthStrategy, +) + +from stratocaster.tests.utils import StrategyTestMixin + + +class TestRadialGrowth(StrategyTestMixin): + strategy_class = RadialGrowthStrategy diff --git a/src/stratocaster/tests/utils.py b/src/stratocaster/tests/utils.py new file mode 100644 index 0000000..e5f6981 --- /dev/null +++ b/src/stratocaster/tests/utils.py @@ -0,0 +1,133 @@ +from random import randint, shuffle + +import pytest + +from gufe.tests.test_protocol import DummyProtocolResult + +from stratocaster.base.strategy import StrategyResult + + +class StrategyTestMixin: + r"""A mixin base class for testing strategies. + + All tests defined in this base class should include a ``settings`` + keyword argument, defaulting to ``None``. This requirement ensures + predictable pytest parametrization in derived strategy tests. The + ``StrategyTestMixin.strategy_or_default`` method, when called with + optional settings, will return either a strategy configured with + the provided settings or the default strategy established by the + strategy author. + + """ + + _default_strategy = None + _default_settings = None + + @property + def default_strategy(self): + if not self._default_strategy: + self._default_strategy = self.strategy_class(self.default_settings) + return self._default_strategy + + @property + def default_settings(self): + if not self._default_settings: + self._default_settings = self.strategy_class._default_settings() + return self._default_settings + + def strategy_or_default(self, settings): + return self.strategy_class(settings) if settings else self.default_strategy + + def test_deterministic(self, fanning_network, settings=None): + + strategy = self.strategy_or_default(settings) + + def random_runs(): + """Generate random inputs for propose.""" + return { + transformation.key: DummyProtocolResult( + n_protocol_dag_results=randint(0, 3), + info=f"key: {transformation.key}", + ) + for transformation in fanning_network.edges + } + + for _ in range(10): + random_protocol_results = random_runs() + proposal = strategy.propose( + fanning_network, protocol_results=random_protocol_results + ) + for _ in range(3): + _proposal = strategy.propose( + fanning_network, protocol_results=random_protocol_results + ) + assert _proposal == proposal + + def test_starts(self, fanning_network, settings=None): + + strategy = self.strategy_or_default(settings) + + proposal: StrategyResult = strategy.propose(fanning_network, {}) + # check that there is at least 1 non-None weight + assert any(proposal.weights.values()) + + def test_disconnected( + self, + disconnected_fanning_network, + settings=None, + ): + strategy = self.strategy_or_default(settings) + strategy.propose(disconnected_fanning_network, {}) + + def test_simulated_termination(self, fanning_network, settings=None): + + strategy = self.strategy_or_default(settings) + + def counts_to_result_data(counts_dict): + result_data = {} + for transformation_key, count in counts_dict.items(): + result = DummyProtocolResult( + n_protocol_dag_results=count, info=f"key: {transformation_key}" + ) + result_data[transformation_key] = result + return result_data + + def shuffle_take_n(keys_list, n): + shuffle(keys_list) + return keys_list[:n] + + # initial transforms + transformation_counts = { + transformation.key: 0 for transformation in fanning_network.edges + } + + max_iterations = 100 + current_iteration = 0 + while current_iteration <= max_iterations: + + if current_iteration == max_iterations: + raise RuntimeError( + f"Strategy did not terminate in {max_iterations} iterations " + ) + + result_data = counts_to_result_data(transformation_counts) + proposal = strategy.propose(fanning_network, result_data) + + # get random transformations from those with a non-None weight + resolved_keys = shuffle_take_n( + [ + key + for key, weight in proposal.resolve().items() + if weight is not None + ], + 5, + ) + + if resolved_keys: + # pretend we ran each of the randomly selected protocols + for key in resolved_keys: + transformation_counts[key] += 1 + # if we got an empty list back, there are not more protocols to run + else: + break + current_iteration += 1