From a43dfc73254cad5bb7ae295243bdd1855673a1bf Mon Sep 17 00:00:00 2001 From: Sylwester Arabas Date: Wed, 28 May 2025 00:42:07 +0200 Subject: [PATCH 1/2] dimensionally_split mode fixes --- PyMPDATA/impl/formulae_antidiff.py | 2 +- PyMPDATA/impl/formulae_nonoscillatory.py | 17 ++++++++++------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/PyMPDATA/impl/formulae_antidiff.py b/PyMPDATA/impl/formulae_antidiff.py index 93cfb64a..1d0248b9 100644 --- a/PyMPDATA/impl/formulae_antidiff.py +++ b/PyMPDATA/impl/formulae_antidiff.py @@ -128,7 +128,7 @@ def antidiff_variants(psi, g_c, g_factor): result += tmp - if n_dims > 1: + if n_dims > 1 and not dimensionally_split: tmp = ( atv(*g_c, 1, 0.5) + atv(*g_c, 0, 0.5) diff --git a/PyMPDATA/impl/formulae_nonoscillatory.py b/PyMPDATA/impl/formulae_nonoscillatory.py index b99635f4..1bd072a9 100644 --- a/PyMPDATA/impl/formulae_nonoscillatory.py +++ b/PyMPDATA/impl/formulae_nonoscillatory.py @@ -27,7 +27,7 @@ def apply(_traversal_data, _psi_extrema, _psi): at_idx = INNER if traversals.n_dims == 1 else OUTER formulae = ( - __make_psi_extrema(options.jit_flags, traversals.n_dims, idx.ats[at_idx]), + __make_psi_extrema(options.jit_flags, traversals.n_dims, idx.ats[at_idx], options), None, None, ) @@ -55,8 +55,8 @@ def apply(traversals_data, psi_extrema, psi): return apply -def __make_psi_extrema(jit_flags, n_dims, ats): - if n_dims == 1: +def __make_psi_extrema(jit_flags, n_dims, ats, options): + if n_dims == 1 or options.dimensionally_split: @numba.njit(**jit_flags) def _impl(psi, extremum): @@ -118,7 +118,7 @@ def apply(_traversal_data, _beta, _flux, _psi, _psi_extrema, _g_factor): ats=idx.ats[at_idx], atv=idx.atv[at_idx], non_unit_g_factor=non_unit_g_factor, - epsilon=options.epsilon, + options=options, ), None, None, @@ -147,8 +147,11 @@ def apply(traversals_data, beta, flux, psi, psi_extrema, g_factor): return apply -def __make_beta(*, jit_flags, n_dims, ats, atv, non_unit_g_factor, epsilon): - if n_dims == 1: +def __make_beta(*, jit_flags, n_dims, ats, atv, non_unit_g_factor, options): + epsilon = options.epsilon + dimensionally_split = options.dimensionally_split + + if n_dims == 1 or dimensionally_split: @numba.njit(**jit_flags) def denominator(flux, sign): @@ -199,7 +202,7 @@ def g_fun(arg): def g_fun(_): return 1 - if n_dims == 1: + if n_dims == 1 or dimensionally_split: @numba.njit(**jit_flags) # pylint: disable=too-many-arguments From df612385fc2a8bab4835973206ba51d6b8af3b86 Mon Sep 17 00:00:00 2001 From: Sylwester Arabas Date: Wed, 28 May 2025 00:46:30 +0200 Subject: [PATCH 2/2] pre-commit --- PyMPDATA/impl/formulae_nonoscillatory.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/PyMPDATA/impl/formulae_nonoscillatory.py b/PyMPDATA/impl/formulae_nonoscillatory.py index 1bd072a9..b8f6b199 100644 --- a/PyMPDATA/impl/formulae_nonoscillatory.py +++ b/PyMPDATA/impl/formulae_nonoscillatory.py @@ -27,7 +27,9 @@ def apply(_traversal_data, _psi_extrema, _psi): at_idx = INNER if traversals.n_dims == 1 else OUTER formulae = ( - __make_psi_extrema(options.jit_flags, traversals.n_dims, idx.ats[at_idx], options), + __make_psi_extrema( + options.jit_flags, traversals.n_dims, idx.ats[at_idx], options + ), None, None, )