Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 32 additions & 31 deletions bilby/core/prior/dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,11 +731,14 @@ def _resolve_conditions(self):

def _check_conditions_resolved(self, key, sampled_keys):
"""Checks if all required variables have already been sampled so we can sample this key"""
conditions_resolved = True
for k in self[key].required_variables:
if k not in sampled_keys:
conditions_resolved = False
return conditions_resolved
return False
elif isinstance(self[k], JointPrior):
for name in self[k].dist.names:
if name not in sampled_keys and name != key:
return False
return True

def sample_subset(self, keys=iter([]), size=None):
self.convert_floats_to_delta_functions()
Expand Down Expand Up @@ -874,36 +877,34 @@ def rescale(self, keys, theta):
result[key] = self[key].rescale(
theta[index], **self.get_required_variables(key)
)
self[key].least_recently_sampled = result[key]
if isinstance(self[key], JointPrior) and self[key].dist.distname not in joint:
joint[self[key].dist.distname] = [key]
elif isinstance(self[key], JointPrior):
joint[self[key].dist.distname].append(key)
for names in joint.values():
# this is needed to unpack how joint prior rescaling works
# as an example of a joint prior over {a, b, c, d} we might
# get the following based on the order within the joint prior
# {a: [], b: [], c: [1, 2, 3, 4], d: []}
# -> [1, 2, 3, 4]
# -> {a: 1, b: 2, c: 3, d: 4}
values = list()
for key in names:
values = np.concatenate([values, result[key]])
for key, value in zip(names, values):
result[key] = value

def safe_flatten(value):
"""
this is gross but can be removed whenever we switch to returning
arrays, flatten converts 0-d arrays to 1-d so has to be special
cased
"""
if isinstance(value, (float, int, np.int64)):
return value

# if any requested key depends on some joint prior `jp_key`
# self[jp_key].least_recently_sampled needs to be set before
# rescaling those requested keys.
# Thus we keep track of joint priors here
if isinstance(self[key], JointPrior):
# if joint prior, keep track if all names have been rescaled
distname = self[key].dist.distname
# maintain order of names as in the dist as this is the order
# in which they will be rescaled
names = self[key].dist.names
if distname not in joint:
joint[distname] = {key}
else:
joint[distname].add(key)
# only when all names have been rescaled, we can set the values
# we use sets because the order does not matter here
if set(names) == joint[distname]:
for name, value in zip(names, result[key]):
result[name] = value
self[name].least_recently_sampled = value
joint.pop(distname)
else:
return result[key].flatten()
# if not joint prior, set value immediately
self[key].least_recently_sampled = result[key]

return [safe_flatten(result[key]) for key in keys]
# finally return results in the order requested
return [result[key] for key in keys]

def _update_rescale_keys(self, keys):
if not keys == self._least_recently_rescaled_keys:
Expand Down
53 changes: 34 additions & 19 deletions test/core/prior/conditional_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,32 @@ def condition_func_3(reference_parameters, var_1, var_2):
).items():
self.conditional_priors_manually_set_items[key] = value

names = ["mvgvar_a", "mvgvar_b"]
mu = [[0.79, -0.83]]
cov = [[[0.03, 0.0], [0.0, 0.04]]]
mvg = bilby.core.prior.MultivariateGaussianDist(names, mus=mu, covs=cov)

def condition_func_4(reference_parameters, mvgvar_a):
return dict(minimum=reference_parameters["minimum"], maximum=mvgvar_a)

prior_4 = bilby.core.prior.ConditionalUniform(
condition_func=condition_func_4, minimum=self.minimum, maximum=self.maximum
)

self.conditional_priors_with_joint_prior = (
bilby.core.prior.ConditionalPriorDict(
dict(
var_4=prior_4,
var_3=self.prior_3,
var_2=self.prior_2,
var_0=self.prior_0,
var_1=self.prior_1,
mvgvar_a=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_a"),
mvgvar_b=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_b"),
)
)
)

def tearDown(self):
del self.minimum
del self.maximum
Expand All @@ -227,6 +253,7 @@ def tearDown(self):
del self.prior_3
del self.conditional_priors
del self.conditional_priors_manually_set_items
del self.conditional_priors_with_joint_prior
del self.test_sample

def test_conditions_resolved_upon_instantiation(self):
Expand Down Expand Up @@ -333,35 +360,23 @@ def test_rescale_with_joint_prior(self):
"""

# set multivariate Gaussian distribution
names = ["mvgvar_0", "mvgvar_1"]
mu = [[0.79, -0.83]]
cov = [[[0.03, 0.], [0., 0.04]]]
mvg = bilby.core.prior.MultivariateGaussianDist(names, mus=mu, covs=cov)

priordict = bilby.core.prior.ConditionalPriorDict(
dict(
var_3=self.prior_3,
var_2=self.prior_2,
var_0=self.prior_0,
var_1=self.prior_1,
mvgvar_0=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_0"),
mvgvar_1=bilby.core.prior.MultivariateGaussian(mvg, "mvgvar_1"),
)
)
priordict = self.conditional_priors_with_joint_prior
names = ["mvgvar_a", "mvgvar_b"]

ref_variables = list(self.test_sample.values()) + [0.4, 0.1]
keys = list(self.test_sample.keys()) + names
ref_variables = [0.1] + list(self.test_sample.values()) + [0.4, 0.1]
keys = ["var_4"] + list(self.test_sample.keys()) + names
res = priordict.rescale(keys=keys, theta=ref_variables)

self.assertIsInstance(res, list)
self.assertEqual(np.shape(res), (6,))
self.assertListEqual([isinstance(r, float) for r in res], 6 * [True])
self.assertEqual(np.shape(res), (7,))
self.assertListEqual([isinstance(r, float) for r in res], 7 * [True])

# check conditional values are still as expected
expected = [self.test_sample["var_0"]]
for ii in range(1, 4):
expected.append(expected[-1] * self.test_sample[f"var_{ii}"])
self.assertListEqual(expected, res[0:4])
self.assertListEqual(expected, res[1:5])

def test_cdf(self):
"""
Expand Down
Loading