Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 69 additions & 41 deletions openfe/protocols/openmm_rfe/_rfe_utils/multistate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import openmmtools.states as states
from openmm import unit
from openmmtools import cache
from openmmtools.integrators import FIREMinimizationIntegrator
from openmmtools.multistate import multistatesampler, replicaexchange, sams
from openmmtools.states import CompoundThermodynamicState, SamplerState, ThermodynamicState

Expand All @@ -32,14 +31,21 @@ class HybridCompatibilityMixin(object):
unsampled endpoints have a different number of degrees of freedom.
"""

def __init__(self, *args, hybrid_factory=None, **kwargs):
self._hybrid_factory = hybrid_factory
def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs):
self._hybrid_system = hybrid_system
self._hybrid_positions = hybrid_positions
super(HybridCompatibilityMixin, self).__init__(*args, **kwargs)

def setup(self, reporter, lambda_protocol,
temperature=298.15 * unit.kelvin, n_replicas=None,
endstates=True, minimization_steps=100,
minimization_platform="CPU"):
def setup(
self,
reporter,
lambda_protocol,
temperature=298.15 * unit.kelvin,
n_replicas=None,
endstates=True,
minimization_steps=100,
minimization_platform="CPU"
):
"""
Setup MultistateSampler based on the input lambda protocol and number
of replicas.
Expand Down Expand Up @@ -73,15 +79,17 @@ class creation of LambdaProtocol.
"""
n_states = len(lambda_protocol.lambda_schedule)

hybrid_system = self._factory.hybrid_system
lambda_zero_state = RelativeAlchemicalState.from_system(self._hybrid_system)

lambda_zero_state = RelativeAlchemicalState.from_system(hybrid_system)
thermostate = ThermodynamicState(
self._hybrid_system,
temperature=temperature
)

thermostate = ThermodynamicState(hybrid_system,
temperature=temperature)
compound_thermostate = CompoundThermodynamicState(
thermostate,
composable_states=[lambda_zero_state])
thermostate,
composable_states=[lambda_zero_state]
)

# create lists for storing thermostates and sampler states
thermodynamic_state_list = []
Expand All @@ -105,24 +113,30 @@ class creation of LambdaProtocol.
raise ValueError(errmsg)

# starting with the hybrid factory positions
box = hybrid_system.getDefaultPeriodicBoxVectors()
sampler_state = SamplerState(self._factory.hybrid_positions,
box_vectors=box)
box = self._hybrid_system.getDefaultPeriodicBoxVectors()
sampler_state = SamplerState(
self._hybrid_positions,
box_vectors=box
)

# Loop over the lambdas and create & store a compound thermostate at
# that lambda value
for lambda_val in lambda_schedule:
compound_thermostate_copy = copy.deepcopy(compound_thermostate)
compound_thermostate_copy.set_alchemical_parameters(
lambda_val, lambda_protocol)
lambda_val, lambda_protocol
)
thermodynamic_state_list.append(compound_thermostate_copy)

# now generating a sampler_state for each thermodyanmic state,
# with relaxed positions
# Note: remove once choderalab/openmmtools#672 is completed
minimize(compound_thermostate_copy, sampler_state,
max_iterations=minimization_steps,
platform_name=minimization_platform)
minimize(
compound_thermostate_copy,
sampler_state,
max_iterations=minimization_steps,
platform_name=minimization_platform
)
sampler_state_list.append(copy.deepcopy(sampler_state))

del compound_thermostate, sampler_state
Expand All @@ -131,25 +145,34 @@ class creation of LambdaProtocol.
if len(sampler_state_list) != n_replicas:
# picking roughly evenly spaced sampler states
# if n_replicas == 1, then it will pick the first in the list
samples = np.linspace(0, len(sampler_state_list) - 1,
n_replicas)
samples = np.linspace(0, len(sampler_state_list) - 1, n_replicas)
idx = np.round(samples).astype(int)
sampler_state_list = [state for i, state in
enumerate(sampler_state_list) if i in idx]
sampler_state_list = [
state
for i, state in enumerate(sampler_state_list)
if i in idx
]

assert len(sampler_state_list) == n_replicas

if endstates:
# generating unsampled endstates
unsampled_dispersion_endstates = create_endstates(
copy.deepcopy(thermodynamic_state_list[0]),
copy.deepcopy(thermodynamic_state_list[-1]))
self.create(thermodynamic_states=thermodynamic_state_list,
sampler_states=sampler_state_list, storage=reporter,
unsampled_thermodynamic_states=unsampled_dispersion_endstates)
copy.deepcopy(thermodynamic_state_list[-1])
)
self.create(
thermodynamic_states=thermodynamic_state_list,
sampler_states=sampler_state_list,
storage=reporter,
unsampled_thermodynamic_states=unsampled_dispersion_endstates
)
else:
self.create(thermodynamic_states=thermodynamic_state_list,
sampler_states=sampler_state_list, storage=reporter)
self.create(
thermodynamic_states=thermodynamic_state_list,
sampler_states=sampler_state_list,
storage=reporter
)


class HybridRepexSampler(HybridCompatibilityMixin,
Expand All @@ -158,24 +181,27 @@ class HybridRepexSampler(HybridCompatibilityMixin,
ReplicaExchangeSampler that supports unsampled end states with a different
number of positions
"""

def __init__(self, *args, hybrid_factory=None, **kwargs):
def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs):
super(HybridRepexSampler, self).__init__(
*args, hybrid_factory=hybrid_factory, **kwargs)
self._factory = hybrid_factory
*args,
hybrid_system=hybrid_system,
hybrid_positions=hybrid_positions,
**kwargs
)


class HybridSAMSSampler(HybridCompatibilityMixin, sams.SAMSSampler):
"""
SAMSSampler that supports unsampled end states with a different number
of positions
"""

def __init__(self, *args, hybrid_factory=None, **kwargs):
def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs):
super(HybridSAMSSampler, self).__init__(
*args, hybrid_factory=hybrid_factory, **kwargs
*args,
hybrid_system=hybrid_system,
hybrid_positions=hybrid_positions,
**kwargs
)
self._factory = hybrid_factory


class HybridMultiStateSampler(HybridCompatibilityMixin,
Expand All @@ -184,11 +210,13 @@ class HybridMultiStateSampler(HybridCompatibilityMixin,
MultiStateSampler that supports unsample end states with a different
number of positions
"""
def __init__(self, *args, hybrid_factory=None, **kwargs):
def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs):
super(HybridMultiStateSampler, self).__init__(
*args, hybrid_factory=hybrid_factory, **kwargs
*args,
hybrid_system=hybrid_system,
hybrid_positions=hybrid_positions,
**kwargs
)
self._factory = hybrid_factory


def create_endstates(first_thermostate, last_thermostate):
Expand Down
Loading
Loading