Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions heracles/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
"cl2corr",
"corr2cl",
# unmixing
"correct_correlation",
"natural_unmixing",
]

Expand Down Expand Up @@ -157,5 +158,6 @@
)

from .unmixing import (
correct_correlation,
natural_unmixing,
)
33 changes: 24 additions & 9 deletions heracles/dices/jackknife.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ..result import Result, get_result_array
from ..mapping import transform
from ..twopoint import angular_power_spectra
from ..unmixing import _natural_unmixing, logistic
from ..unmixing import _natural_unmixing, correct_correlation
from ..transforms import cl2corr

try:
Expand Down Expand Up @@ -57,7 +57,7 @@ def jackknife_cls(data_maps, vis_maps, jk_maps, fields, nd=1):
_cls = get_cls(data_maps, jk_maps, fields, *regions)
_cls_mm = get_cls(vis_maps, jk_maps, fields, *regions)
# Mask correction
alphas = mask_correction(_cls_mm, mls0)
alphas = mask_correction(_cls_mm, mls0, fields)
_cls = _natural_unmixing(_cls, alphas)
# Bias correction
_cls = correct_bias(_cls, jk_maps, fields, *regions)
Expand Down Expand Up @@ -205,17 +205,27 @@ def correct_bias(cls, jkmaps, fields, jk=0, jk2=0):
return cls


def mask_correction(Mljk, Mls0):
def mask_correction(Mljk, Mls0, fields, options={}, rtol=0.2, smoothing=50):
"""
Internal method to compute the mask correction.
input:
Mljk (np.array): mask of delete1 Cls
Mls0 (np.array): mask Cls
Mljk: mask of delete1 Cls
Mls0: mask Cls
returns:
alpha (Float64): Mask correction factor
alpha: Mask correction factor
"""
# inverse mapping of masks to fields
masks = {}
for key, field in fields.items():
if field.mask is not None:
masks[field.mask] = key

alphas = {}
for key in list(Mljk.keys()):
a, b, i, j = key
# Get corresponding mask keys
a = masks[a]
b = masks[b]
mljk = Mljk[key]
mls0 = Mls0[key]
# Transform to real space
Expand All @@ -225,9 +235,14 @@ def mask_correction(Mljk, Mls0):
wmljk = wmljk.T[0]
# Compute alpha
alpha = wmljk / wmls0
alpha *= logistic(np.log10(abs(wmljk)))
alphas[key] = alpha
return alphas
alphas[(a, b, i, j)] = alpha
corr_alphas = correct_correlation(
alphas,
options=options,
rtol=rtol,
smoothing=smoothing,
)
return corr_alphas


