Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions src/stratocaster/base/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,19 +29,18 @@ def _from_dict(cls, dct: dict):

@property
def weights(self) -> dict[GufeKey, float | None]:
return self._weights
return self._weights.copy()

def resolve(self) -> dict[GufeKey, float | None]:
"""Normalize the proposal weights relative to all non-None Transformation weights."""
weights = self.weights
weight_sum = sum([weight for weight in weights.values() if weight is not None])
modified_weights = {
key: weight / weight_sum
for key, weight in weights.items()
if weight is not None
"""Get normalized proposal weights relative to all non-None Transformation weights."""
weight_sum = sum(
[weight for weight in self._weights.values() if weight is not None]
)
normalized_weights = {
key: weight / weight_sum if weight is not None else None
for key, weight in self._weights.items()
}
weights.update(modified_weights)
return weights
return normalized_weights

def __or__(self, other):
if self.weights.keys() & other.weights.keys():
Expand Down
11 changes: 11 additions & 0 deletions src/stratocaster/tests/test_strategy_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@ class TestStrategyResult:
def test_dict_roundtrip(self):
assert StrategyResult.from_dict(self.result.to_dict()) == self.result

def test_resolve_no_weight_side_effect(self):
"""Resolve returns a normalized copy of the result
weights and doesn't modify the original data."""
res = self.result.resolve()
assert res != self.result.weights

def test_resolve_normalization(self):
"""Resolve produces normalized weights for all non-None values."""
res = self.result.resolve()
assert 1 == sum([value for _, value in res.items() if value is not None])


class DummyStrategySettings(StrategySettings):
pass
Expand Down
Loading