diff --git a/gwinferno/distributions.py b/gwinferno/distributions.py index 40db691..b701005 100644 --- a/gwinferno/distributions.py +++ b/gwinferno/distributions.py @@ -18,7 +18,8 @@ def smooth(dx, x, xmin): func = jnp.exp(dx / (x - xmin) + dx / (x - xmin - dx)) s1 = jnp.where(jnp.less(x, xmin), 0, 1) s2 = jnp.where(jnp.less(x, xmin + dx) | jnp.greater_equal(x, xmin), (func + 1) ** (-1), s1) - return s2 + s3 = jnp.where(jnp.greater_equal(x, xmin + dx), 1, s2) + return s3 def logistic_function(x, L, k, x0): diff --git a/gwinferno/models/parametric/parametric.py b/gwinferno/models/parametric/parametric.py index 1bc312e..07c8bcb 100644 --- a/gwinferno/models/parametric/parametric.py +++ b/gwinferno/models/parametric/parametric.py @@ -50,7 +50,7 @@ def plpeak_primary_pdf(m1, alpha, mmin, mmax, mpp, sigpp, lam, delta=None): if delta is None: return (1 - lam) * powerlaw_pdf(m1, alpha, mmin, mmax) + lam * truncnorm_pdf(m1, mpp, sigpp, mmin, mmax) else: - return (1 - lam) * powerlaw_pdf(m1, alpha, mmin, mmax) * smooth(delta, m1, mmin) + lam * truncnorm_pdf(m1, mpp, sigpp, mmin, mmax) + return ( (1 - lam) * powerlaw_pdf(m1, alpha, mmin, mmax) + lam * truncnorm_pdf(m1, mpp, sigpp, mmin, mmax) ) * smooth(delta, m1, mmin) """