Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 38 additions & 20 deletions src/scripts/intensity_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import jax
import jax.numpy as jnp
import jax.scipy.special as jss
import jax.scipy.stats as jsst
import jax.scipy.integrate as jsi
import numpy as np
import numpyro
import numpyro.distributions as dist
Expand Down Expand Up @@ -56,6 +58,11 @@ def log_dNdmCO(mco, a, b):
x = mco/mtr
return jnp.where(mco < mtr, -a*jnp.log(x), -b*jnp.log(x))

def smooth_log_dNdmCO(xx, a, b):
xtr = 20
delta = 0.05
return -a * jnp.log(xx / xtr) + delta * (a - b) * jnp.log(0.5 * (1 + (xx/xtr)**(1/delta)))

def log_smooth_turnon(m, mmin, width=0.05):
"""A function that smoothly transitions from 0 to 1.

Expand Down Expand Up @@ -159,10 +166,10 @@ class LogDNDMPISN_evolve(object):
mpisn: object
mbhmax: object
sigma: object
n_m: object = 800
n_m: object = 1024
mbh_grid: object = dataclasses.field(init=False)
log_dN_grid: object = dataclasses.field(init=False)

def __post_init__(self):
min_bh_mass = 3.0
min_co_mass = 1.0
Expand All @@ -184,17 +191,23 @@ def __post_init__(self):
mco = jnp.exp(log_mco)

mu = mean_mbh_from_mco(mco, mpisn, mbhmax)
log_mu = jnp.log(jnp.where(mu > 0, mu, 0))
log_p = -0.5*jnp.square((log_mbh - log_mu)/sigma) - 0.5*jnp.log(2*np.pi) - jnp.log(sigma) - log_mbh
mu_min = 0.1
mu = jnp.where(mu > 0, mu, mu_min)
log_mu = jnp.log(mu)

log_p = -0.5 * jnp.square((log_mbh - log_mu) / sigma) - 0.5*jnp.log(2*jnp.pi) - jnp.log(sigma) - log_mbh

#log_p = -0.5*jnp.square((log_mbh - log_mu)/sigma) - 0.5*jnp.log(2*np.pi) - jnp.log(sigma) - log_mbh

#log_p = jsst.norm.logpdf(x = log_mbh, loc = log_mu, scale = sigma) - log_mbh

log_wts = log_dNdmCO(mco, self.a, self.b) + log_p
#log_trapz = jnp.log(jsi.trapezoid(x = mco, y = jnp.exp(log_wts), axis=1))
log_trapz = np.log(0.5) + jnp.logaddexp(log_wts[:,:-1,:], log_wts[:,1:,:]) + jnp.log(jnp.diff(mco, axis=1))

self.log_dN_grid = jss.logsumexp(log_trapz, axis=1)
self.mbh_grid = mbh[0,0,:]

#def __call__(self, m):
#return jnp.interp(m, self.mbh_grid, self.log_dN_grid)


@dataclass
class LogDNDM_evolve(object):
Expand All @@ -215,7 +228,7 @@ class LogDNDM_evolve(object):
mbh_min: object = mbh_min
zmax: object = 20
mref: object = 30.0
zref: object = 0
zref: object = 0.001
log_dndm_pisn: object = dataclasses.field(init=False)

def __post_init__(self):
Expand All @@ -225,15 +238,15 @@ def __post_init__(self):
#mpisn_at_zref = self.mpisn + self.mpisndot * (1 - 1/(1 + self.zref))
#mbhmax_at_zref = mpisn_at_zref + self.dmbhmax

# self.log_pl_norm = jnp.log(self.fpl) + self.interp_2d_dndmpisn(mbhmax_at_zref, self.zref)
#self.log_pl_norm = jnp.log(self.fpl) + self.interp_2d_dndmpisn(mbhmax_at_zref, self.zref)

#self.log_norm = -(self(self.mref, self.zref) + jnp.log(self.mref)) # normalize so that m dNdm = 1 at mref

def setup_interp(self):
self.z_array = jnp.expm1(jnp.linspace(np.log(1), jnp.log(1+self.zmax), 50))
mpisns = self.mpisn + self.mpisndot * (1 - 1/(1+self.z_array))
mbhmaxs = mpisns + self.dmbhmax
self.log_dndm_pisn = LogDNDMPISN_evolve(self.a, self.b, jnp.array(mpisns), jnp.array(mbhmaxs), self.sigma)
self.log_dndm_pisn = LogDNDMPISN_evolve(self.a, self.b, mpisns, mbhmaxs, self.sigma)
self.mbh_grid = self.log_dndm_pisn.mbh_grid
self.log_dndm_pisn_grid = self.log_dndm_pisn.log_dN_grid.T
self.mbhmaxs = jnp.array(mbhmaxs)
Expand Down Expand Up @@ -261,8 +274,12 @@ def interp_2d_dndmpisn(self, m, z):

t = jnp.where(m_lower == m_upper, 0, (m - m_lower) / mdiffs)
u = jnp.where(z_lower == z_upper, 0, (z - z_lower) / zdiffs)

return (1 - t) * (1 - u) * f1 + t * (1 - u) * f2 + t * u *f3 + (1 - t) * u * f4

coefficients = jnp.array([(1 - t) * (1 - u), t * (1 - u), t * u, (1 - t) * u ])
fs = jnp.array([f1, f2, f3, f4])
coefficients = jnp.where(jnp.isinf(fs), 1e-6, coefficients)

return jnp.einsum('i...,i...',coefficients, fs)

def __call__(self, m, z):
m = jnp.array(m)
Expand All @@ -278,7 +295,7 @@ def __call__(self, m, z):
log_dNdm = jnp.logaddexp(log_dNdm, jnp.log(self.fpl) + log_dNdmbhmax_at_samples + -self.c*jnp.log(m/mbhmax_at_samples) + log_smooth_turnon(m, mbhmax_at_samples))
log_dNdm = jnp.where(m < self.mbh_min, np.NINF, log_dNdm)

return log_dNdm #+ self.log_norm
return log_dNdm

@dataclass
class LogDNDM(object):
Expand Down Expand Up @@ -331,7 +348,7 @@ class LogDNDV(object):
lam: object
kappa: object
zp: object
zref: object = 0.01
zref: object = 0.001
zmax: object = 20
log_norm: object = 0.0

Expand Down Expand Up @@ -361,7 +378,7 @@ class LogDNDMDQDV(object):
zp: object
mref: object = 30.0
qref: object = 1.0
zref: object = 0.01
zref: object = 0.001
log_dndm: object = dataclasses.field(init=False)
log_dndv: object = dataclasses.field(init=False)

Expand Down Expand Up @@ -399,7 +416,7 @@ class LogDNDMDQDV_evolve(object):
zp: object
mref: object = 30.0
qref: object = 1.0
zref: object = 0.0
zref: object = 0.001
zmax: object = 20
log_dndm: object = dataclasses.field(init=False)
log_dndv: object = dataclasses.field(init=False)
Expand Down Expand Up @@ -502,13 +519,12 @@ def mass_parameters():
c = numpyro.sample('c', dist.TruncatedNormal(4, 2, low=0, high=8))

mpisn = numpyro.sample('mpisn', dist.TruncatedNormal(35.0, 5.0, low=20.0, high=50.0))
dmbhmax = numpyro.sample('dmbhmax', dist.TruncatedNormal(5.0, 2.0, low=0.5, high=11.0))
dmbhmax = numpyro.sample('dmbhmax', dist.TruncatedNormal(3.0, 2.0, low=0.5, high=7.0))#used to be mean 5.0
mbhmax = numpyro.deterministic('mbhmax', mpisn + dmbhmax)
sigma = numpyro.sample('sigma', dist.TruncatedNormal(0.1, 0.1, low=0.05))

beta = numpyro.sample('beta', dist.Normal(0, 2))

log_fpl = numpyro.sample('log_fpl', dist.Uniform(np.log(1e-3), np.log(0.5)))
log_fpl = numpyro.sample('log_fpl', dist.Uniform(np.log(1e-2), np.log(0.5)))
fpl = numpyro.deterministic('fpl', jnp.exp(log_fpl))

return a,b,c,mpisn,mbhmax,sigma,beta,fpl
Expand All @@ -529,7 +545,9 @@ def cosmo_parameters():
return h,Om,w

def evolve_parameters():
numpyro.sample('mpisndot', dist.Uinform(low=-2, high=8))
mpisndot = numpyro.sample('mpisndot', dist.Uniform(low=-2, high=8))
#mpisndot = numpyro.sample('mpisndot', dist.Uniform(low=-50, high=60))
return mpisndot

def pop_cosmo_model(m1s_det, qs, dls, pdraw, m1s_det_sel, qs_sel, dls_sel, pdraw_sel, Ndraw, evolution = False, zmax=20, fixed_cosmo_params = None):
m1s_det, qs, dls, pdraw, m1s_det_sel, qs_sel, dls_sel, pdraw_sel = map(jnp.array, (m1s_det, qs, dls, pdraw, m1s_det_sel, qs_sel, dls_sel, pdraw_sel))
Expand Down
Loading