Skip to content
32 changes: 18 additions & 14 deletions sunbird/cosmology/growth_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
class Growth:
def __init__(
self,
theta_star: float = AbacusSummit(0).theta_star,
theta_MC_100: float = AbacusSummit(0)['theta_MC_100'],
emulate=False,
emulator_data_dir=DEFAULT_PATH / "data/hemu/",
):
self.theta_star = theta_star
self.theta_MC_100 = theta_MC_100
self.emulate = emulate
self.emulator_data_dir = emulator_data_dir
if self.emulate:
Expand Down Expand Up @@ -59,10 +59,13 @@ def generate_emulator_training_data(
h_values, sample_parameters = [], []
for i, sample in enumerate(samples_matrix):
try:
cosmology = self.get_cosmology_fixed_theta_star(
# print every 1000th sample
if i % 1000 == 0:
print(i)
cosmology = self.get_cosmology_fixed_theta_MC_100(
DESI(engine="class"),
dict(
theta_star=self.theta_star,
theta_MC_100=self.theta_MC_100,
omega_b=sample[0],
omega_cdm=sample[1],
sigma8=sample[2],
Expand Down Expand Up @@ -274,29 +277,30 @@ def get_emulated_h(self, omega_b, omega_cdm, sigma8, N_ur, n_s, w0_fld, wa_fld):
x = jnp.vstack([omega_b, omega_cdm, sigma8, N_ur, n_s, w0_fld, wa_fld]).T
return self.model.apply(self.params, x)

def get_cosmology_fixed_theta_star(
def get_cosmology_fixed_theta_MC_100(
self,
fiducial,
params,
h_limits=[0.4, 1.0],
xtol=1.0e-6,
):
theta = params.pop("theta_star", None)
theta = params.pop("theta_MC_100", None)
fiducial = fiducial.clone(base="input", **params)
if theta is not None:
if "h" in params:
raise ValueError("Cannot provide both theta_star and h")
raise ValueError("Cannot provide both theta_MC_100 and h")

def f(h):
cosmo = fiducial.clone(base="input", h=h)
return 100.0 * (theta - cosmo.get_thermodynamics().theta_star)
# return 100.0 * (theta - cosmo.get_thermodynamics().theta_MC_100)
return 100.0 * (theta - cosmo['theta_MC_100'])

rtol = xtol
try:
h = optimize.bisect(f, *h_limits, xtol=xtol, rtol=rtol, disp=True)
except ValueError as exc:
raise ValueError(
"Could not find proper h value in the interval that matches theta_star = {:.4f} with [f({:.3f}), f({:.3f})] = [{:.4f}, {:.4f}]".format(
"Could not find proper h value in the interval that matches theta_MC_100 = {:.4f} with [f({:.3f}), f({:.3f})] = [{:.4f}, {:.4f}]".format(
theta, *h_limits, *list(map(f, h_limits))
)
) from exc
Expand Down Expand Up @@ -335,10 +339,10 @@ def get_growth(
z=z,
)
else:
cosmology = self.get_cosmology_fixed_theta_star(
cosmology = self.get_cosmology_fixed_theta_MC_100(
DESI(engine="class"),
dict(
theta_star=self.theta_star,
theta_MC_100=self.theta_MC_100,
omega_b=omega_b,
omega_cdm=omega_cdm,
sigma8=sigma8,
Expand Down Expand Up @@ -393,10 +397,10 @@ def get_fsigma8(
)
return growth_rate * sigma8_z
else:
cosmology = self.get_cosmology_fixed_theta_star(
cosmology = self.get_cosmology_fixed_theta_MC_100(
DESI(engine="class"),
dict(
theta_star=self.theta_star,
theta_MC_100=self.theta_MC_100,
omega_b=omega_b,
omega_cdm=omega_cdm,
sigma8=sigma8,
Expand All @@ -414,6 +418,6 @@ def get_fsigma8(

t0 = time.time()
growth = Growth()
# growth.generate_emulator_training_data()
growth.generate_emulator_training_data(n_samples=100_000)
growth.train_emulator()
print(f"It took {time.time() - t0} seconds")
105 changes: 105 additions & 0 deletions sunbird/data/transforms_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,27 @@ def transform(self, x):
@abstractmethod
def inverse_transform(self, x):
pass

@abstractmethod
def get_jacobian_diagonal(self, y):
"""
Get the diagonal of the Jacobian matrix df/dy for transforming covariance matrices.

For an element-wise transformation f(y), the transformed covariance is:
Cov_transformed = diag(J) @ Cov @ diag(J)
where J = df/dy is the Jacobian diagonal.

Parameters
----------
y : array_like
Data vector in the original (untransformed) space.

Returns
-------
array_like
Diagonal of the Jacobian matrix, same shape as y.
"""
pass


class LogTransform(BaseTransform):
Expand All @@ -22,6 +43,27 @@ def transform(self, x):

def inverse_transform(self, x):
return 10**x

def get_jacobian_diagonal(self, y):
"""
Get Jacobian diagonal for log10 transform: d(log10(y))/dy = 1/(y * ln(10))

Parameters
----------
y : array_like
Data vector in the original (untransformed) space.

Returns
-------
array_like
Jacobian diagonal: 1/(y * ln(10))
"""
if type(y) == torch.Tensor:
return 1.0 / (y * torch.log(torch.tensor(10.0)))
elif type(y) == np.ndarray:
return 1.0 / (y * np.log(10.0))
else:
return 1.0 / (y * jnp.log(10.0))


class ArcsinhTransform(BaseTransform):
Expand All @@ -40,6 +82,27 @@ def inverse_transform(self, x):
return np.sinh(x)
else:
return jnp.sinh(x)

def get_jacobian_diagonal(self, y):
"""
Get Jacobian diagonal for arcsinh transform: d(arcsinh(y))/dy = 1/sqrt(1 + y^2)

Parameters
----------
y : array_like
Data vector in the original (untransformed) space.

Returns
-------
array_like
Jacobian diagonal: 1/sqrt(1 + y^2)
"""
if type(y) == torch.Tensor:
return 1.0 / torch.sqrt(1.0 + y**2)
elif type(y) == np.ndarray:
return 1.0 / np.sqrt(1.0 + y**2)
else:
return 1.0 / jnp.sqrt(1.0 + y**2)

class WeiLiuOutputTransForm(BaseTransform):
"""Class to reconcile output the Minkowski functionals model
Expand All @@ -56,6 +119,27 @@ def transform(self, x):

def inverse_transform(self, x):
return x * self.std + self.mean

def get_jacobian_diagonal(self, y):
"""
Get Jacobian diagonal for affine transform: d(y * std + mean)/dy = std

Parameters
----------
y : array_like
Data vector in the original (untransformed) space.

Returns
-------
array_like
Jacobian diagonal: std (broadcast to match y shape)
"""
if type(y) == torch.Tensor:
return torch.ones_like(y) * self.std
elif type(y) == np.ndarray:
return np.ones_like(y) * self.std.numpy()
else:
return jnp.ones_like(y) * self.std.numpy()

class WeiLiuInputTransform(BaseTransform):
"""Class to reconcile input of the Minkowski functionals model
Expand All @@ -72,4 +156,25 @@ def transform(self, x):

def inverse_transform(self, x):
return x

def get_jacobian_diagonal(self, y):
"""
Get Jacobian diagonal for standardization: d((y - mean) / std)/dy = 1/std

Parameters
----------
y : array_like
Data vector in the original (untransformed) space.

Returns
-------
array_like
Jacobian diagonal: 1/std (broadcast to match y shape)
"""
if type(y) == torch.Tensor:
return torch.ones_like(y) / self.std
elif type(y) == np.ndarray:
return np.ones_like(y) / self.std.numpy()
else:
return jnp.ones_like(y) / self.std.numpy()

20 changes: 16 additions & 4 deletions sunbird/emulators/models/fcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
],
)
self.compression_matrix = compression_matrix

@staticmethod
def add_model_specific_args(parent_parser):
"""Model arguments that could vary
Expand Down Expand Up @@ -275,16 +275,28 @@ def forward(self, x: Tensor) -> Tensor:
y_var = torch.zeros_like(y_pred)
return y_pred, y_var

def get_prediction(self, x: Tensor, filters: Optional[dict] = None) -> Tensor:
def get_prediction(self, x: Tensor, filters: Optional[dict] = None, skip_output_inverse_transform: bool = False) -> Tensor:
"""Get prediction from the model.

Args:
x (Tensor): Input tensor
filters (dict, optional): Filters to apply. Defaults to None.
skip_output_inverse_transform (bool, optional): If True, skip the output inverse transformation,
keeping predictions in the transformed space. Useful when performing inference in transformed
space (requires transforming observations and covariance to match). Defaults to False.

Returns:
Tensor: Model prediction
"""
x = torch.Tensor(x)
if self.transform_input:
if self.transform_input is not None:
x = self.transform_input.transform(x)
y, _ = self.forward(x)
if self.standarize_output:
std_output = self.std_output.to(x.device)
mean_output = self.mean_output.to(x.device)
y = y * std_output + mean_output
if self.transform_output:
if self.transform_output is not None and not skip_output_inverse_transform:
y = self.transform_output.inverse_transform(y)
if self.compression_matrix is not None:
y = y @ self.compression_matrix
Expand Down
71 changes: 64 additions & 7 deletions sunbird/inference/base.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,97 @@
"""Base classes and utilities for inference samplers."""

import logging
import numpy as np
from tabulate import tabulate
from typing import Dict, Optional
from sunbird.inference.priors import AbacusSummitEllipsoid

class BaseSampler:
def __init__(self,
"""Base class for inference samplers.

Handles parameter bookkeeping, optional transformed-space sampling, and
convenience utilities for saving chains and summary tables.
"""

def __init__(
self,
observation,
precision_matrix,
theory_model,
priors,
ranges: Optional[Dict[str, tuple]] = {},
labels: Dict[str, str] = {},
fixed_parameters: Dict[str, float] = {},
slice_filters: Dict = {},
select_filters: Dict = {},
coordinates: list = [],
ellipsoid: bool = False,
markers: dict = {},
sample_in_transformed_space: bool = False,
**kwargs,
):
"""Initialize the sampler base.

Args:
observation: Observed data vector.
precision_matrix: Inverse covariance matrix.
theory_model: Callable model that maps parameters to predictions.
priors: Mapping of parameter names to prior objects.
ranges: Optional plotting or reporting ranges by parameter.
labels: Optional labels by parameter.
fixed_parameters: Mapping of parameter names to fixed values.
ellipsoid: Whether to include the AbacusSummit ellipsoid prior.
markers: Optional marker styling for plots.
sample_in_transformed_space: If True, use transformed outputs.
**kwargs: Extra arguments for subclasses.
"""
self.logger = logging.getLogger(self.__class__.__name__)
self.theory_model = theory_model
if fixed_parameters is None:
fixed_parameters = {}
self.fixed_parameters = fixed_parameters
self.observation = observation
self.priors = priors
self.ranges = ranges
self.labels = labels
self.precision_matrix = precision_matrix
self.ellipsoid = ellipsoid
self.markers = markers
self.sample_in_transformed_space = sample_in_transformed_space

# Handle transformation of observations and covariance
if sample_in_transformed_space:
# Validate that the observable has an output transform
if not hasattr(theory_model.__self__.model, 'transform_output'):
raise ValueError('Cannot sample in transformed space: observable does not have a transform_output. '
'Either set sample_in_transformed_space=False or use an observable with transform_output.')

# Check if transform_output is valid (not None or empty list)
transform = theory_model.__self__.model.transform_output
if transform is None:
raise ValueError('Cannot sample in transformed space: transform_output is None. '
'Either set sample_in_transformed_space=False or use an observable with transform_output.')

# For combined observables, transform_output is a list
if isinstance(transform, list):
if all(t is None for t in transform):
raise ValueError('Cannot sample in transformed space: all transforms in combined observable are None. '
'Either set sample_in_transformed_space=False or use observables with transform_output.')

self.logger.warning('Sampling in transformed space (skip_output_inverse_transform=True). '
'Ensure observations and covariance matrix are also transformed to match!')

self.observation = observation
self.precision_matrix = precision_matrix

if self.ellipsoid:
self.abacus_ellipsoid = AbacusSummitEllipsoid()

self.ndim = len(self.priors.keys()) - len(self.fixed_parameters.keys())
self.logger.info(f'Free parameters: {[key for key in priors.keys() if key not in fixed_parameters.keys()]}')
self.logger.info(f'Fixed parameters: {[key for key in priors.keys() if key in fixed_parameters.keys()]}')

def save_chain(self, save_fn, metadata=None):
"""Save the chain to a file
"""Save a chain dictionary to a NumPy file.

Args:
save_fn: Output filename for the NumPy archive.
metadata: Optional extra metadata to include.
"""
data = self.get_chain(flat=True)
names = [param for param in self.priors.keys() if param not in self.fixed_parameters]
Expand All @@ -66,6 +118,11 @@ def save_chain(self, save_fn, metadata=None):
np.save(save_fn, cout)

def save_table(self, save_fn):
"""Write a summary table with MAP/mean/std values.

Args:
save_fn: Output filename for the text table.
"""
chain = self.get_chain(flat=True)
maxp = chain['samples'][chain['log_posterior'].argmax()]
mean = chain['samples'].mean(axis=0)
Expand Down
Loading