def jackknife_covariance(dict, nd=1):
Expand Down
10 changes: 8 additions & 2 deletions heracles/twopoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,7 @@ def mixing_matrices(

def invert_mixing_matrix(
M,
options: dict = {},
rtol: float = 1e-5,
progress: Progress | None = None,
):
Expand Down Expand Up @@ -427,6 +428,11 @@ def invert_mixing_matrix(
*_, _n, _m = _M.shape
new_ell = np.arange(_m)

if key in list(options.keys()):
_rtol = options[key].get("rtol", rtol)
else:
_rtol = rtol

with progress.task(f"invert {key}"):
if (s1 != 0) and (s2 != 0):
_inv_m = np.linalg.pinv(
Expand All @@ -435,10 +441,10 @@ def invert_mixing_matrix(
)
_inv_M_EEEE = _inv_m[:_m, :_n]
_inv_M_EEBB = _inv_m[_m:, :_n]
_inv_M_EBEB = np.linalg.pinv(_M[2], rcond=rtol)
_inv_M_EBEB = np.linalg.pinv(_M[2], rcond=_rtol)
_inv_M = np.array([_inv_M_EEEE, _inv_M_EEBB, _inv_M_EBEB])
else:
_inv_M = np.linalg.pinv(_M, rcond=rtol)
_inv_M = np.linalg.pinv(_M, rcond=_rtol)

inv_M[key] = Result(_inv_M, axis=value.axis, ell=new_ell)

Expand Down
113 changes: 91 additions & 22 deletions heracles/unmixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,45 +27,114 @@
from dataclasses import replace


def natural_unmixing(d, m, x0=-2, k=50, patch_hole=True, lmax=None):
def correct_correlation(wm, options={}, rtol=0.2, smoothing=50):
"""
Correct the correlation function of the mask to avoid
dividing by very small numbers during unmixing.
Args:
wm: mask correlation functions
options: dictionary of options for each mask
rtol: relative tolerance to apply
smoothing: smoothing parameter for the logistic function
Returns:
wm: corrected mask correlation functions
"""
wm_keys = list(wm.keys())
corr_wm = {}
for wm_key in wm_keys:
if wm_key in list(options.keys()):
_rtol = options[wm_key].get("rtol", rtol)
_smoothing = options[wm_key].get("smoothing", smoothing)
else:
_rtol = rtol
_smoothing = smoothing
_wm = wm[wm_key]
_tol = _rtol * np.max(abs(_wm))
_wm *= logistic(
np.log10(abs(_wm)),
tol=np.log10(_tol),
smoothing=_smoothing,
)
corr_wm[wm_key] = _wm
return corr_wm


def natural_unmixing(d, m, fields, options={}, rtol=0.2, smoothing=50):
"""
Natural unmixing of the data Cl.
Args:
d: data cls
m: mask cls
fields: list of fields
patch_hole: If True, apply the patch hole correction
Returns:
corr_d: Corrected Cl
"""
# inverse mapping of masks to fields
masks = {}
for key, field in fields.items():
if field.mask is not None:
masks[field.mask] = key

wm = {}
m_keys = list(m.keys())
for m_key in m_keys:
a, b, i, j = m_key
# Get corresponding mask keys
a = masks[a]
b = masks[b]
# Transform to real space
_m = m[m_key].array
_wm = cl2corr(_m).T[0]
if patch_hole:
_wm *= logistic(np.log10(abs(_wm)), x0=x0, k=k)
wm[m_key] = _wm
return _natural_unmixing(d, wm, lmax=lmax)
wm[(a, b, i, j)] = _wm

wm = correct_correlation(
wm,
options=options,
rtol=rtol,
smoothing=smoothing,
)
return _natural_unmixing(d, wm)


def _natural_unmixing(d, wm, lmax=None):
def _natural_unmixing(d, wm):
"""
Natural unmixing of the data Cl.
Args:
d: Data Cl
m: mask cls
d: data cls
m: mask correlation function
patch_hole: If True, apply the patch hole correction
Returns:
corr_d: Corrected Cl
"""
corr_d = {}
d_keys = list(d.keys())
wm_keys = list(wm.keys())
for d_key, wm_key in zip(d_keys, wm_keys):
a, b, i, j = d_key
if lmax is None:
*_, lmax = d[d_key].shape
s1, s2 = d[d_key].spin
_d = np.atleast_2d(d[d_key])
_wm = wm[wm_key]
lmax_mask = len(wm[wm_key])
for key in list(d.keys()):
ell = d[key].ell
if ell is None:
*_, lmax = d[key].shape
else:
lmax = ell[-1] + 1
s1, s2 = d[key].spin
_d = np.atleast_2d(d[key])
# Get corresponding mask correlation function
if key in wm:
_wm = wm[key]
else:
a, b, i, j = key
if a == b and (a, b, j, i) in wm:
_wm = wm[(a, b, j, i)]
elif (b, a, j, i) in wm:
_wm = wm[(b, a, j, i)]
else:
raise KeyError(f"Key {key} not found in mask correlation functions.")

lmax_mask = len(_wm)
# pad cls
pad_width = [(0, 0)] * _d.ndim # no padding for other dims
pad_width[-1] = (0, lmax_mask - lmax) # pad only last dim
_d = np.pad(_d, pad_width, mode="constant", constant_values=0)
# Grab metadata
dtype = d[d_key].array.dtype
dtype = d[key].array.dtype
if (s1 != 0) and (s2 != 0):
__d = np.array(
[
Expand Down Expand Up @@ -109,11 +178,11 @@ def _natural_unmixing(d, wm, lmax=None):
_corr_d = np.squeeze(_corr_d)
# Add metadata back
_corr_d = np.array(list(_corr_d), dtype=dtype)
corr_d[d_key] = replace(d[d_key], array=_corr_d)
corr_d[key] = replace(d[key], array=_corr_d)
# truncate to lmax
corr_d = truncated(corr_d, lmax)
return corr_d


def logistic(x, x0=-5, k=50):
return 1.0 + np.exp(-k * (x - x0))
def logistic(x, tol=-5, smoothing=50):
return 1.0 + np.exp(-smoothing * (x - tol))
65 changes: 46 additions & 19 deletions tests/test_dices.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,29 +91,56 @@ def test_get_delete2_fsky(jk_maps, njk):
assert alpha == pytest.approx(_alpha, rel=1e-1)


def test_mask_correction(cls0, mls0):
alphas = dices.mask_correction(mls0, mls0)
_cls = heracles.unmixing._natural_unmixing(cls0, alphas)
def test_mask_correction(cls0, mls0, fields):
# test natural umixing
masks = {}
for key, field in fields.items():
if field.mask is not None:
masks[field.mask] = key
wm = {}
m_keys = list(mls0.keys())
for m_key in m_keys:
a, b, i, j = m_key
# Get corresponding mask keys
a = masks[a]
b = masks[b]
# Transform to real space
_m = mls0[m_key].array
_wm = heracles.transforms.cl2corr(_m).T[0]
wm[(a, b, i, j)] = _wm
cls = heracles.unmixing.natural_unmixing(cls0, mls0, fields)
_cls = heracles.unmixing._natural_unmixing(cls0, wm)
for key in list(cls0.keys()):
cl = cls0[key].array
cl = cls[key].array
_cl = _cls[key].array
assert np.isclose(cl[2:], _cl[2:]).all()


def test_polspice(cls0):
from heracles.dices.utils import get_cl

cls = np.array(
[
get_cl(("POS", "POS", 1, 1), cls0),
get_cl(("SHE", "SHE", 1, 1), cls0)[0, 0],
get_cl(("SHE", "SHE", 1, 1), cls0)[1, 1],
get_cl(("POS", "SHE", 1, 1), cls0)[0],
]
).T
corrs = heracles.cl2corr(cls)
_cls = heracles.corr2cl(corrs)
for cl, _cl in zip(cls.T, _cls.T):
# test logistic correction
wm = {}
m_keys = list(mls0.keys())
for m_key in m_keys:
a, b, i, j = m_key
# Get corresponding mask keys
a = masks[a]
b = masks[b]
# Transform to real space
_m = mls0[m_key].array
_wm = heracles.transforms.cl2corr(_m).T[0]
_wm = np.abs(_wm)
_wm /= np.max(_wm)
wm[(a, b, i, j)] = _wm
_wm = heracles.correct_correlation(wm, rtol=_wm)
for key in wm.keys():
_w = wm[key]
__w = _wm[key]
assert np.isclose(__w, _w).all()

# test dices mask correction
alphas = dices.mask_correction(mls0, mls0, fields)
_cls = heracles.unmixing._natural_unmixing(cls0, alphas)
for key in list(cls0.keys()):
cl = cls0[key].array
_cl = _cls[key].array
assert np.isclose(cl[2:], _cl[2:]).all()


Expand Down
21 changes: 16 additions & 5 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import heracles


def test_cl2corr():
# Is there something more clever we can do here?
# Like transforming the legendre nodes and return ones?
def test_cl_transform(cls0):
from heracles.dices.utils import get_cl

cl = np.array(
[
np.ones(10),
Expand All @@ -16,8 +16,6 @@ def test_cl2corr():
corr = heracles.cl2corr(cl.T).T
assert corr.shape == cl.shape


def test_corr2cl():
corr = np.array(
[
np.ones(10),
Expand All @@ -28,3 +26,16 @@ def test_corr2cl():
)
cl = heracles.corr2cl(corr.T).T
assert corr.shape == cl.shape

cls = np.array(
[
get_cl(("POS", "POS", 1, 1), cls0),
get_cl(("SHE", "SHE", 1, 1), cls0)[0, 0],
get_cl(("SHE", "SHE", 1, 1), cls0)[1, 1],
get_cl(("POS", "SHE", 1, 1), cls0)[0],
]
).T
corrs = heracles.cl2corr(cls)
_cls = heracles.corr2cl(corrs)
for cl, _cl in zip(cls.T, _cls.T):
assert np.isclose(cl[2:], _cl[2:]).all()
12 changes: 12 additions & 0 deletions tests/test_twopoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,3 +402,15 @@ def test_inverting_mixing_matrices():
_mixed_cls = np.sum(mixed_cls[key].array)
print(key, _mixed_cls)
np.testing.assert_allclose(_mixed_cls, (n + 1) * 1.0)

# test options
options = {("POS", "POS", 1, 1): {"rtol": 1}}
opt_inv_mms = invert_mixing_matrix(mms, options=options)
opt_mixed_cls = apply_mixing_matrix(cls, opt_inv_mms)
for key in opt_mixed_cls:
if key in list(options.keys()):
assert np.all(
opt_mixed_cls[key].array == np.zeros_like(mixed_cls[key].array)
)
else:
assert np.all(opt_mixed_cls[key].array == mixed_cls[key].array)