diff --git a/src/scripts/intensity_models.py b/src/scripts/intensity_models.py index 1bda990..284abab 100644 --- a/src/scripts/intensity_models.py +++ b/src/scripts/intensity_models.py @@ -5,6 +5,8 @@ import jax import jax.numpy as jnp import jax.scipy.special as jss +import jax.scipy.stats as jsst +import jax.scipy.integrate as jsi import numpy as np import numpyro import numpyro.distributions as dist @@ -56,6 +58,11 @@ def log_dNdmCO(mco, a, b): x = mco/mtr return jnp.where(mco < mtr, -a*jnp.log(x), -b*jnp.log(x)) +def smooth_log_dNdmCO(xx, a, b): + xtr = 20 + delta = 0.05 + return -a * jnp.log(xx / xtr) + delta * (a - b) * jnp.log(0.5 * (1 + (xx/xtr)**(1/delta))) + def log_smooth_turnon(m, mmin, width=0.05): """A function that smoothly transitions from 0 to 1. @@ -159,10 +166,10 @@ class LogDNDMPISN_evolve(object): mpisn: object mbhmax: object sigma: object - n_m: object = 800 + n_m: object = 1024 mbh_grid: object = dataclasses.field(init=False) log_dN_grid: object = dataclasses.field(init=False) - + def __post_init__(self): min_bh_mass = 3.0 min_co_mass = 1.0 @@ -184,17 +191,23 @@ def __post_init__(self): mco = jnp.exp(log_mco) mu = mean_mbh_from_mco(mco, mpisn, mbhmax) - log_mu = jnp.log(jnp.where(mu > 0, mu, 0)) - log_p = -0.5*jnp.square((log_mbh - log_mu)/sigma) - 0.5*jnp.log(2*np.pi) - jnp.log(sigma) - log_mbh + mu_min = 0.1 + mu = jnp.where(mu > 0, mu, mu_min) + log_mu = jnp.log(mu) + + log_p = -0.5 * jnp.square((log_mbh - log_mu) / sigma) - 0.5*jnp.log(2*jnp.pi) - jnp.log(sigma) - log_mbh + + #log_p = -0.5*jnp.square((log_mbh - log_mu)/sigma) - 0.5*jnp.log(2*np.pi) - jnp.log(sigma) - log_mbh + + #log_p = jsst.norm.logpdf(x = log_mbh, loc = log_mu, scale = sigma) - log_mbh log_wts = log_dNdmCO(mco, self.a, self.b) + log_p + #log_trapz = jnp.log(jsi.trapezoid(x = mco, y = jnp.exp(log_wts), axis=1)) log_trapz = np.log(0.5) + jnp.logaddexp(log_wts[:,:-1,:], log_wts[:,1:,:]) + jnp.log(jnp.diff(mco, axis=1)) + self.log_dN_grid = jss.logsumexp(log_trapz, axis=1) self.mbh_grid = mbh[0,0,:] - #def __call__(self, m): - #return jnp.interp(m, self.mbh_grid, self.log_dN_grid) - @dataclass class LogDNDM_evolve(object): @@ -215,7 +228,7 @@ class LogDNDM_evolve(object): mbh_min: object = mbh_min zmax: object = 20 mref: object = 30.0 - zref: object = 0 + zref: object = 0.001 log_dndm_pisn: object = dataclasses.field(init=False) def __post_init__(self): @@ -225,7 +238,7 @@ def __post_init__(self): #mpisn_at_zref = self.mpisn + self.mpisndot * (1 - 1/(1 + self.zref)) #mbhmax_at_zref = mpisn_at_zref + self.dmbhmax - # self.log_pl_norm = jnp.log(self.fpl) + self.interp_2d_dndmpisn(mbhmax_at_zref, self.zref) + #self.log_pl_norm = jnp.log(self.fpl) + self.interp_2d_dndmpisn(mbhmax_at_zref, self.zref) #self.log_norm = -(self(self.mref, self.zref) + jnp.log(self.mref)) # normalize so that m dNdm = 1 at mref @@ -233,7 +246,7 @@ def setup_interp(self): self.z_array = jnp.expm1(jnp.linspace(np.log(1), jnp.log(1+self.zmax), 50)) mpisns = self.mpisn + self.mpisndot * (1 - 1/(1+self.z_array)) mbhmaxs = mpisns + self.dmbhmax - self.log_dndm_pisn = LogDNDMPISN_evolve(self.a, self.b, jnp.array(mpisns), jnp.array(mbhmaxs), self.sigma) + self.log_dndm_pisn = LogDNDMPISN_evolve(self.a, self.b, mpisns, mbhmaxs, self.sigma) self.mbh_grid = self.log_dndm_pisn.mbh_grid self.log_dndm_pisn_grid = self.log_dndm_pisn.log_dN_grid.T self.mbhmaxs = jnp.array(mbhmaxs) @@ -261,8 +274,12 @@ def interp_2d_dndmpisn(self, m, z): t = jnp.where(m_lower == m_upper, 0, (m - m_lower) / mdiffs) u = jnp.where(z_lower == z_upper, 0, (z - z_lower) / zdiffs) - - return (1 - t) * (1 - u) * f1 + t * (1 - u) * f2 + t * u *f3 + (1 - t) * u * f4 + + coefficients = jnp.array([(1 - t) * (1 - u), t * (1 - u), t * u, (1 - t) * u ]) + fs = jnp.array([f1, f2, f3, f4]) + coefficients = jnp.where(jnp.isinf(fs), 1e-6, coefficients) + + return jnp.einsum('i...,i...',coefficients, fs) def __call__(self, m, z): m = jnp.array(m) @@ -278,7 +295,7 @@ def __call__(self, m, z): log_dNdm = jnp.logaddexp(log_dNdm, jnp.log(self.fpl) + log_dNdmbhmax_at_samples + -self.c*jnp.log(m/mbhmax_at_samples) + log_smooth_turnon(m, mbhmax_at_samples)) log_dNdm = jnp.where(m < self.mbh_min, np.NINF, log_dNdm) - return log_dNdm #+ self.log_norm + return log_dNdm @dataclass class LogDNDM(object): @@ -331,7 +348,7 @@ class LogDNDV(object): lam: object kappa: object zp: object - zref: object = 0.01 + zref: object = 0.001 zmax: object = 20 log_norm: object = 0.0 @@ -361,7 +378,7 @@ class LogDNDMDQDV(object): zp: object mref: object = 30.0 qref: object = 1.0 - zref: object = 0.01 + zref: object = 0.001 log_dndm: object = dataclasses.field(init=False) log_dndv: object = dataclasses.field(init=False) @@ -399,7 +416,7 @@ class LogDNDMDQDV_evolve(object): zp: object mref: object = 30.0 qref: object = 1.0 - zref: object = 0.0 + zref: object = 0.001 zmax: object = 20 log_dndm: object = dataclasses.field(init=False) log_dndv: object = dataclasses.field(init=False) @@ -502,13 +519,12 @@ def mass_parameters(): c = numpyro.sample('c', dist.TruncatedNormal(4, 2, low=0, high=8)) mpisn = numpyro.sample('mpisn', dist.TruncatedNormal(35.0, 5.0, low=20.0, high=50.0)) - dmbhmax = numpyro.sample('dmbhmax', dist.TruncatedNormal(5.0, 2.0, low=0.5, high=11.0)) + dmbhmax = numpyro.sample('dmbhmax', dist.TruncatedNormal(3.0, 2.0, low=0.5, high=7.0))#used to be mean 5.0 mbhmax = numpyro.deterministic('mbhmax', mpisn + dmbhmax) sigma = numpyro.sample('sigma', dist.TruncatedNormal(0.1, 0.1, low=0.05)) beta = numpyro.sample('beta', dist.Normal(0, 2)) - - log_fpl = numpyro.sample('log_fpl', dist.Uniform(np.log(1e-3), np.log(0.5))) + log_fpl = numpyro.sample('log_fpl', dist.Uniform(np.log(1e-2), np.log(0.5))) fpl = numpyro.deterministic('fpl', jnp.exp(log_fpl)) return a,b,c,mpisn,mbhmax,sigma,beta,fpl @@ -529,7 +545,9 @@ def cosmo_parameters(): return h,Om,w def evolve_parameters(): - numpyro.sample('mpisndot', dist.Uinform(low=-2, high=8)) + mpisndot = numpyro.sample('mpisndot', dist.Uniform(low=-2, high=8)) + #mpisndot = numpyro.sample('mpisndot', dist.Uniform(low=-50, high=60)) + return mpisndot def pop_cosmo_model(m1s_det, qs, dls, pdraw, m1s_det_sel, qs_sel, dls_sel, pdraw_sel, Ndraw, evolution = False, zmax=20, fixed_cosmo_params = None): m1s_det, qs, dls, pdraw, m1s_det_sel, qs_sel, dls_sel, pdraw_sel = map(jnp.array, (m1s_det, qs, dls, pdraw, m1s_det_sel, qs_sel, dls_sel, pdraw_sel)) diff --git a/src/scripts/utils.py b/src/scripts/utils.py index 1a9f88a..d5f2d1e 100644 --- a/src/scripts/utils.py +++ b/src/scripts/utils.py @@ -5,4 +5,349 @@ def jnp_cumtrapz(ys, xs): xs = jnp.array(xs) ys = jnp.array(ys) - return jnp.concatenate((jnp.zeros(1), jnp.cumsum(0.5*jnp.diff(xs)*(ys[:-1] + ys[1:])))) \ No newline at end of file + return jnp.concatenate((jnp.zeros(1), jnp.cumsum(0.5*jnp.diff(xs)*(ys[:-1] + ys[1:])))) + +""" +Most of the stuff under this is cloned from Tom Callister's effective-spin-priors package: https://github.com/tcallister/effective-spin-priors +""" + +import numpy as np +from scipy.stats import gaussian_kde +from scipy.special import spence as PL + +def Di(z): + + """ + Wrapper for the scipy implmentation of Spence's function. + Note that we adhere to the Mathematica convention as detailed in: + https://reference.wolfram.com/language/ref/PolyLog.html + Inputs + z: A (possibly complex) scalar or array + Returns + Array equivalent to PolyLog[2,z], as defined by Mathematica + """ + + return PL(1.-z+0j) + +def chi_effective_prior_from_aligned_spins(xs, mass_ratio, a_max): + + """ + Function defining the conditional priors p(chi_eff|q) corresponding to + uniform, aligned component spin priors. + Inputs + q: Mass ratio value (according to the convention q<1) + aMax: Maximum allowed dimensionless component spin magnitude + xs: Chi_effective value or values at which we wish to compute prior + Returns: + Array of prior values + """ + + # Ensure that `xs` is an array and take absolute value + xs = np.reshape(xs,-1) + + # Set up various piecewise cases + pdfs = np.zeros(xs.size) + + if not isinstance(mass_ratio, np.ndarray): + mass_ratio *= np.ones(len(xs)) + if not isinstance(a_max, np.ndarray): + a_max *= np.ones(len(xs)) + + masks = dict() + masks["caseA"] = (xs>a_max*(1.-mass_ratio)/(1.+mass_ratio))*(xs<=a_max) + masks["caseB"] = (xs<-a_max*(1.-mass_ratio)/(1.+mass_ratio))*(xs>=-a_max) + masks["caseC"] = (xs>=-a_max*(1.-mass_ratio)/(1.+mass_ratio))*(xs<=a_max*(1.-mass_ratio)/(1.+mass_ratio)) + + """ + # Select relevant effective spins + x_A = xs[caseA] + x_B = xs[caseB] + x_C = xs[caseC] + """ + functions = dict() + functions["caseA"] = lambda X, q, aMax: (1.+q)**2.*(aMax-X)/(4.*q*aMax**2) + functions["caseB"] = lambda X, q, aMax: (1.+q)**2.*(aMax+X)/(4.*q*aMax**2) + functions["caseC"] = lambda X, q, aMax: (1.+q)/(2.*aMax) + for case in masks.keys(): + mask = masks[case] + Xs = xs[mask] + qs = mass_ratio[mask] + amaxes = a_max[mask] + pdfs[mask] = functions[case](X = Xs, q = qs, aMax = amaxes) + return pdfs + +def chi_effective_prior_from_isotropic_spins(xs, mass_ratio, a_max): + + """ + Function defining the conditional priors p(chi_eff|q) corresponding to + uniform, isotropic component spin priors. + Inputs + q: Mass ratio value (according to the convention q<1) + aMax: Maximum allowed dimensionless component spin magnitude + xs: Chi_effective value or values at which we wish to compute prior + Returns: + Array of prior values + """ + + # Ensure that `xs` is an array and take absolute value + xs = np.reshape(np.abs(xs),-1) + + # Set up various piecewise cases + pdfs = np.ones(xs.size,dtype=complex)*(-1.) + if not isinstance(mass_ratio, np.ndarray): + mass_ratio *= np.ones(len(xs)) + if not isinstance(a_max, np.ndarray): + a_max *= np.ones(len(xs)) + + """ + # Select relevant effective spins + x_A = xs[caseA] + x_B = xs[caseB] + x_C = xs[caseC] + x_D = xs[caseD] + x_E = xs[caseE] + + # Select relevant effective spins + q_A = q[caseA] + q_B = q[caseB] + q_C = q[caseC] + q_D = q[caseD] + q_E = q[caseE] + """ + functions = dict() + functions["caseZ"] = lambda X, q, aMax: (1.+q)/(2.*aMax)*(2.-np.log(q)) + + functions["caseA"] = lambda X, q, aMax: (1.+q)/(4.*q*aMax**2)*( + q*aMax*(4.+2.*np.log(aMax) - np.log(q**2*aMax**2 - (1.+q)**2*X**2)) + - 2.*(1.+q)*X*np.arctanh((1.+q)*X/(q*aMax)) + + (1.+q)*X*(Di(-q*aMax/((1.+q)*X)) - Di(q*aMax/((1.+q)*X))) + ) + + functions["caseB"] = lambda X, q, aMax: (1.+q)/(4.*q*aMax**2)*( + 4.*q*aMax + + 2.*q*aMax*np.log(aMax) + - 2.*(1.+q)*X*np.arctanh(q*aMax/((1.+q)*X)) + - q*aMax*np.log((1.+q)**2*X**2 - q**2*aMax**2) + + (1.+q)*X*(Di(-q*aMax/((1.+q)*X)) - Di(q*aMax/((1.+q)*X))) + ) + + functions["caseC"] = lambda X, q, aMax: (1.+q)/(4.*q*aMax**2)*( + 2.*(1.+q)*(aMax-X) + - (1.+q)*X*np.log(aMax)**2. + + (aMax + (1.+q)*X*np.log((1.+q)*X))*np.log(q*aMax/(aMax-(1.+q)*X)) + - (1.+q)*X*np.log(aMax)*(2. + np.log(q) - np.log(aMax-(1.+q)*X)) + + q*aMax*np.log(aMax/(q*aMax-(1.+q)*X)) + + (1.+q)*X*np.log((aMax-(1.+q)*X)*(q*aMax-(1.+q)*X)/q) + + (1.+q)*X*(Di(1.-aMax/((1.+q)*X)) - Di(q*aMax/((1.+q)*X))) + ) + + functions["caseD"] = lambda X, q, aMax: (1.+q)/(4.*q*aMax**2)*( + -X*np.log(aMax)**2 + + 2.*(1.+q)*(aMax-X) + + q*aMax*np.log(aMax/((1.+q)*X-q*aMax)) + + aMax*np.log(q*aMax/(aMax-(1.+q)*X)) + - X*np.log(aMax)*(2.*(1.+q) - np.log((1.+q)*X) - q*np.log((1.+q)*X/aMax)) + + (1.+q)*X*np.log((-q*aMax+(1.+q)*X)*(aMax-(1.+q)*X)/q) + + (1.+q)*X*np.log(aMax/((1.+q)*X))*np.log((aMax-(1.+q)*X)/q) + + (1.+q)*X*(Di(1.-aMax/((1.+q)*X)) - Di(q*aMax/((1.+q)*X))) + ) + + functions["caseE"] = lambda X, q, aMax: (1.+q)/(4.*q*aMax**2)*( + 2.*(1.+q)*(aMax-X) + - (1.+q)*X*np.log(aMax)**2 + + np.log(aMax)*( + aMax + -2.*(1.+q)*X + -(1.+q)*X*np.log(q/((1.+q)*X-aMax)) + ) + - aMax*np.log(((1.+q)*X-aMax)/q) + + (1.+q)*X*np.log(((1.+q)*X-aMax)*((1.+q)*X-q*aMax)/q) + + (1.+q)*X*np.log((1.+q)*X)*np.log(q*aMax/((1.+q)*X-aMax)) + - q*aMax*np.log(((1.+q)*X-q*aMax)/aMax) + + (1.+q)*X*(Di(1.-aMax/((1.+q)*X)) - Di(q*aMax/((1.+q)*X))) + ) + + functions["caseF"] = lambda X, q, aMax: 0. + + masks = dict() + masks["caseZ"] = (xs==0) + masks["caseA"] = (xs>0)*(xsmass_ratio*a_max/(1.+mass_ratio)) + masks["caseC"] = (xs>a_max*(1.-mass_ratio)/(1.+mass_ratio))*(xsa_max*(1.-mass_ratio)/(1.+mass_ratio))*(xs=mass_ratio*a_max/(1.+mass_ratio)) + masks["caseE"] = (xs>a_max*(1.-mass_ratio)/(1.+mass_ratio))*(xs>a_max/(1.+mass_ratio))*(xs=a_max) + + for case in masks.keys(): + mask = masks[case] + Xs = xs[mask] + qs = mass_ratio[mask] + amaxes = a_max[mask] + pdfs[mask] = functions[case](X = Xs, q = qs, aMax = amaxes) + + # Deal with spins on the boundary between cases + if np.any(pdfs==-1): + boundary = (pdfs==-1) + pdfs[boundary] = 0.5*(chi_effective_prior_from_isotropic_spins(mass_ratio = mass_ratio[boundary], a_max = a_max[boundary], xs = xs[boundary]+1e-6)\ + + chi_effective_prior_from_isotropic_spins(mass_ratio = mass_ratio[boundary], a_max = a_max[boundary], xs = xs[boundary]+1e-6)) + + return np.real(pdfs) + +def chi_p_prior_from_isotropic_spins(xs, mass_ratio, a_max): + + """ + Function defining the conditional priors p(chi_p|q) corresponding to + uniform, isotropic component spin priors. + Inputs + q: Mass ratio value (according to the convention q<1) + aMax: Maximum allowed dimensionless component spin magnitude + xs: Chi_p value or values at which we wish to compute prior + Returns: + Array of prior values + """ + + # Ensure that `xs` is an array and take absolute value + xs = np.reshape(xs,-1) + + # Set up various piecewise cases + pdfs = np.zeros(xs.size) + + masks = dict() + masks["caseA"] = xs=mass_ratio*a_max*(3.+4.*mass_ratio)/(4.+3.*mass_ratio))*(xs1): + to_replace = np.where((cost1<-1) | (cost1>1))[0] + a1[to_replace] = np.random.random(to_replace.size)*aMax + a2[to_replace] = np.random.random(to_replace.size)*aMax + cost2[to_replace] = 2.*np.random.random(to_replace.size)-1. + cost1 = (xeff*(1.+q) - q*a2*cost2)/a1 + + # Compute precessing spins and corresponding weights, build KDE + # See `Joint-ChiEff-ChiP-Prior.ipynb` for a discussion of these weights + Xp_draws = chi_p_from_components(a1,a2,cost1,cost2,q) + jacobian_weights = (1.+q)/a1 + prior_kde = gaussian_kde(Xp_draws,weights=jacobian_weights,bw_method=bw_method) + + # Compute maximum chi_p + if (1.+q)*np.abs(xeff)/q