diff --git a/src/stratocaster/base/strategy.py b/src/stratocaster/base/strategy.py index a7e3d3e..327e4ce 100644 --- a/src/stratocaster/base/strategy.py +++ b/src/stratocaster/base/strategy.py @@ -33,7 +33,7 @@ def weights(self) -> dict[GufeKey, float | None]: def resolve(self) -> dict[GufeKey, float | None]: """Normalize the proposal weights relative to all non-None Transformation weights.""" - weights = self.weights + weights = dict(self.weights) weight_sum = sum([weight for weight in weights.values() if weight is not None]) modified_weights = { key: weight / weight_sum diff --git a/src/stratocaster/tests/test_strategy_base.py b/src/stratocaster/tests/test_strategy_base.py index 6b20058..fc890b4 100644 --- a/src/stratocaster/tests/test_strategy_base.py +++ b/src/stratocaster/tests/test_strategy_base.py @@ -17,6 +17,16 @@ class TestStrategyResult: def test_dict_roundtrip(self): assert StrategyResult.from_dict(self.result.to_dict()) == self.result + def test_resolve_does_not_mutate_weights(self): + original_weights = dict(self.result.weights) + self.result.resolve() + assert self.result.weights == original_weights + + def test_resolve_is_idempotent(self): + first = self.result.resolve() + second = self.result.resolve() + assert first == second + class DummyStrategySettings(StrategySettings): pass