diff --git a/src/stratocaster/base/strategy.py b/src/stratocaster/base/strategy.py index 5817fc3..cd36d5a 100644 --- a/src/stratocaster/base/strategy.py +++ b/src/stratocaster/base/strategy.py @@ -33,9 +33,21 @@ def weights(self) -> dict[GufeKey, float | None]: def resolve(self) -> dict[GufeKey, float | None]: """Get normalized proposal weights relative to all non-None Transformation weights.""" - weight_sum = sum( - [weight for weight in self._weights.values() if weight is not None] - ) + non_none_weights = [ + weight for weight in self._weights.values() if weight is not None + ] + + if not non_none_weights: + return self.weights + + weight_sum = sum(non_none_weights) + if weight_sum == 0: + raise ValueError( + "Cannot resolve weights: sum of non-None weights is zero. " + "This is likely a bug in the Strategy implementation. " + "Please raise an issue at https://github.com/OpenFreeEnergy/stratocaster/issues" + ) + normalized_weights = { key: weight / weight_sum if weight is not None else None for key, weight in self._weights.items() diff --git a/src/stratocaster/tests/test_strategy_base.py b/src/stratocaster/tests/test_strategy_base.py index aaabf58..f40d675 100644 --- a/src/stratocaster/tests/test_strategy_base.py +++ b/src/stratocaster/tests/test_strategy_base.py @@ -1,3 +1,5 @@ +import pytest + from gufe import AlchemicalNetwork, ProtocolResult from gufe.tokenization import GufeKey @@ -28,6 +30,19 @@ def test_resolve_normalization(self): res = self.result.resolve() assert 1 == sum([value for _, value in res.items() if value is not None]) + def test_resolve_no_zero_division(self): + all_zero_weights = StrategyResult( + {key: 0 for key, _ in self.result.weights.items()} + ) + all_none_weights = StrategyResult( + {key: None for key, _ in self.result.weights.items()} + ) + + with pytest.raises(ValueError): + _ = all_zero_weights.resolve() + + _ = all_none_weights.resolve() + class DummyStrategySettings(StrategySettings): pass