diff --git a/heracles/__init__.py b/heracles/__init__.py index f1502ad..1514fff 100644 --- a/heracles/__init__.py +++ b/heracles/__init__.py @@ -76,6 +76,7 @@ "cl2corr", "corr2cl", # unmixing + "correct_correlation", "natural_unmixing", ] @@ -157,5 +158,6 @@ ) from .unmixing import ( + correct_correlation, natural_unmixing, ) diff --git a/heracles/dices/jackknife.py b/heracles/dices/jackknife.py index 1a65759..83407f3 100644 --- a/heracles/dices/jackknife.py +++ b/heracles/dices/jackknife.py @@ -25,7 +25,7 @@ from ..result import Result, get_result_array from ..mapping import transform from ..twopoint import angular_power_spectra -from ..unmixing import _natural_unmixing, logistic +from ..unmixing import _natural_unmixing, correct_correlation from ..transforms import cl2corr try: @@ -57,7 +57,7 @@ def jackknife_cls(data_maps, vis_maps, jk_maps, fields, nd=1): _cls = get_cls(data_maps, jk_maps, fields, *regions) _cls_mm = get_cls(vis_maps, jk_maps, fields, *regions) # Mask correction - alphas = mask_correction(_cls_mm, mls0) + alphas = mask_correction(_cls_mm, mls0, fields) _cls = _natural_unmixing(_cls, alphas) # Bias correction _cls = correct_bias(_cls, jk_maps, fields, *regions) @@ -205,17 +205,27 @@ def correct_bias(cls, jkmaps, fields, jk=0, jk2=0): return cls -def mask_correction(Mljk, Mls0): +def mask_correction(Mljk, Mls0, fields, options={}, rtol=0.2, smoothing=50): """ Internal method to compute the mask correction. input: - Mljk (np.array): mask of delete1 Cls - Mls0 (np.array): mask Cls + Mljk: mask of delete1 Cls + Mls0: mask Cls returns: - alpha (Float64): Mask correction factor + alpha: Mask correction factor """ + # inverse mapping of masks to fields + masks = {} + for key, field in fields.items(): + if field.mask is not None: + masks[field.mask] = key + alphas = {} for key in list(Mljk.keys()): + a, b, i, j = key + # Get corresponding mask keys + a = masks[a] + b = masks[b] mljk = Mljk[key] mls0 = Mls0[key] # Transform to real space @@ -225,9 +235,14 @@ def mask_correction(Mljk, Mls0): wmljk = wmljk.T[0] # Compute alpha alpha = wmljk / wmls0 - alpha *= logistic(np.log10(abs(wmljk))) - alphas[key] = alpha - return alphas + alphas[(a, b, i, j)] = alpha + corr_alphas = correct_correlation( + alphas, + options=options, + rtol=rtol, + smoothing=smoothing, + ) + return corr_alphas def jackknife_covariance(dict, nd=1): diff --git a/heracles/twopoint.py b/heracles/twopoint.py index b55aad5..40ca49b 100644 --- a/heracles/twopoint.py +++ b/heracles/twopoint.py @@ -398,6 +398,7 @@ def mixing_matrices( def invert_mixing_matrix( M, + options: dict = {}, rtol: float = 1e-5, progress: Progress | None = None, ): @@ -427,6 +428,11 @@ def invert_mixing_matrix( *_, _n, _m = _M.shape new_ell = np.arange(_m) + if key in list(options.keys()): + _rtol = options[key].get("rtol", rtol) + else: + _rtol = rtol + with progress.task(f"invert {key}"): if (s1 != 0) and (s2 != 0): _inv_m = np.linalg.pinv( @@ -435,10 +441,10 @@ def invert_mixing_matrix( ) _inv_M_EEEE = _inv_m[:_m, :_n] _inv_M_EEBB = _inv_m[_m:, :_n] - _inv_M_EBEB = np.linalg.pinv(_M[2], rcond=rtol) + _inv_M_EBEB = np.linalg.pinv(_M[2], rcond=_rtol) _inv_M = np.array([_inv_M_EEEE, _inv_M_EEBB, _inv_M_EBEB]) else: - _inv_M = np.linalg.pinv(_M, rcond=rtol) + _inv_M = np.linalg.pinv(_M, rcond=_rtol) inv_M[key] = Result(_inv_M, axis=value.axis, ell=new_ell) diff --git a/heracles/unmixing.py b/heracles/unmixing.py index fb87589..7476d18 100644 --- a/heracles/unmixing.py +++ b/heracles/unmixing.py @@ -27,45 +27,114 @@ from dataclasses import replace -def natural_unmixing(d, m, x0=-2, k=50, patch_hole=True, lmax=None): +def correct_correlation(wm, options={}, rtol=0.2, smoothing=50): + """ + Correct the correlation function of the mask to avoid + dividing by very small numbers during unmixing. + Args: + wm: mask correlation functions + options: dictionary of options for each mask + rtol: relative tolerance to apply + smoothing: smoothing parameter for the logistic function + Returns: + wm: corrected mask correlation functions + """ + wm_keys = list(wm.keys()) + corr_wm = {} + for wm_key in wm_keys: + if wm_key in list(options.keys()): + _rtol = options[wm_key].get("rtol", rtol) + _smoothing = options[wm_key].get("smoothing", smoothing) + else: + _rtol = rtol + _smoothing = smoothing + _wm = wm[wm_key] + _tol = _rtol * np.max(abs(_wm)) + _wm *= logistic( + np.log10(abs(_wm)), + tol=np.log10(_tol), + smoothing=_smoothing, + ) + corr_wm[wm_key] = _wm + return corr_wm + + +def natural_unmixing(d, m, fields, options={}, rtol=0.2, smoothing=50): + """ + Natural unmixing of the data Cl. + Args: + d: data cls + m: mask cls + fields: list of fields + patch_hole: If True, apply the patch hole correction + Returns: + corr_d: Corrected Cl + """ + # inverse mapping of masks to fields + masks = {} + for key, field in fields.items(): + if field.mask is not None: + masks[field.mask] = key + wm = {} m_keys = list(m.keys()) for m_key in m_keys: + a, b, i, j = m_key + # Get corresponding mask keys + a = masks[a] + b = masks[b] + # Transform to real space _m = m[m_key].array _wm = cl2corr(_m).T[0] - if patch_hole: - _wm *= logistic(np.log10(abs(_wm)), x0=x0, k=k) - wm[m_key] = _wm - return _natural_unmixing(d, wm, lmax=lmax) + wm[(a, b, i, j)] = _wm + + wm = correct_correlation( + wm, + options=options, + rtol=rtol, + smoothing=smoothing, + ) + return _natural_unmixing(d, wm) -def _natural_unmixing(d, wm, lmax=None): +def _natural_unmixing(d, wm): """ Natural unmixing of the data Cl. Args: - d: Data Cl - m: mask cls + d: data cls + m: mask correlation function patch_hole: If True, apply the patch hole correction Returns: corr_d: Corrected Cl """ corr_d = {} - d_keys = list(d.keys()) - wm_keys = list(wm.keys()) - for d_key, wm_key in zip(d_keys, wm_keys): - a, b, i, j = d_key - if lmax is None: - *_, lmax = d[d_key].shape - s1, s2 = d[d_key].spin - _d = np.atleast_2d(d[d_key]) - _wm = wm[wm_key] - lmax_mask = len(wm[wm_key]) + for key in list(d.keys()): + ell = d[key].ell + if ell is None: + *_, lmax = d[key].shape + else: + lmax = ell[-1] + 1 + s1, s2 = d[key].spin + _d = np.atleast_2d(d[key]) + # Get corresponding mask correlation function + if key in wm: + _wm = wm[key] + else: + a, b, i, j = key + if a == b and (a, b, j, i) in wm: + _wm = wm[(a, b, j, i)] + elif (b, a, j, i) in wm: + _wm = wm[(b, a, j, i)] + else: + raise KeyError(f"Key {key} not found in mask correlation functions.") + + lmax_mask = len(_wm) # pad cls pad_width = [(0, 0)] * _d.ndim # no padding for other dims pad_width[-1] = (0, lmax_mask - lmax) # pad only last dim _d = np.pad(_d, pad_width, mode="constant", constant_values=0) # Grab metadata - dtype = d[d_key].array.dtype + dtype = d[key].array.dtype if (s1 != 0) and (s2 != 0): __d = np.array( [ @@ -109,11 +178,11 @@ def _natural_unmixing(d, wm, lmax=None): _corr_d = np.squeeze(_corr_d) # Add metadata back _corr_d = np.array(list(_corr_d), dtype=dtype) - corr_d[d_key] = replace(d[d_key], array=_corr_d) + corr_d[key] = replace(d[key], array=_corr_d) # truncate to lmax corr_d = truncated(corr_d, lmax) return corr_d -def logistic(x, x0=-5, k=50): - return 1.0 + np.exp(-k * (x - x0)) +def logistic(x, tol=-5, smoothing=50): + return 1.0 + np.exp(-smoothing * (x - tol)) diff --git a/tests/test_dices.py b/tests/test_dices.py index b52d88d..8798933 100644 --- a/tests/test_dices.py +++ b/tests/test_dices.py @@ -91,29 +91,56 @@ def test_get_delete2_fsky(jk_maps, njk): assert alpha == pytest.approx(_alpha, rel=1e-1) -def test_mask_correction(cls0, mls0): - alphas = dices.mask_correction(mls0, mls0) - _cls = heracles.unmixing._natural_unmixing(cls0, alphas) +def test_mask_correction(cls0, mls0, fields): + # test natural umixing + masks = {} + for key, field in fields.items(): + if field.mask is not None: + masks[field.mask] = key + wm = {} + m_keys = list(mls0.keys()) + for m_key in m_keys: + a, b, i, j = m_key + # Get corresponding mask keys + a = masks[a] + b = masks[b] + # Transform to real space + _m = mls0[m_key].array + _wm = heracles.transforms.cl2corr(_m).T[0] + wm[(a, b, i, j)] = _wm + cls = heracles.unmixing.natural_unmixing(cls0, mls0, fields) + _cls = heracles.unmixing._natural_unmixing(cls0, wm) for key in list(cls0.keys()): - cl = cls0[key].array + cl = cls[key].array _cl = _cls[key].array assert np.isclose(cl[2:], _cl[2:]).all() - -def test_polspice(cls0): - from heracles.dices.utils import get_cl - - cls = np.array( - [ - get_cl(("POS", "POS", 1, 1), cls0), - get_cl(("SHE", "SHE", 1, 1), cls0)[0, 0], - get_cl(("SHE", "SHE", 1, 1), cls0)[1, 1], - get_cl(("POS", "SHE", 1, 1), cls0)[0], - ] - ).T - corrs = heracles.cl2corr(cls) - _cls = heracles.corr2cl(corrs) - for cl, _cl in zip(cls.T, _cls.T): + # test logistic correction + wm = {} + m_keys = list(mls0.keys()) + for m_key in m_keys: + a, b, i, j = m_key + # Get corresponding mask keys + a = masks[a] + b = masks[b] + # Transform to real space + _m = mls0[m_key].array + _wm = heracles.transforms.cl2corr(_m).T[0] + _wm = np.abs(_wm) + _wm /= np.max(_wm) + wm[(a, b, i, j)] = _wm + _wm = heracles.correct_correlation(wm, rtol=_wm) + for key in wm.keys(): + _w = wm[key] + __w = _wm[key] + assert np.isclose(__w, _w).all() + + # test dices mask correction + alphas = dices.mask_correction(mls0, mls0, fields) + _cls = heracles.unmixing._natural_unmixing(cls0, alphas) + for key in list(cls0.keys()): + cl = cls0[key].array + _cl = _cls[key].array assert np.isclose(cl[2:], _cl[2:]).all() diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 6513713..e74cf0f 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -2,9 +2,9 @@ import heracles -def test_cl2corr(): - # Is there something more clever we can do here? - # Like transforming the legendre nodes and return ones? +def test_cl_transform(cls0): + from heracles.dices.utils import get_cl + cl = np.array( [ np.ones(10), @@ -16,8 +16,6 @@ def test_cl2corr(): corr = heracles.cl2corr(cl.T).T assert corr.shape == cl.shape - -def test_corr2cl(): corr = np.array( [ np.ones(10), @@ -28,3 +26,16 @@ def test_corr2cl(): ) cl = heracles.corr2cl(corr.T).T assert corr.shape == cl.shape + + cls = np.array( + [ + get_cl(("POS", "POS", 1, 1), cls0), + get_cl(("SHE", "SHE", 1, 1), cls0)[0, 0], + get_cl(("SHE", "SHE", 1, 1), cls0)[1, 1], + get_cl(("POS", "SHE", 1, 1), cls0)[0], + ] + ).T + corrs = heracles.cl2corr(cls) + _cls = heracles.corr2cl(corrs) + for cl, _cl in zip(cls.T, _cls.T): + assert np.isclose(cl[2:], _cl[2:]).all() diff --git a/tests/test_twopoint.py b/tests/test_twopoint.py index bc6c337..20dbf9b 100644 --- a/tests/test_twopoint.py +++ b/tests/test_twopoint.py @@ -402,3 +402,15 @@ def test_inverting_mixing_matrices(): _mixed_cls = np.sum(mixed_cls[key].array) print(key, _mixed_cls) np.testing.assert_allclose(_mixed_cls, (n + 1) * 1.0) + + # test options + options = {("POS", "POS", 1, 1): {"rtol": 1}} + opt_inv_mms = invert_mixing_matrix(mms, options=options) + opt_mixed_cls = apply_mixing_matrix(cls, opt_inv_mms) + for key in opt_mixed_cls: + if key in list(options.keys()): + assert np.all( + opt_mixed_cls[key].array == np.zeros_like(mixed_cls[key].array) + ) + else: + assert np.all(opt_mixed_cls[key].array == mixed_cls[key].array)