diff --git a/PyMPDATA/impl/formulae_antidiff.py b/PyMPDATA/impl/formulae_antidiff.py index 93cfb64a..1d0248b9 100644 --- a/PyMPDATA/impl/formulae_antidiff.py +++ b/PyMPDATA/impl/formulae_antidiff.py @@ -128,7 +128,7 @@ def antidiff_variants(psi, g_c, g_factor): result += tmp - if n_dims > 1: + if n_dims > 1 and not dimensionally_split: tmp = ( atv(*g_c, 1, 0.5) + atv(*g_c, 0, 0.5) diff --git a/PyMPDATA/impl/formulae_nonoscillatory.py b/PyMPDATA/impl/formulae_nonoscillatory.py index b99635f4..b8f6b199 100644 --- a/PyMPDATA/impl/formulae_nonoscillatory.py +++ b/PyMPDATA/impl/formulae_nonoscillatory.py @@ -27,7 +27,9 @@ def apply(_traversal_data, _psi_extrema, _psi): at_idx = INNER if traversals.n_dims == 1 else OUTER formulae = ( - __make_psi_extrema(options.jit_flags, traversals.n_dims, idx.ats[at_idx]), + __make_psi_extrema( + options.jit_flags, traversals.n_dims, idx.ats[at_idx], options + ), None, None, ) @@ -55,8 +57,8 @@ def apply(traversals_data, psi_extrema, psi): return apply -def __make_psi_extrema(jit_flags, n_dims, ats): - if n_dims == 1: +def __make_psi_extrema(jit_flags, n_dims, ats, options): + if n_dims == 1 or options.dimensionally_split: @numba.njit(**jit_flags) def _impl(psi, extremum): @@ -118,7 +120,7 @@ def apply(_traversal_data, _beta, _flux, _psi, _psi_extrema, _g_factor): ats=idx.ats[at_idx], atv=idx.atv[at_idx], non_unit_g_factor=non_unit_g_factor, - epsilon=options.epsilon, + options=options, ), None, None, @@ -147,8 +149,11 @@ def apply(traversals_data, beta, flux, psi, psi_extrema, g_factor): return apply -def __make_beta(*, jit_flags, n_dims, ats, atv, non_unit_g_factor, epsilon): - if n_dims == 1: +def __make_beta(*, jit_flags, n_dims, ats, atv, non_unit_g_factor, options): + epsilon = options.epsilon + dimensionally_split = options.dimensionally_split + + if n_dims == 1 or dimensionally_split: @numba.njit(**jit_flags) def denominator(flux, sign): @@ -199,7 +204,7 @@ def g_fun(arg): def g_fun(_): return 1 - if n_dims == 1: + if n_dims == 1 or dimensionally_split: @numba.njit(**jit_flags) # pylint: disable=too-many-arguments