Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions src/stratocaster/base/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,21 @@ def weights(self) -> dict[GufeKey, float | None]:

def resolve(self) -> dict[GufeKey, float | None]:
"""Get normalized proposal weights relative to all non-None Transformation weights."""
weight_sum = sum(
[weight for weight in self._weights.values() if weight is not None]
)
non_none_weights = [
weight for weight in self._weights.values() if weight is not None
]

if not non_none_weights:
return self.weights

weight_sum = sum(non_none_weights)
if weight_sum == 0:
raise ValueError(
"Cannot resolve weights: sum of non-None weights is zero. "
"This is likely a bug in the Strategy implementation. "
"Please raise an issue at https://github.com/OpenFreeEnergy/stratocaster/issues"
)

normalized_weights = {
key: weight / weight_sum if weight is not None else None
for key, weight in self._weights.items()
Expand Down
15 changes: 15 additions & 0 deletions src/stratocaster/tests/test_strategy_base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import pytest

from gufe import AlchemicalNetwork, ProtocolResult
from gufe.tokenization import GufeKey

Expand Down Expand Up @@ -28,6 +30,19 @@ def test_resolve_normalization(self):
res = self.result.resolve()
assert 1 == sum([value for _, value in res.items() if value is not None])

def test_resolve_no_zero_division(self):
all_zero_weights = StrategyResult(
{key: 0 for key, _ in self.result.weights.items()}
)
all_none_weights = StrategyResult(
{key: None for key, _ in self.result.weights.items()}
)

with pytest.raises(ValueError):
_ = all_zero_weights.resolve()

_ = all_none_weights.resolve()


class DummyStrategySettings(StrategySettings):
pass
Expand Down
Loading