-
Notifications
You must be signed in to change notification settings - Fork 1
Add RadialGrowthStrategy #16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
1c7cd30
Untested implementation of RadialGrowthStrategy
ianmkenney d0c7d1f
Allow StrategyResults merging through "|"
ianmkenney 922a97e
Move fixtures to conftest.py
ianmkenney 1d2dc51
Test empty results and disconnected graph
ianmkenney 25c9a29
Handle edge case for distance
ianmkenney 211fd9c
Test graph connectivity in the correct way
ianmkenney bbaea72
Create mixin test class
ianmkenney 3af2b3c
Use mixin test class for connectivity
ianmkenney d9b4d0b
Make mixin test general
ianmkenney 9f41a69
Allow default tests to be overwritten with parametrize
ianmkenney 7db5f88
Add docstrings
ianmkenney cd0a257
Rework docstrings
ianmkenney 7024b15
Merge branch 'main' into feat/strategy/radialgrowth
ianmkenney d89a6c0
Derived strategy tests should provide instances of the settings
ianmkenney 641e972
Update field descriptions and fix typos
ianmkenney e5da83e
Test against correct strategy in test_deterministic
ianmkenney File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| from stratocaster.strategies.connectivity import ConnectivityStrategy | ||
| from stratocaster.strategies.radialgrowth import RadialGrowthStrategy | ||
|
|
||
| __all__ = ["ConnectivityStrategy"] | ||
| __all__ = ["ConnectivityStrategy", "RadialGrowthStrategy"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,200 @@ | ||
| import networkx as nx | ||
|
|
||
| from gufe import AlchemicalNetwork, ProtocolResult | ||
| from gufe.tokenization import GufeKey | ||
|
|
||
| from pydantic import ( | ||
| Field, | ||
| field_validator, | ||
| ) | ||
|
|
||
| from stratocaster.base import Strategy, StrategyResult | ||
| from stratocaster.base.models import StrategySettings | ||
|
|
||
|
|
||
| class RadialGrowthStrategySettings(StrategySettings): | ||
| """Settings required for the RadialGrowthStrategy.""" | ||
|
|
||
| max_runs: int = Field( | ||
| default=3, | ||
| description="the upper limit of ProtocolDAG results needed before a Transformation is no longer weighed", | ||
| ) | ||
|
|
||
| candidacy_max_distance: int = Field( | ||
| default=1, | ||
| description="the maximum distance a candidate ChemicalSystem can be from previously reached ChemicalSystems", | ||
| ) | ||
|
|
||
| decay_repeat_rate: float = Field( | ||
| default=0.5, | ||
| description="decay rate of the exponential repeat decay penalty factor", | ||
| ) | ||
|
|
||
| decay_distance_rate: float = Field( | ||
| default=0.5, | ||
| description="decay rate of the exponential distance decay penalty factor", | ||
| ) | ||
|
|
||
| @field_validator("max_runs", mode="before") | ||
| def validate_max_runs(cls, value): | ||
| if not value >= 1: | ||
| raise ValueError("`max_runs` must be greater than or equal to 1") | ||
| return value | ||
|
|
||
| @field_validator("candidacy_max_distance", mode="before") | ||
| def validate_candidate_max_distance(cls, value): | ||
| if not value >= 1: | ||
| raise ValueError( | ||
| "`candidtate_max_distance` must be greater than or equal to 1" | ||
| ) | ||
| return value | ||
|
|
||
| @field_validator("decay_repeat_rate", mode="before") | ||
| def validate_decay_repeat_rate(cls, value): | ||
| if not (0 < value < 1): | ||
| raise ValueError("`decay_repeat_rate` must be between 0 and 1") | ||
| return value | ||
|
|
||
| @field_validator("decay_distance_rate", mode="before") | ||
| def validate_decay_distance_rate(cls, value): | ||
| if not (0 < value < 1): | ||
| raise ValueError("`decay_distance_rate` must be between 0 and 1") | ||
| return value | ||
|
|
||
|
|
||
| class RadialGrowthStrategy(Strategy): | ||
| r"""A Strategy that favors Transformations close to the network | ||
| center. | ||
|
|
||
| The weight assigned to each Transformation depends on its highest | ||
| ChemicalSystem distance, as measured and labeled by the vertex | ||
| eccentricty, from the lowest completed tier of distances. In the | ||
| graph below, if at least one ``ProtocolDAGResult`` exists for both | ||
| edges in 3-2-3, the lowest completed distance is 3. If neither | ||
| edge or only one edge has a result, the lowest completed distance | ||
| is 2. In the former case, a transformation going from 3 to 4 has | ||
| an effective distance of 1, while in the latter case it has a | ||
| distance of 2. | ||
|
|
||
| .. code-block:: | ||
|
|
||
| 4 4 | ||
| \ / | ||
| \ / | ||
| 4--3-2-3--4 | ||
| / \ | ||
| / \ | ||
| 4 4 | ||
|
|
||
| The strategy penalizes transformations that have high effective | ||
| distances by multiplying the weight with a penalty of :math:`r^d`, where r | ||
| is a user-specified decay rate and d is the effective | ||
| distance. The ``candidacy_max_distance`` setting limits how far | ||
| out transformations can be assigned non-zero weights. | ||
|
|
||
| """ | ||
|
|
||
| _settings_cls = RadialGrowthStrategySettings | ||
|
|
||
| @classmethod | ||
| def _default_settings(cls) -> StrategySettings: | ||
| return RadialGrowthStrategySettings(max_runs=3) | ||
|
|
||
| def _propose( | ||
| self, | ||
| alchemical_network: AlchemicalNetwork, | ||
| protocol_results: dict[GufeKey, ProtocolResult], | ||
| ) -> StrategyResult: | ||
| """Propose `Transformation` weight recommendations based on | ||
| `Transformation` distance from the graph center. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| alchemical_network | ||
| protocol_results | ||
| A dictionary whose keys are the `GufeKey`s of `Transformation`s in the `AlchemicalNetwork` | ||
| and whose values are the `ProtocolResult`s for those `Transformation`s. | ||
|
|
||
| Returns | ||
| ------- | ||
| StrategyResult | ||
| A `StrategyResult` containing the proposed `Transformation` weights. | ||
|
|
||
| """ | ||
|
|
||
| alchemical_network_mdg = alchemical_network.graph | ||
| weights: dict[GufeKey, float | None] = {} | ||
|
|
||
| # calculate all node eccentricies | ||
| e = nx.eccentricity(alchemical_network_mdg.to_undirected()) | ||
|
|
||
| # start with the maximum value, this will be decremented as we | ||
| # see evidence the value should be lower | ||
| lowest_complete_eccentricity = max(e.values()) | ||
| # hold on to the eccentricies of the transformations instead | ||
| # of the distances since we don't know the lowest complete | ||
| # eccentricity until we process the full graph, distances can | ||
| # be calculated after | ||
| transformation_eccentricity = {} | ||
|
|
||
| for state_a, state_b in alchemical_network_mdg.edges(): | ||
| edge = e[state_a], e[state_b] | ||
| # find the range of eccentricies | ||
| lower, upper = min(edge), max(edge) | ||
|
|
||
| transformation_key = alchemical_network_mdg.get_edge_data(state_a, state_b)[ | ||
| 0 | ||
| ]["object"].key | ||
|
|
||
| factor_repeats = 1 | ||
| match (protocol_results.get(transformation_key)): | ||
| case None: | ||
| transformation_n_protcol_dag_results = 0 | ||
| # since we have no results for this | ||
| # transformation, we know the lowest complete | ||
| # eccentricity must be lower than the upper | ||
| # eccentricity of the transformation | ||
| if upper < lowest_complete_eccentricity: | ||
| lowest_complete_eccentricity = lower | ||
| case pr: | ||
| transformation_n_protcol_dag_results = pr.n_protocol_dag_results | ||
| # scale the repeat factor to discourage reruns as | ||
| # specified by the user's decay_repeat_rate | ||
| factor_repeats *= ( | ||
| self.settings.decay_repeat_rate | ||
| ** transformation_n_protcol_dag_results | ||
| ) | ||
|
|
||
| # stop condition given max runs | ||
| if self.settings.max_runs <= transformation_n_protcol_dag_results: | ||
| weights[transformation_key] = None | ||
| continue | ||
|
|
||
| # save the upper eccentricity for later when we know the | ||
| # lowest_completed. This is the transformation's effective | ||
| # distance from the center | ||
| transformation_eccentricity[transformation_key] = upper | ||
| weights[transformation_key] = factor_repeats | ||
|
|
||
| # start applying weights due to effective distance | ||
| for transformation_key, e in transformation_eccentricity.items(): | ||
| distance = e - lowest_complete_eccentricity | ||
|
|
||
| if distance <= self.settings.candidacy_max_distance: | ||
| # edge case where there are multiple vertices with | ||
| # eccentricity equal to graph radius | ||
| if distance == 0: | ||
| distance_factor = 1 | ||
| else: | ||
| # scale the distance factor to limit the | ||
| # calculation of far-out transformations | ||
| distance_factor = self.settings.decay_distance_rate ** ( | ||
| distance - 1 | ||
| ) | ||
| else: | ||
| # set to zero, not None | ||
| distance_factor = 0 | ||
|
|
||
| weights[transformation_key] *= distance_factor | ||
|
|
||
| return StrategyResult(weights) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| import pytest | ||
|
|
||
| from stratocaster.tests.networks import ( | ||
| benzene_variants_star_map as _benzene_variants_star_map, | ||
| fanning_network as _fanning_network, | ||
| disconnected_fanning_network as _disconnected_fanning_network, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def benzene_variants_star_map(): | ||
| return _benzene_variants_star_map() | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def fanning_network(): | ||
| return _fanning_network() | ||
|
|
||
|
|
||
| @pytest.fixture(scope="module") | ||
| def disconnected_fanning_network(): | ||
| return _disconnected_fanning_network() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain why we have a
- 1here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it because we want
distance == 1cases to always get adistance_factor == 1?