From c9955c379f99951f19a5d9d1edb859b69d848a1d Mon Sep 17 00:00:00 2001 From: Ben Dudson Date: Sun, 21 Jun 2026 13:55:08 -0700 Subject: [PATCH 1/7] Stencil expressions: DDZ_C2 and DDZ_C4 operators Return BinaryExpr types, perform derivatives in Z using CoordinateAccessor to get dz. Appear to inline in Hasegawa-Wakatani example. --- examples/hasegawa-wakatani/hw.cxx | 3 +- include/bout/stencil_expr.hxx | 80 +++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 1 deletion(-) create mode 100644 include/bout/stencil_expr.hxx diff --git a/examples/hasegawa-wakatani/hw.cxx b/examples/hasegawa-wakatani/hw.cxx index c3f8717597..7bea5938d8 100644 --- a/examples/hasegawa-wakatani/hw.cxx +++ b/examples/hasegawa-wakatani/hw.cxx @@ -3,6 +3,7 @@ #include #include #include +#include class HW : public PhysicsModel { private: @@ -110,7 +111,7 @@ class HW : public PhysicsModel { } ddt(n) = - -bracket(phi, n, bm) + alpha * (nonzonal_phi - nonzonal_n) - kappa * DDZ(phi); + -bracket(phi, n, bm) + alpha * (nonzonal_phi - nonzonal_n) - kappa * DDZ_C2(phi); ddt(vort) = -bracket(phi, vort, bm) + alpha * (nonzonal_phi - nonzonal_n); diff --git a/include/bout/stencil_expr.hxx b/include/bout/stencil_expr.hxx new file mode 100644 index 0000000000..1e67c4779f --- /dev/null +++ b/include/bout/stencil_expr.hxx @@ -0,0 +1,80 @@ +#pragma once +#ifndef BOUT_STENCIL_EXPR_HXX +#define BOUT_STENCIL_EXPR_HXX + +#include "bout/coordinates_accessor.hxx" +#include "bout/field3d.hxx" +#include "bout/fieldops.hxx" +#include "bout/single_index_ops.hxx" + +namespace bout::stencil { + +struct DDZ_C2_Op { + CoordinatesAccessor coords; + int nz{0}; + + template + BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal operator()(int idx, const LView& lhs, + const RView&) const { + const int izp = i_zp(idx, nz); + const int izm = i_zm(idx, nz); + + return 0.5 * (lhs(izp) - lhs(izm)) / coords.dz(idx); + } +}; + +struct DDZ_C4_Op { + CoordinatesAccessor coords; + int nz{0}; + + template + BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal operator()(int idx, const LView& lhs, + const RView&) const { + const int izp = i_zp(idx, nz); + const int izm = i_zm(idx, nz); + const int izp2 = i_zp(izp, nz); + const int izm2 = i_zm(izm, nz); + + return (-lhs(izp2) + 8.0 * lhs(izp) - 8.0 * lhs(izm) + lhs(izm2)) + / (12.0 * coords.dz(idx)); + } +}; + +using DDZExprC2 = BinaryExpr; +using DDZExprC4 = BinaryExpr; + +} // namespace bout::stencil + +inline bout::stencil::DDZExprC2 DDZ_C2(const Field3D& f) { + checkData(f); + + const auto region_id = f.getMesh()->getRegionID("RGN_NOBNDRY"); + + return bout::stencil::DDZExprC2{ + static_cast(f), + static_cast(f), + bout::stencil::DDZ_C2_Op{CoordinatesAccessor{f.getCoordinates()}, f.getNz()}, + f.getMesh(), + f.getLocation(), + f.getDirections(), + region_id, + f.getMesh()->getRegion("RGN_NOBNDRY")}; +} + +inline bout::stencil::DDZExprC4 DDZ_C4(const Field3D& f) { + checkData(f); + + const auto region_id = f.getMesh()->getRegionID("RGN_NOBNDRY"); + + return bout::stencil::DDZExprC4{ + static_cast(f), + static_cast(f), + bout::stencil::DDZ_C4_Op{CoordinatesAccessor{f.getCoordinates()}, f.getNz()}, + f.getMesh(), + f.getLocation(), + f.getDirections(), + region_id, + f.getMesh()->getRegion("RGN_NOBNDRY")}; +} + +#endif From acb30aba4e76e239c54ddbf9d5c73b882a76cbae Mon Sep 17 00:00:00 2001 From: Ben Dudson Date: Mon, 22 Jun 2026 15:02:03 -0700 Subject: [PATCH 2/7] Stencil expressions: bracket_arakawa Implements Arakawa bracket in X-Z as a BinaryExpr. --- examples/hasegawa-wakatani/hw.cxx | 6 ++-- include/bout/stencil_expr.hxx | 57 +++++++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/examples/hasegawa-wakatani/hw.cxx b/examples/hasegawa-wakatani/hw.cxx index 7bea5938d8..e72d844a00 100644 --- a/examples/hasegawa-wakatani/hw.cxx +++ b/examples/hasegawa-wakatani/hw.cxx @@ -110,10 +110,10 @@ class HW : public PhysicsModel { nonzonal_phi -= averageY(DC(phi)); } - ddt(n) = - -bracket(phi, n, bm) + alpha * (nonzonal_phi - nonzonal_n) - kappa * DDZ_C2(phi); + ddt(n) = -bracket_arakawa(phi, n) + alpha * (nonzonal_phi - nonzonal_n) + - kappa * DDZ_C2(phi); - ddt(vort) = -bracket(phi, vort, bm) + alpha * (nonzonal_phi - nonzonal_n); + ddt(vort) = -bracket_arakawa(phi, vort) + alpha * (nonzonal_phi - nonzonal_n); return 0; } diff --git a/include/bout/stencil_expr.hxx b/include/bout/stencil_expr.hxx index 1e67c4779f..807a1c3852 100644 --- a/include/bout/stencil_expr.hxx +++ b/include/bout/stencil_expr.hxx @@ -5,6 +5,7 @@ #include "bout/coordinates_accessor.hxx" #include "bout/field3d.hxx" #include "bout/fieldops.hxx" +#include "bout/mesh.hxx" #include "bout/single_index_ops.hxx" namespace bout::stencil { @@ -40,8 +41,44 @@ struct DDZ_C4_Op { } }; +struct BracketArakawaOp { + CoordinatesAccessor coords; + int ny{0}; + int nz{0}; + + template + BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal operator()(int idx, const LView& f, + const RView& g) const { + const int ixp = i_xp(idx, ny, nz); + const int ixm = i_xm(idx, ny, nz); + const int izp = i_zp(idx, nz); + const int izm = i_zm(idx, nz); + + const int izpxp = i_xp(izp, ny, nz); + const int izpxm = i_xm(izp, ny, nz); + const int izmxp = i_xp(izm, ny, nz); + const int izmxm = i_xm(izm, ny, nz); + + // J++ = DDZ(f)*DDX(g) - DDX(f)*DDZ(g) + const BoutReal Jpp = + ((f(izp) - f(izm)) * (g(ixp) - g(ixm)) - (f(ixp) - f(ixm)) * (g(izp) - g(izm))); + + // J+x + const BoutReal Jpx = + (g(ixp) * (f(izpxp) - f(izmxp)) - g(ixm) * (f(izpxm) - f(izmxm)) + - g(izp) * (f(izpxp) - f(izpxm)) + g(izm) * (f(izmxp) - f(izmxm))); + + // Jx+ + const BoutReal Jxp = (g(izpxp) * (f(izp) - f(ixp)) - g(izmxm) * (f(ixm) - f(izm)) + - g(izpxm) * (f(izp) - f(ixm)) + g(izmxp) * (f(ixp) - f(izm))); + + return (Jpp + Jpx + Jxp) / (12.0 * coords.dx(idx) * coords.dz(idx)); + } +}; + using DDZExprC2 = BinaryExpr; using DDZExprC4 = BinaryExpr; +using BracketArakawaExpr = BinaryExpr; } // namespace bout::stencil @@ -77,4 +114,24 @@ inline bout::stencil::DDZExprC4 DDZ_C4(const Field3D& f) { f.getMesh()->getRegion("RGN_NOBNDRY")}; } +inline bout::stencil::BracketArakawaExpr bracket_arakawa(const Field3D& f, + const Field3D& g) { + checkData(f); + checkData(g); + ASSERT1_FIELDS_COMPATIBLE(f, g); + + const auto region_id = f.getMesh()->getRegionID("RGN_NOBNDRY"); + + return bout::stencil::BracketArakawaExpr{ + static_cast(f), + static_cast(g), + bout::stencil::BracketArakawaOp{CoordinatesAccessor{f.getCoordinates()}, f.getNy(), + f.getNz()}, + f.getMesh(), + f.getLocation(), + f.getDirections(), + region_id, + f.getMesh()->getRegion("RGN_NOBNDRY")}; +} + #endif From 4fd34afd448135c0af48761face2e6cda99fb9bf Mon Sep 17 00:00:00 2001 From: Ben Dudson Date: Mon, 22 Jun 2026 22:54:25 -0700 Subject: [PATCH 3/7] stencil expr: DDX_Dispatch Testing the performance impact of runtime dispatch. The hope is to preserve the ability to switch method at runtime. The conditional that selects the method is the same for all iterations, so hopefully good branch prediction and no warp divergence. DDZ_Dispatch(f, method) selects between DDZ_C2 and DDZ_C4 at runtime. --- examples/hasegawa-wakatani/hw.cxx | 2 +- include/bout/stencil_expr.hxx | 61 +++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/examples/hasegawa-wakatani/hw.cxx b/examples/hasegawa-wakatani/hw.cxx index e72d844a00..39ffdd709c 100644 --- a/examples/hasegawa-wakatani/hw.cxx +++ b/examples/hasegawa-wakatani/hw.cxx @@ -111,7 +111,7 @@ class HW : public PhysicsModel { } ddt(n) = -bracket_arakawa(phi, n) + alpha * (nonzonal_phi - nonzonal_n) - - kappa * DDZ_C2(phi); + - kappa * DDZ_Dispatch(phi, DIFF_C2); ddt(vort) = -bracket_arakawa(phi, vort) + alpha * (nonzonal_phi - nonzonal_n); diff --git a/include/bout/stencil_expr.hxx b/include/bout/stencil_expr.hxx index 807a1c3852..9820012fa2 100644 --- a/include/bout/stencil_expr.hxx +++ b/include/bout/stencil_expr.hxx @@ -41,6 +41,44 @@ struct DDZ_C4_Op { } }; +struct DDZ_Dispatch_Op { + CoordinatesAccessor coords; + int nz{0}; + DIFF_METHOD method{DIFF_DEFAULT}; + + template + BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal apply_c2(int idx, const LView& lhs) const { + const int izp = i_zp(idx, nz); + const int izm = i_zm(idx, nz); + + return 0.5 * (lhs(izp) - lhs(izm)) / coords.dz(idx); + } + + template + BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal apply_c4(int idx, const LView& lhs) const { + const int izp = i_zp(idx, nz); + const int izm = i_zm(idx, nz); + const int izp2 = i_zp(izp, nz); + const int izm2 = i_zm(izm, nz); + + return (-lhs(izp2) + 8.0 * lhs(izp) - 8.0 * lhs(izm) + lhs(izm2)) + / (12.0 * coords.dz(idx)); + } + + template + BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal operator()(int idx, const LView& lhs, + const RView&) const { + switch (method) { + case DIFF_C2: + return apply_c2(idx, lhs); + case DIFF_C4: + return apply_c4(idx, lhs); + default: + return 0.0; + } + } +}; + struct BracketArakawaOp { CoordinatesAccessor coords; int ny{0}; @@ -78,6 +116,7 @@ struct BracketArakawaOp { using DDZExprC2 = BinaryExpr; using DDZExprC4 = BinaryExpr; +using DDZDispatchExpr = BinaryExpr; using BracketArakawaExpr = BinaryExpr; } // namespace bout::stencil @@ -114,6 +153,28 @@ inline bout::stencil::DDZExprC4 DDZ_C4(const Field3D& f) { f.getMesh()->getRegion("RGN_NOBNDRY")}; } +inline bout::stencil::DDZDispatchExpr DDZ_Dispatch(const Field3D& f, DIFF_METHOD method) { + checkData(f); + + if ((method != DIFF_C2) && (method != DIFF_C4)) { + throw BoutException("DDZ_Dispatch only supports DIFF_C2 and DIFF_C4, got {:s}", + toString(method)); + } + + const auto region_id = f.getMesh()->getRegionID("RGN_NOBNDRY"); + + return bout::stencil::DDZDispatchExpr{ + static_cast(f), + static_cast(f), + bout::stencil::DDZ_Dispatch_Op{CoordinatesAccessor{f.getCoordinates()}, f.getNz(), + method}, + f.getMesh(), + f.getLocation(), + f.getDirections(), + region_id, + f.getMesh()->getRegion("RGN_NOBNDRY")}; +} + inline bout::stencil::BracketArakawaExpr bracket_arakawa(const Field3D& f, const Field3D& g) { checkData(f); From 1238f4b908a2f20449ae83b2cea5ee0bf6854e03 Mon Sep 17 00:00:00 2001 From: Ben Dudson Date: Tue, 23 Jun 2026 12:36:42 -0700 Subject: [PATCH 4/7] pow: Convert to BinaryExpr lazy function --- include/bout/field.hxx | 163 ++++++++++++++++++---------- include/bout/field3d.hxx | 19 +++- include/bout/fieldops.hxx | 10 ++ src/field/field3d.cxx | 16 --- tests/unit/field/test_field2d.cxx | 25 +++++ tests/unit/field/test_field3d.cxx | 26 +++++ tests/unit/field/test_fieldperp.cxx | 27 +++++ 7 files changed, 213 insertions(+), 73 deletions(-) diff --git a/include/bout/field.hxx b/include/bout/field.hxx index a8e1952546..51281d08de 100644 --- a/include/bout/field.hxx +++ b/include/bout/field.hxx @@ -539,53 +539,127 @@ inline BoutReal mean(const BinaryExpr& f, bool allpe = false, return bout::reduce::Mean::finalize(state); } -/// Exponent: pow(lhs, lhs) is \p lhs raised to the power of \p rhs -/// -/// This loops over the entire domain, including guard/boundary cells by -/// default (can be changed using the \p rgn argument) -/// If CHECK >= 3 then the result will be checked for non-finite numbers -template > -T pow(const T& lhs, const T& rhs, const std::string& rgn = "RGN_ALL") { - - ASSERT1(areFieldsCompatible(lhs, rhs)); +class Field3DParallel; +class FieldPerp; - T result{emptyFrom(lhs)}; +namespace bout::detail { +template +struct expression_result { + using type = std::decay_t; +}; - BOUT_FOR(i, result.getRegion(rgn)) { result[i] = ::pow(lhs[i], rhs[i]); } +template +struct expression_result> { + using type = ResT; +}; - checkData(result); - return result; -} +template +using expression_result_t = typename expression_result>::type; -template > -T pow(const T& lhs, BoutReal rhs, const std::string& rgn = "RGN_ALL") { +template +inline constexpr bool is_expression_field_v = + is_expr_field2d_v || is_expr_field3d_v || is_expr_fieldperp_v; - // Check if the inputs are allocated - checkData(lhs); - checkData(rhs); +template +inline constexpr bool is_same_expression_rank_v = + (is_expr_field2d_v && is_expr_field2d_v) + || (is_expr_field3d_v && is_expr_field3d_v) + || (is_expr_fieldperp_v && is_expr_fieldperp_v); - T result{emptyFrom(lhs)}; +template +std::optional getPerpYIndex(const T& value) { + if constexpr (std::is_same_v, ::FieldPerp>) { + return value.getIndex(); + } else { + return std::nullopt; + } +} - BOUT_FOR(i, result.getRegion(rgn)) { result[i] = ::pow(lhs[i], rhs); } +template +std::optional getPerpYIndex(const BinaryExpr& expr) { + if constexpr (std::is_same_v) { + return expr.getIndex(); + } else { + return std::nullopt; + } +} - checkData(result); - return result; +template +std::optional getExpressionRegionID(const Expr& expr, + const std::string& region_name) { + std::optional region_id{}; + if (expr.getMesh()->hasRegion3D(region_name)) { + region_id = expr.getMesh()->getRegionID(region_name); + } + return expr.getMesh()->getCommonRegion(region_id, expr.getRegionID()); } -template > -T pow(BoutReal lhs, const T& rhs, const std::string& rgn = "RGN_ALL") { +template +std::optional getExpressionRegionID(const L& lhs, const R& rhs, + const std::string& region_name) { + return lhs.getMesh()->getCommonRegion(getExpressionRegionID(lhs, region_name), + rhs.getRegionID()); +} +} // namespace bout::detail - // Check if the inputs are allocated - checkData(lhs); - checkData(rhs); +/// Exponent: pow(lhs, lhs) is \p lhs raised to the power of \p rhs +/// +/// This loops over the entire domain, including guard/boundary cells by +/// default (can be changed using the \p rgn argument) +/// If CHECK >= 3 then the result will be checked for non-finite numbers +template , + typename = std::enable_if_t>> +auto pow(const L& lhs, const R& rhs, const std::string& rgn = "RGN_ALL") { - // Define and allocate the output result - T result{emptyFrom(rhs)}; + if constexpr (bout::utils::is_Field_v> + && bout::utils::is_Field_v>) { + ASSERT1(areFieldsCompatible(lhs, rhs)); + } else { + ASSERT1_EXPR_COMPATIBLE(lhs, rhs); + } + + return BinaryExpr{ + static_cast(lhs), + static_cast(rhs), + bout::op::Pow{}, + lhs.getMesh(), + lhs.getLocation(), + lhs.getDirections(), + bout::detail::getExpressionRegionID(lhs, rhs, rgn), + lhs.getMesh()->template getRegion(rgn), + bout::detail::getPerpYIndex(lhs)}; +} - BOUT_FOR(i, result.getRegion(rgn)) { result[i] = ::pow(lhs, rhs[i]); } +template , + typename = std::enable_if_t>> +auto pow(const T& lhs, BoutReal rhs, const std::string& rgn = "RGN_ALL") { + + return BinaryExpr, bout::op::Pow>{ + static_cast(lhs), + static_cast::View>(rhs), + bout::op::Pow{}, + lhs.getMesh(), + lhs.getLocation(), + lhs.getDirections(), + bout::detail::getExpressionRegionID(lhs, rgn), + lhs.getMesh()->template getRegion(rgn), + bout::detail::getPerpYIndex(lhs)}; +} - checkData(result); - return result; +template , + typename = std::enable_if_t>> +auto pow(BoutReal lhs, const T& rhs, const std::string& rgn = "RGN_ALL") { + + return BinaryExpr, T, bout::op::Pow>{ + static_cast::View>(lhs), + static_cast(rhs), + bout::op::Pow{}, + rhs.getMesh(), + rhs.getLocation(), + rhs.getDirections(), + bout::detail::getExpressionRegionID(rhs, rgn), + rhs.getMesh()->template getRegion(rgn), + bout::detail::getPerpYIndex(rhs)}; } /*! @@ -604,29 +678,6 @@ T pow(BoutReal lhs, const T& rhs, const std::string& rgn = "RGN_ALL") { * result for non-finite numbers * */ -class Field3DParallel; -class FieldPerp; - -namespace bout::detail { -template -std::optional getPerpYIndex(const T& value) { - if constexpr (std::is_same_v, ::FieldPerp>) { - return value.getIndex(); - } else { - return std::nullopt; - } -} - -template -std::optional getPerpYIndex(const BinaryExpr& expr) { - if constexpr (std::is_same_v) { - return expr.getIndex(); - } else { - return std::nullopt; - } -} -} // namespace bout::detail - #ifdef FIELD_FUNC #error This macro has already been defined #else diff --git a/include/bout/field3d.hxx b/include/bout/field3d.hxx index 09c7f1ce55..abcb52f892 100644 --- a/include/bout/field3d.hxx +++ b/include/bout/field3d.hxx @@ -900,7 +900,24 @@ inline auto operator-(const Field3D& f) { /// This loops over the entire domain, including guard/boundary cells by /// default (can be changed using the \p rgn argument). /// If CHECK >= 3 then the result will be checked for non-finite numbers -Field3D pow(const Field3D& lhs, const Field2D& rhs, const std::string& rgn = "RGN_ALL"); +template +std::enable_if_t && is_expr_field2d_v, + BinaryExpr> +pow(const L& lhs, const R& rhs, const std::string& rgn = "RGN_ALL") { + ASSERT1_EXPR_COMPATIBLE(lhs, rhs); + auto regionID = bout::detail::getExpressionRegionID(lhs, rhs, rgn); + int mesh_nz = lhs.getMesh()->LocalNz; + return BinaryExpr{ + static_cast(lhs), + static_cast(rhs).setScale(1, mesh_nz), + bout::op::Pow{}, + lhs.getMesh(), + lhs.getLocation(), + lhs.getDirections(), + regionID, + (regionID.has_value() ? lhs.getMesh()->getRegion(regionID.value()) + : lhs.getMesh()->getRegion(rgn))}; +} FieldPerp pow(const Field3D& lhs, const FieldPerp& rhs, const std::string& rgn = "RGN_ALL"); diff --git a/include/bout/fieldops.hxx b/include/bout/fieldops.hxx index 53e28042c9..1d3193f638 100644 --- a/include/bout/fieldops.hxx +++ b/include/bout/fieldops.hxx @@ -118,6 +118,16 @@ struct Div { return a / b; } }; +struct Pow { + template + BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal operator()(int idx, const LView& L, + const RView& R) const { + return ::pow(L(idx), R(idx)); + } + BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal operator()(BoutReal a, BoutReal b) const { + return ::pow(a, b); + } +}; struct IfElse { bool condition; diff --git a/src/field/field3d.cxx b/src/field/field3d.cxx index 7915579440..0d7687d0ca 100644 --- a/src/field/field3d.cxx +++ b/src/field/field3d.cxx @@ -688,22 +688,6 @@ void Field3D::swapData(Field3D& other) { std::swap(data, other.data); } //////////////// NON-MEMBER FUNCTIONS ////////////////// -Field3D pow(const Field3D& lhs, const Field2D& rhs, const std::string& rgn) { - - // Check if the inputs are allocated - checkData(lhs); - checkData(rhs); - ASSERT1_FIELDS_COMPATIBLE(lhs, rhs); - - // Define and allocate the output result - Field3D result{emptyFrom(lhs)}; - - BOUT_FOR(i, result.getRegion(rgn)) { result[i] = ::pow(lhs[i], rhs[i]); } - - checkData(result); - return result; -} - FieldPerp pow(const Field3D& lhs, const FieldPerp& rhs, const std::string& rgn) { checkData(lhs); diff --git a/tests/unit/field/test_field2d.cxx b/tests/unit/field/test_field2d.cxx index a91acd4a40..a5920d5d19 100644 --- a/tests/unit/field/test_field2d.cxx +++ b/tests/unit/field/test_field2d.cxx @@ -1168,6 +1168,31 @@ TEST_F(Field2DTest, PowField2DField2D) { EXPECT_TRUE(IsFieldEqual(c, 64.0)); } +TEST_F(Field2DTest, PowExpressionUsesBinaryExpr) { + Field2D field; + + field = 2.0; + const auto expr = field + 1.0; + + EXPECT_TRUE( + (std::is_same_v, + BinaryExpr, bout::op::Pow>>)); + EXPECT_TRUE((std::is_same_v, + BinaryExpr, + Constant, bout::op::Pow>>)); + EXPECT_TRUE(IsFieldEqual(pow(expr, 2.0), 9.0)); +} + +TEST_F(Field2DTest, PowRegionLimitedExpressionConstructsField2D) { + Field2D field; + + field = 2.0; + + Field2D result = pow(field, 2.0, "RGN_NOBNDRY"); + + EXPECT_TRUE(IsFieldEqual(result, 4.0, "RGN_NOBNDRY")); +} + TEST_F(Field2DTest, Sqrt) { Field2D field; diff --git a/tests/unit/field/test_field3d.cxx b/tests/unit/field/test_field3d.cxx index 905b182018..8502837642 100644 --- a/tests/unit/field/test_field3d.cxx +++ b/tests/unit/field/test_field3d.cxx @@ -1942,6 +1942,32 @@ TEST_F(Field3DTest, PowField3DField3D) { EXPECT_TRUE(IsFieldEqual(c, 64.0)); } +TEST_F(Field3DTest, PowExpressionUsesBinaryExpr) { + Field3D field; + + field = 2.0; + const auto expr = field + 1.0; + + EXPECT_TRUE( + (std::is_same_v, + BinaryExpr, bout::op::Pow>>)); + EXPECT_TRUE((std::is_same_v, + BinaryExpr, + Constant, bout::op::Pow>>)); + EXPECT_TRUE(IsFieldEqual(pow(expr, 2.0), 9.0)); +} + +TEST_F(Field3DTest, PowRegionArgumentSetsRegionID) { + Field3D field; + + field = 2.0; + const auto expr = pow(field, 2.0, "RGN_NOBNDRY"); + + ASSERT_TRUE(expr.getRegionID().has_value()); + EXPECT_EQ(expr.getRegionID().value(), field.getMesh()->getRegionID("RGN_NOBNDRY")); + EXPECT_TRUE(IsFieldEqual(expr, 4.0, "RGN_NOBNDRY")); +} + TEST_F(Field3DTest, Sqrt) { Field3D field; diff --git a/tests/unit/field/test_fieldperp.cxx b/tests/unit/field/test_fieldperp.cxx index 46f07d589f..60a15ea693 100644 --- a/tests/unit/field/test_fieldperp.cxx +++ b/tests/unit/field/test_fieldperp.cxx @@ -1569,6 +1569,33 @@ TEST_F(FieldPerpTest, PowFieldPerpFieldPerp) { EXPECT_TRUE(IsFieldEqual(c, 64.0)); } +TEST_F(FieldPerpTest, PowExpressionUsesBinaryExpr) { + FieldPerp field; + field.setIndex(0); + + field = 2.0; + const auto expr = field + 1.0; + + EXPECT_TRUE((std::is_same_v< + std::decay_t, + BinaryExpr, bout::op::Pow>>)); + EXPECT_TRUE((std::is_same_v, + BinaryExpr, + Constant, bout::op::Pow>>)); + EXPECT_TRUE(IsFieldEqual(pow(expr, 2.0), 9.0)); +} + +TEST_F(FieldPerpTest, PowRegionLimitedExpressionConstructsFieldPerp) { + FieldPerp field; + field.setIndex(0); + + field = 2.0; + + FieldPerp result = pow(field, 2.0, "RGN_NOBNDRY"); + + EXPECT_TRUE(IsFieldEqual(result, 4.0, "RGN_NOBNDRY")); +} + TEST_F(FieldPerpTest, Sqrt) { FieldPerp field; field.setIndex(0); From 19502b404423558f876b8a48e3141191f879daf0 Mon Sep 17 00:00:00 2001 From: Ben Dudson Date: Tue, 23 Jun 2026 13:15:32 -0700 Subject: [PATCH 5/7] DDZ_Dispatch: Extend to handle staggered locations Retains the runtime dispatch of DDZ, while building a BinaryExpr lazy expression. --- examples/hasegawa-wakatani/hw.cxx | 2 +- include/bout/stencil_expr.hxx | 96 +++++++++++++++++-- tests/unit/CMakeLists.txt | 1 + tests/unit/include/bout/test_stencil_expr.cxx | 78 +++++++++++++++ 4 files changed, 168 insertions(+), 9 deletions(-) create mode 100644 tests/unit/include/bout/test_stencil_expr.cxx diff --git a/examples/hasegawa-wakatani/hw.cxx b/examples/hasegawa-wakatani/hw.cxx index 39ffdd709c..42dd923c59 100644 --- a/examples/hasegawa-wakatani/hw.cxx +++ b/examples/hasegawa-wakatani/hw.cxx @@ -111,7 +111,7 @@ class HW : public PhysicsModel { } ddt(n) = -bracket_arakawa(phi, n) + alpha * (nonzonal_phi - nonzonal_n) - - kappa * DDZ_Dispatch(phi, DIFF_C2); + - kappa * DDZ_Dispatch(phi, CELL_DEFAULT, DIFF_C2); ddt(vort) = -bracket_arakawa(phi, vort) + alpha * (nonzonal_phi - nonzonal_n); diff --git a/include/bout/stencil_expr.hxx b/include/bout/stencil_expr.hxx index 9820012fa2..7e031305f4 100644 --- a/include/bout/stencil_expr.hxx +++ b/include/bout/stencil_expr.hxx @@ -2,12 +2,18 @@ #ifndef BOUT_STENCIL_EXPR_HXX #define BOUT_STENCIL_EXPR_HXX +#include "bout/bout_types.hxx" +#include "bout/boutexception.hxx" +#include "bout/build_config.hxx" #include "bout/coordinates_accessor.hxx" +#include "bout/field.hxx" #include "bout/field3d.hxx" #include "bout/fieldops.hxx" #include "bout/mesh.hxx" #include "bout/single_index_ops.hxx" +#include + namespace bout::stencil { struct DDZ_C2_Op { @@ -44,10 +50,12 @@ struct DDZ_C4_Op { struct DDZ_Dispatch_Op { CoordinatesAccessor coords; int nz{0}; + STAGGER stagger{STAGGER::None}; DIFF_METHOD method{DIFF_DEFAULT}; template - BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal apply_c2(int idx, const LView& lhs) const { + BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal apply_c2_none(int idx, + const LView& lhs) const { const int izp = i_zp(idx, nz); const int izm = i_zm(idx, nz); @@ -55,7 +63,8 @@ struct DDZ_Dispatch_Op { } template - BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal apply_c4(int idx, const LView& lhs) const { + BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal apply_c4_none(int idx, + const LView& lhs) const { const int izp = i_zp(idx, nz); const int izm = i_zm(idx, nz); const int izp2 = i_zp(izp, nz); @@ -65,6 +74,70 @@ struct DDZ_Dispatch_Op { / (12.0 * coords.dz(idx)); } + template + BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal apply_c2_c2l(int idx, + const LView& lhs) const { + const int izm = i_zm(idx, nz); + + return (lhs(idx) - lhs(izm)) / coords.dz(idx); + } + + template + BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal apply_c4_c2l(int idx, + const LView& lhs) const { + const int izm = i_zm(idx, nz); + const int izp = i_zp(idx, nz); + const int izm2 = i_zm(izm, nz); + + return (27.0 * (lhs(idx) - lhs(izm)) - (lhs(izp) - lhs(izm2))) + / (24.0 * coords.dz(idx)); + } + + template + BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal apply_c2_l2c(int idx, + const LView& lhs) const { + const int izp = i_zp(idx, nz); + + return (lhs(izp) - lhs(idx)) / coords.dz(idx); + } + + template + BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal apply_c4_l2c(int idx, + const LView& lhs) const { + const int izp = i_zp(idx, nz); + const int izm = i_zm(idx, nz); + const int izp2 = i_zp(izp, nz); + + return (27.0 * (lhs(izp) - lhs(idx)) - (lhs(izp2) - lhs(izm))) + / (24.0 * coords.dz(idx)); + } + + template + BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal apply_c2(int idx, const LView& lhs) const { + switch (stagger) { + case STAGGER::None: + return apply_c2_none(idx, lhs); + case STAGGER::C2L: + return apply_c2_c2l(idx, lhs); + case STAGGER::L2C: + return apply_c2_l2c(idx, lhs); + } + return 0.0; + } + + template + BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal apply_c4(int idx, const LView& lhs) const { + switch (stagger) { + case STAGGER::None: + return apply_c4_none(idx, lhs); + case STAGGER::C2L: + return apply_c4_c2l(idx, lhs); + case STAGGER::L2C: + return apply_c4_l2c(idx, lhs); + } + return 0.0; + } + template BOUT_HOST_DEVICE BOUT_FORCEINLINE BoutReal operator()(int idx, const LView& lhs, const RView&) const { @@ -153,7 +226,10 @@ inline bout::stencil::DDZExprC4 DDZ_C4(const Field3D& f) { f.getMesh()->getRegion("RGN_NOBNDRY")}; } -inline bout::stencil::DDZDispatchExpr DDZ_Dispatch(const Field3D& f, DIFF_METHOD method) { +inline bout::stencil::DDZDispatchExpr +DDZ_Dispatch(const Field3D& f, CELL_LOC outloc = CELL_DEFAULT, + DIFF_METHOD method = DIFF_DEFAULT, + const std::string& region = "RGN_NOBNDRY") { checkData(f); if ((method != DIFF_C2) && (method != DIFF_C4)) { @@ -161,18 +237,22 @@ inline bout::stencil::DDZDispatchExpr DDZ_Dispatch(const Field3D& f, DIFF_METHOD toString(method)); } - const auto region_id = f.getMesh()->getRegionID("RGN_NOBNDRY"); + const auto resolved_outloc = (outloc == CELL_DEFAULT) ? f.getLocation() : outloc; + const auto stagger = + f.getMesh()->getStagger(f.getLocation(), resolved_outloc, CELL_ZLOW); + const auto region_id = f.getMesh()->getRegionID(region); return bout::stencil::DDZDispatchExpr{ static_cast(f), static_cast(f), - bout::stencil::DDZ_Dispatch_Op{CoordinatesAccessor{f.getCoordinates()}, f.getNz(), - method}, + bout::stencil::DDZ_Dispatch_Op{ + CoordinatesAccessor{f.getCoordinates(resolved_outloc)}, f.getNz(), stagger, + method}, f.getMesh(), - f.getLocation(), + resolved_outloc, f.getDirections(), region_id, - f.getMesh()->getRegion("RGN_NOBNDRY")}; + f.getMesh()->getRegion(region)}; } inline bout::stencil::BracketArakawaExpr bracket_arakawa(const Field3D& f, diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 4963aaf1f0..0e87a4d406 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -80,6 +80,7 @@ set(serial_tests_source ./include/bout/test_petsc_vector.cxx ./include/bout/test_region.cxx ./include/bout/test_single_index_ops.cxx + ./include/bout/test_stencil_expr.cxx ./include/bout/test_stencil.cxx ./include/bout/test_template_combinations.cxx ./include/bout/test_traits.cxx diff --git a/tests/unit/include/bout/test_stencil_expr.cxx b/tests/unit/include/bout/test_stencil_expr.cxx new file mode 100644 index 0000000000..898d9d5694 --- /dev/null +++ b/tests/unit/include/bout/test_stencil_expr.cxx @@ -0,0 +1,78 @@ +#include "gtest/gtest.h" + +#include "test_extras.hxx" + +#include "bout/derivs.hxx" +#include "bout/stencil_expr.hxx" + +#include + +#include "fake_mesh_fixture.hxx" + +namespace { + +Field3D makeTestField(Mesh* mesh, CELL_LOC location) { + Field3D result(mesh, location); + result.allocate(); + + for (auto i : result.getRegion("RGN_ALL")) { + result[i] = 0.1 * i.x() + 0.2 * i.y() + 0.3 * i.z() + 0.05 * i.x() * i.z(); + } + + return result; +} + +class DDZDispatchExprTest : public FakeMeshFixture {}; + +class DDZDispatchExprParamTest + : public DDZDispatchExprTest, + public ::testing::WithParamInterface> { +}; + +std::string paramToString( + const ::testing::TestParamInfo>& param) { + const auto [inloc, outloc, method] = param.param; + return toString(inloc) + "_to_" + toString(outloc) + "_" + toString(method); +} + +} // namespace + +INSTANTIATE_TEST_SUITE_P( + SupportedLocations, DDZDispatchExprParamTest, + ::testing::Values(std::make_tuple(CELL_CENTRE, CELL_CENTRE, DIFF_C2), + std::make_tuple(CELL_CENTRE, CELL_CENTRE, DIFF_C4), + std::make_tuple(CELL_CENTRE, CELL_ZLOW, DIFF_C2), + std::make_tuple(CELL_CENTRE, CELL_ZLOW, DIFF_C4), + std::make_tuple(CELL_ZLOW, CELL_CENTRE, DIFF_C2), + std::make_tuple(CELL_ZLOW, CELL_CENTRE, DIFF_C4), + std::make_tuple(CELL_ZLOW, CELL_ZLOW, DIFF_C2), + std::make_tuple(CELL_ZLOW, CELL_ZLOW, DIFF_C4)), + paramToString); + +TEST_P(DDZDispatchExprParamTest, MatchesDDZ) { + const auto [inloc, outloc, method] = GetParam(); + auto input = makeTestField(mesh_staggered, inloc); + + const auto actual = Field3D{DDZ_Dispatch(input, outloc, method)}; + const auto expected = DDZ(input, outloc, toString(method)); + + EXPECT_EQ(actual.getLocation(), outloc); + EXPECT_TRUE(IsFieldEqual(actual, expected, "RGN_NOBNDRY")); +} + +TEST_F(DDZDispatchExprTest, UsesRequestedRegion) { + auto input = makeTestField(mesh_staggered, CELL_CENTRE); + + const auto actual = Field3D{DDZ_Dispatch(input, CELL_ZLOW, DIFF_C2, "RGN_ALL")}; + const auto expected = DDZ(input, CELL_ZLOW, toString(DIFF_C2), "RGN_ALL"); + + EXPECT_EQ(actual.getLocation(), CELL_ZLOW); + EXPECT_TRUE(IsFieldEqual(actual, expected, "RGN_ALL")); +} + +TEST_F(DDZDispatchExprTest, RejectsUnsupportedMethods) { + auto input = makeTestField(mesh_staggered, CELL_CENTRE); + + EXPECT_THROW((void)DDZ_Dispatch(input, CELL_DEFAULT, DIFF_DEFAULT), BoutException); + EXPECT_THROW((void)DDZ_Dispatch(input, CELL_DEFAULT, DIFF_FFT), BoutException); +} From 6c15db49d40b4661eb62012ae17925df31454ab0 Mon Sep 17 00:00:00 2001 From: Ben Dudson Date: Tue, 23 Jun 2026 14:24:08 -0700 Subject: [PATCH 6/7] Mesh::DerivativeDefaults default method Stores the default numerical methods as DIFF_METHOD enums. This enables runtime dispatch with minimal overhead in expressions. --- examples/hasegawa-wakatani/hw.cxx | 2 +- include/bout/mesh.hxx | 97 +++++++++++++++++++ include/bout/stencil_expr.hxx | 11 ++- src/mesh/index_derivs.cxx | 94 ++++++++++++++++++ tests/unit/include/bout/test_stencil_expr.cxx | 17 +++- tests/unit/mesh/test_mesh.cxx | 38 ++++++++ 6 files changed, 254 insertions(+), 5 deletions(-) diff --git a/examples/hasegawa-wakatani/hw.cxx b/examples/hasegawa-wakatani/hw.cxx index 42dd923c59..f4e546908f 100644 --- a/examples/hasegawa-wakatani/hw.cxx +++ b/examples/hasegawa-wakatani/hw.cxx @@ -111,7 +111,7 @@ class HW : public PhysicsModel { } ddt(n) = -bracket_arakawa(phi, n) + alpha * (nonzonal_phi - nonzonal_n) - - kappa * DDZ_Dispatch(phi, CELL_DEFAULT, DIFF_C2); + - kappa * DDZ_Dispatch(phi); ddt(vort) = -bracket_arakawa(phi, vort) + alpha * (nonzonal_phi - nonzonal_n); diff --git a/include/bout/mesh.hxx b/include/bout/mesh.hxx index 929e9d5aa7..2cef181157 100644 --- a/include/bout/mesh.hxx +++ b/include/bout/mesh.hxx @@ -59,6 +59,7 @@ class Mesh; #include "bout/sys/range.hxx" // RangeIterator #include "bout/unused.hxx" +#include #include #include #include @@ -95,6 +96,94 @@ using comm_handle = void*; class Mesh { public: + struct DerivativeDefaults { + static constexpr int num_directions = 3; + static constexpr int num_staggers = 3; + static constexpr int num_deriv_kinds = 5; + + std::array values{}; + + DerivativeDefaults() { + for (auto direction : {DIRECTION::X, DIRECTION::Y, DIRECTION::Z}) { + for (auto stagger : {STAGGER::None, STAGGER::C2L, STAGGER::L2C}) { + for (auto deriv : {DERIV::Standard, DERIV::StandardSecond, + DERIV::StandardFourth, DERIV::Upwind, DERIV::Flux}) { + set(direction, deriv, builtinDefaultMethod(deriv), stagger); + } + } + } + } + + static DIFF_METHOD builtinDefaultMethod(DERIV deriv) { + switch (deriv) { + case DERIV::Standard: + case DERIV::StandardSecond: + case DERIV::StandardFourth: + return DIFF_C2; + case DERIV::Upwind: + case DERIV::Flux: + return DIFF_U1; + } + throw BoutException("Unhandled derivative kind in builtinDefaultMethod"); + } + + static int directionIndex(DIRECTION direction) { + switch (direction) { + case DIRECTION::X: + return 0; + case DIRECTION::Y: + case DIRECTION::YOrthogonal: + case DIRECTION::YAligned: + return 1; + case DIRECTION::Z: + return 2; + } + throw BoutException("Unhandled direction in DerivativeDefaults"); + } + + static int staggerIndex(STAGGER stagger) { + switch (stagger) { + case STAGGER::None: + return 0; + case STAGGER::C2L: + return 1; + case STAGGER::L2C: + return 2; + } + throw BoutException("Unhandled stagger in DerivativeDefaults"); + } + + static int derivIndex(DERIV deriv) { + switch (deriv) { + case DERIV::Standard: + return 0; + case DERIV::StandardSecond: + return 1; + case DERIV::StandardFourth: + return 2; + case DERIV::Upwind: + return 3; + case DERIV::Flux: + return 4; + } + throw BoutException("Unhandled derivative kind in DerivativeDefaults"); + } + + DIFF_METHOD get(DIRECTION direction, DERIV deriv, + STAGGER stagger = STAGGER::None) const { + return values[(directionIndex(direction) * num_staggers + staggerIndex(stagger)) + * num_deriv_kinds + + derivIndex(deriv)]; + } + + void set(DIRECTION direction, DERIV deriv, DIFF_METHOD method, + STAGGER stagger = STAGGER::None) { + values[(directionIndex(direction) * num_staggers + staggerIndex(stagger)) + * num_deriv_kinds + + derivIndex(deriv)] = method; + } + }; + /// Constructor for a "bare", uninitialised Mesh /// Only useful for testing Mesh() : source(nullptr), options(nullptr), include_corner_cells(true) {} @@ -719,6 +808,14 @@ public: /// Fraction of modes to filter. This is set in derivs_init from option "ddz:fft_filter" BoutReal fft_derivs_filter{0.0}; + /// Concrete default methods initialised from options in derivs_init + DerivativeDefaults derivative_defaults{}; + + DIFF_METHOD getDefaultMethod(DIRECTION direction, DERIV deriv, + STAGGER stagger = STAGGER::None) const { + return derivative_defaults.get(direction, deriv, stagger); + } + /// Determines the resultant output stagger location in derivatives /// given the input and output location. Also checks that the /// combination of locations is allowed diff --git a/include/bout/stencil_expr.hxx b/include/bout/stencil_expr.hxx index 7e031305f4..bbccc267bf 100644 --- a/include/bout/stencil_expr.hxx +++ b/include/bout/stencil_expr.hxx @@ -232,14 +232,19 @@ DDZ_Dispatch(const Field3D& f, CELL_LOC outloc = CELL_DEFAULT, const std::string& region = "RGN_NOBNDRY") { checkData(f); + const auto resolved_outloc = (outloc == CELL_DEFAULT) ? f.getLocation() : outloc; + const auto stagger = + f.getMesh()->getStagger(f.getLocation(), resolved_outloc, CELL_ZLOW); + + if (method == DIFF_DEFAULT) { + method = f.getMesh()->getDefaultMethod(DIRECTION::Z, DERIV::Standard, stagger); + } + if ((method != DIFF_C2) && (method != DIFF_C4)) { throw BoutException("DDZ_Dispatch only supports DIFF_C2 and DIFF_C4, got {:s}", toString(method)); } - const auto resolved_outloc = (outloc == CELL_DEFAULT) ? f.getLocation() : outloc; - const auto stagger = - f.getMesh()->getStagger(f.getLocation(), resolved_outloc, CELL_ZLOW); const auto region_id = f.getMesh()->getRegionID(region); return bout::stencil::DDZDispatchExpr{ diff --git a/src/mesh/index_derivs.cxx b/src/mesh/index_derivs.cxx index 70fc47b538..f01df89b68 100644 --- a/src/mesh/index_derivs.cxx +++ b/src/mesh/index_derivs.cxx @@ -23,10 +23,58 @@ #include "bout/build_defines.hxx" #include "bout/traits.hxx" +#include "bout/utils.hxx" #include #include #include +#include +#include + +namespace { + +DIFF_METHOD parseConcreteDiffMethod(const std::string& method_name) { + const auto method = uppercase(method_name); + + if (method == "U1") { + return DIFF_U1; + } + if (method == "U2") { + return DIFF_U2; + } + if (method == "C2") { + return DIFF_C2; + } + if (method == "W2") { + return DIFF_W2; + } + if (method == "W3") { + return DIFF_W3; + } + if (method == "C4") { + return DIFF_C4; + } + if (method == "U3") { + return DIFF_U3; + } + if (method == "FFT") { + return DIFF_FFT; + } + if (method == "SPLIT") { + return DIFF_SPLIT; + } + if (method == "S2") { + return DIFF_S2; + } + if (method == "DEFAULT") { + throw BoutException("Default derivative options must resolve to a concrete method"); + } + + throw BoutException("Unknown differential method '{:s}'", method_name); +} + +} // namespace + /******************************************************************************* * Helper routines *******************************************************************************/ @@ -37,6 +85,52 @@ void Mesh::derivs_init(Options* options) { // of derivative. DerivativeStore::getInstance().initialise(options); DerivativeStore::getInstance().initialise(options); + + auto backup_section = options->getSection("diff"); + const std::array, 3> directions{{ + {DIRECTION::X, "ddx"}, + {DIRECTION::Y, "ddy"}, + {DIRECTION::Z, "ddz"}, + }}; + const std::array, 5> deriv_types{{ + {DERIV::Standard, "first"}, + {DERIV::StandardSecond, "second"}, + {DERIV::StandardFourth, "fourth"}, + {DERIV::Upwind, "upwind"}, + {DERIV::Flux, "flux"}, + }}; + + derivative_defaults = DerivativeDefaults{}; + + for (const auto& [direction, section_name] : directions) { + auto specific_section = options->getSection(section_name); + auto staggered_section = options->getSection(section_name + "stag"); + + for (const auto& [deriv, option_name] : deriv_types) { + auto default_method = DerivativeDefaults::builtinDefaultMethod(deriv); + auto default_name = toString(default_method); + + if (specific_section->isSet(option_name)) { + specific_section->get(option_name, default_name, default_name); + } else if (backup_section->isSet(option_name)) { + backup_section->get(option_name, default_name, default_name); + } + + default_method = parseConcreteDiffMethod(default_name); + derivative_defaults.set(direction, deriv, default_method, STAGGER::None); + + auto staggered_name = toString(default_method); + auto staggered_method = default_method; + if (staggered_section->isSet(option_name)) { + staggered_section->get(option_name, staggered_name, staggered_name); + staggered_method = parseConcreteDiffMethod(staggered_name); + } + + derivative_defaults.set(direction, deriv, staggered_method, STAGGER::C2L); + derivative_defaults.set(direction, deriv, staggered_method, STAGGER::L2C); + } + } + // Get the fraction of modes filtered out in FFT derivatives options->getSection("ddz")->get("fft_filter", fft_derivs_filter, 0.0); } diff --git a/tests/unit/include/bout/test_stencil_expr.cxx b/tests/unit/include/bout/test_stencil_expr.cxx index 898d9d5694..9d3154c29a 100644 --- a/tests/unit/include/bout/test_stencil_expr.cxx +++ b/tests/unit/include/bout/test_stencil_expr.cxx @@ -7,6 +7,7 @@ #include +#include "fake_mesh.hxx" #include "fake_mesh_fixture.hxx" namespace { @@ -73,6 +74,20 @@ TEST_F(DDZDispatchExprTest, UsesRequestedRegion) { TEST_F(DDZDispatchExprTest, RejectsUnsupportedMethods) { auto input = makeTestField(mesh_staggered, CELL_CENTRE); - EXPECT_THROW((void)DDZ_Dispatch(input, CELL_DEFAULT, DIFF_DEFAULT), BoutException); EXPECT_THROW((void)DDZ_Dispatch(input, CELL_DEFAULT, DIFF_FFT), BoutException); } + +TEST_F(DDZDispatchExprTest, ResolvesDefaultMethodFromMesh) { + auto input = makeTestField(mesh_staggered, CELL_CENTRE); + + Options diff_options{{"ddz", {{"first", "C4"}}}, {"ddzstag", {{"first", "C2"}}}}; + static_cast(mesh_staggered)->initDerivs(&diff_options); + + const auto centre_actual = Field3D{DDZ_Dispatch(input, CELL_CENTRE, DIFF_DEFAULT)}; + const auto centre_expected = DDZ(input, CELL_CENTRE, toString(DIFF_C4)); + EXPECT_TRUE(IsFieldEqual(centre_actual, centre_expected, "RGN_NOBNDRY")); + + const auto staggered_actual = Field3D{DDZ_Dispatch(input, CELL_ZLOW, DIFF_DEFAULT)}; + const auto staggered_expected = DDZ(input, CELL_ZLOW, toString(DIFF_C2)); + EXPECT_TRUE(IsFieldEqual(staggered_actual, staggered_expected, "RGN_NOBNDRY")); +} diff --git a/tests/unit/mesh/test_mesh.cxx b/tests/unit/mesh/test_mesh.cxx index 8d64f1d0a3..0b30176704 100644 --- a/tests/unit/mesh/test_mesh.cxx +++ b/tests/unit/mesh/test_mesh.cxx @@ -201,6 +201,44 @@ TEST_F(MeshTest, GetStringNoSource) { EXPECT_EQ(string_value, ""); } +TEST_F(MeshTest, GetDefaultMethodUsesBuiltinDefaults) { + Options options; + + localmesh.initDerivs(&options); + + EXPECT_EQ(localmesh.getDefaultMethod(DIRECTION::X, DERIV::Standard), DIFF_C2); + EXPECT_EQ(localmesh.getDefaultMethod(DIRECTION::Y, DERIV::StandardSecond), DIFF_C2); + EXPECT_EQ(localmesh.getDefaultMethod(DIRECTION::Z, DERIV::StandardFourth), DIFF_C2); + EXPECT_EQ(localmesh.getDefaultMethod(DIRECTION::Z, DERIV::Upwind), DIFF_U1); + EXPECT_EQ(localmesh.getDefaultMethod(DIRECTION::Z, DERIV::Flux), DIFF_U1); + EXPECT_EQ(localmesh.getDefaultMethod(DIRECTION::Z, DERIV::Standard, STAGGER::C2L), + DIFF_C2); + EXPECT_EQ(localmesh.getDefaultMethod(DIRECTION::Z, DERIV::Standard, STAGGER::L2C), + DIFF_C2); +} + +TEST_F(MeshTest, GetDefaultMethodReadsDirectionAndStaggeredOptions) { + Options options{{"diff", {{"first", "W2"}, {"upwind", "U2"}}}, + {"ddz", {{"first", "C4"}}}, + {"ddzstag", {{"first", "S2"}}}}; + + localmesh.initDerivs(&options); + + EXPECT_EQ(localmesh.getDefaultMethod(DIRECTION::X, DERIV::Standard), DIFF_W2); + EXPECT_EQ(localmesh.getDefaultMethod(DIRECTION::Y, DERIV::Upwind), DIFF_U2); + EXPECT_EQ(localmesh.getDefaultMethod(DIRECTION::Z, DERIV::Standard), DIFF_C4); + EXPECT_EQ(localmesh.getDefaultMethod(DIRECTION::Z, DERIV::Standard, STAGGER::C2L), + DIFF_S2); + EXPECT_EQ(localmesh.getDefaultMethod(DIRECTION::Z, DERIV::Standard, STAGGER::L2C), + DIFF_S2); +} + +TEST_F(MeshTest, GetDefaultMethodRejectsDefaultOptionValue) { + Options options{{"ddz", {{"first", "DEFAULT"}}}}; + + EXPECT_THROW(localmesh.initDerivs(&options), BoutException); +} + TEST_F(MeshTest, GetStringNoSourceWithDefault) { std::string string_value; const std::string default_value = "some default"; From 38fb3c29b5fad7faa0e169ef932d95bf4b1a3871 Mon Sep 17 00:00:00 2001 From: Ben Dudson Date: Tue, 23 Jun 2026 15:27:49 -0700 Subject: [PATCH 7/7] Replace DDZ(Field3D) with BinaryExpr implementation Retains runtime choice of method, and handling of staggered inputs or outputs. Only C2 and C4 methods are supported. --- examples/hasegawa-wakatani/hw.cxx | 4 +- examples/performance/ddz/ddz.cxx | 12 +--- include/bout/derivs.hxx | 29 ++++----- include/bout/stencil_expr.hxx | 62 +++++++++++++++++-- .../laplace/impls/naulin/naulin_laplace.cxx | 4 +- src/sys/derivs.cxx | 46 +++++++++++--- tests/unit/include/bout/test_stencil_expr.cxx | 16 ++--- 7 files changed, 120 insertions(+), 53 deletions(-) diff --git a/examples/hasegawa-wakatani/hw.cxx b/examples/hasegawa-wakatani/hw.cxx index f4e546908f..45864fa7eb 100644 --- a/examples/hasegawa-wakatani/hw.cxx +++ b/examples/hasegawa-wakatani/hw.cxx @@ -110,8 +110,8 @@ class HW : public PhysicsModel { nonzonal_phi -= averageY(DC(phi)); } - ddt(n) = -bracket_arakawa(phi, n) + alpha * (nonzonal_phi - nonzonal_n) - - kappa * DDZ_Dispatch(phi); + ddt(n) = + -bracket_arakawa(phi, n) + alpha * (nonzonal_phi - nonzonal_n) - kappa * DDZ(phi); ddt(vort) = -bracket_arakawa(phi, vort) + alpha * (nonzonal_phi - nonzonal_n); diff --git a/examples/performance/ddz/ddz.cxx b/examples/performance/ddz/ddz.cxx index 5d3bf2c961..0b1be902f2 100644 --- a/examples/performance/ddz/ddz.cxx +++ b/examples/performance/ddz/ddz.cxx @@ -65,17 +65,9 @@ int main(int argc, char** argv) { // Nested loops over block data ITERATOR_TEST_BLOCK("DDZ Default", result = DDZ(a);); - ITERATOR_TEST_BLOCK("DDZ C2", result = DDZ(a, CELL_DEFAULT, "DIFF_C2");); + ITERATOR_TEST_BLOCK("DDZ C2", result = DDZ(a, CELL_DEFAULT, DIFF_C2);); - ITERATOR_TEST_BLOCK("DDZ C4", result = DDZ(a, CELL_DEFAULT, "DIFF_C4");); - - ITERATOR_TEST_BLOCK("DDZ S2", result = DDZ(a, CELL_DEFAULT, "DIFF_S2");); - - ITERATOR_TEST_BLOCK("DDZ W2", result = DDZ(a, CELL_DEFAULT, "DIFF_W2");); - - ITERATOR_TEST_BLOCK("DDZ W3", result = DDZ(a, CELL_DEFAULT, "DIFF_W3");); - - ITERATOR_TEST_BLOCK("DDZ FFT", result = DDZ(a, CELL_DEFAULT, "DIFF_FFT");); + ITERATOR_TEST_BLOCK("DDZ C4", result = DDZ(a, CELL_DEFAULT, DIFF_C4);); if (profileMode) { int nthreads = 0; diff --git a/include/bout/derivs.hxx b/include/bout/derivs.hxx index a8d9279378..5adb58b544 100644 --- a/include/bout/derivs.hxx +++ b/include/bout/derivs.hxx @@ -31,6 +31,7 @@ #include "bout/field2d.hxx" #include "bout/field3d.hxx" +#include "bout/stencil_expr.hxx" #include "bout/vector2d.hxx" #include "bout/vector3d.hxx" @@ -102,22 +103,6 @@ Coordinates::FieldMetric DDY(const Field2D& f, CELL_LOC outloc = CELL_DEFAULT, const std::string& method = "DEFAULT", const std::string& region = "RGN_NOBNDRY"); -/// Calculate first partial derivative in Z -/// -/// \f$\partial / \partial z\f$ -/// -/// @param[in] f The field to be differentiated -/// @param[in] outloc The cell location where the result is desired. If -/// staggered grids is not enabled then this has no effect -/// If not given, defaults to CELL_DEFAULT -/// @param[in] method Differencing method to use. This overrides the default -/// If not given, defaults to DIFF_DEFAULT -/// @param[in] region What region is expected to be calculated -/// If not given, defaults to RGN_NOBNDRY -Field3D DDZ(const Field3D& f, CELL_LOC outloc = CELL_DEFAULT, - const std::string& method = "DEFAULT", - const std::string& region = "RGN_NOBNDRY"); - /// Calculate first partial derivative in Z /// /// \f$\partial / \partial z\f$ @@ -147,7 +132,11 @@ Coordinates::FieldMetric DDZ(const Field2D& f, CELL_LOC outloc = CELL_DEFAULT, /// @param[in] region What region is expected to be calculated /// If not given, defaults to RGN_NOBNDRY Vector3D DDZ(const Vector3D& f, CELL_LOC outloc = CELL_DEFAULT, - const std::string& method = "DEFAULT", + DIFF_METHOD method = DIFF_DEFAULT, + const std::string& region = "RGN_NOBNDRY"); + +/// Compatibility overload for string-based callers. +Vector3D DDZ(const Vector3D& f, CELL_LOC outloc, const std::string& method, const std::string& region = "RGN_NOBNDRY"); /// Calculate first partial derivative in Z @@ -163,7 +152,11 @@ Vector3D DDZ(const Vector3D& f, CELL_LOC outloc = CELL_DEFAULT, /// @param[in] region What region is expected to be calculated /// If not given, defaults to RGN_NOBNDRY Vector2D DDZ(const Vector2D& f, CELL_LOC outloc = CELL_DEFAULT, - const std::string& method = "DEFAULT", + DIFF_METHOD method = DIFF_DEFAULT, + const std::string& region = "RGN_NOBNDRY"); + +/// Compatibility overload for string-based callers. +Vector2D DDZ(const Vector2D& f, CELL_LOC outloc, const std::string& method, const std::string& region = "RGN_NOBNDRY"); ////////// SECOND DERIVATIVES ////////// diff --git a/include/bout/stencil_expr.hxx b/include/bout/stencil_expr.hxx index bbccc267bf..8542fc06dc 100644 --- a/include/bout/stencil_expr.hxx +++ b/include/bout/stencil_expr.hxx @@ -11,6 +11,7 @@ #include "bout/fieldops.hxx" #include "bout/mesh.hxx" #include "bout/single_index_ops.hxx" +#include "bout/utils.hxx" #include @@ -194,6 +195,53 @@ using BracketArakawaExpr = BinaryExprgetRegion("RGN_NOBNDRY")}; } -inline bout::stencil::DDZDispatchExpr -DDZ_Dispatch(const Field3D& f, CELL_LOC outloc = CELL_DEFAULT, - DIFF_METHOD method = DIFF_DEFAULT, - const std::string& region = "RGN_NOBNDRY") { +inline bout::stencil::DDZDispatchExpr DDZ(const Field3D& f, + CELL_LOC outloc = CELL_DEFAULT, + DIFF_METHOD method = DIFF_DEFAULT, + const std::string& region = "RGN_NOBNDRY") { checkData(f); const auto resolved_outloc = (outloc == CELL_DEFAULT) ? f.getLocation() : outloc; @@ -260,6 +308,12 @@ DDZ_Dispatch(const Field3D& f, CELL_LOC outloc = CELL_DEFAULT, f.getMesh()->getRegion(region)}; } +inline bout::stencil::DDZDispatchExpr DDZ(const Field3D& f, CELL_LOC outloc, + const std::string& method, + const std::string& region = "RGN_NOBNDRY") { + return DDZ(f, outloc, parseDDZMethodString(method), region); +} + inline bout::stencil::BracketArakawaExpr bracket_arakawa(const Field3D& f, const Field3D& g) { checkData(f); diff --git a/src/invert/laplace/impls/naulin/naulin_laplace.cxx b/src/invert/laplace/impls/naulin/naulin_laplace.cxx index ae8e78d1ff..a323bc6c08 100644 --- a/src/invert/laplace/impls/naulin/naulin_laplace.cxx +++ b/src/invert/laplace/impls/naulin/naulin_laplace.cxx @@ -247,7 +247,7 @@ Field3D LaplaceNaulin::solve(const Field3D& rhs, const Field3D& x0) { Field3D coef_y = DDY(C2coef, location, "C2") / C1TimesD; // z-component of 1./(C1*D) * Grad_perp(C2) - Field3D coef_z = DDZ(C2coef, location, "FFT") / C1TimesD; + Field3D coef_z = DDZ(C2coef, location, DIFF_C4) / C1TimesD; Field3D AOverD = Acoef / Dcoef; @@ -286,7 +286,7 @@ Field3D LaplaceNaulin::solve(const Field3D& rhs, const Field3D& x0) { auto calc_b_guess = [&](const Field3D& x_in) { // Derivatives of x Field3D ddx_x = DDX(x_in, location, "C2"); - Field3D ddz_x = DDZ(x_in, location, "FFT"); + Field3D ddz_x = DDZ(x_in, location, DIFF_C4); return rhsOverD - (coords->g11 * coef_x_AC * ddx_x + coords->g33 * coef_z * ddz_x + coords->g13 * (coef_x_AC * ddz_x + coef_z * ddx_x)) diff --git a/src/sys/derivs.cxx b/src/sys/derivs.cxx index e449dbcd30..b0a5b14388 100644 --- a/src/sys/derivs.cxx +++ b/src/sys/derivs.cxx @@ -86,12 +86,6 @@ Coordinates::FieldMetric DDY(const Field2D& f, CELL_LOC outloc, const std::strin ////////////// Z DERIVATIVE ///////////////// -Field3D DDZ(const Field3D& f, CELL_LOC outloc, const std::string& method, - const std::string& region) { - return bout::derivatives::index::DDZ(f, outloc, method, region) - / f.getCoordinates(outloc)->dz; -} - Coordinates::FieldMetric DDZ(const Field2D& f, CELL_LOC UNUSED(outloc), const std::string& UNUSED(method), const std::string& UNUSED(region)) { @@ -100,7 +94,7 @@ Coordinates::FieldMetric DDZ(const Field2D& f, CELL_LOC UNUSED(outloc), return tmp; } -Vector3D DDZ(const Vector3D& v, CELL_LOC outloc, const std::string& method, +Vector3D DDZ(const Vector3D& v, CELL_LOC outloc, DIFF_METHOD method, const std::string& region) { Vector3D result(v.getMesh()); const Coordinates* metric = v.x.getCoordinates(outloc); @@ -131,8 +125,37 @@ Vector3D DDZ(const Vector3D& v, CELL_LOC outloc, const std::string& method, return result; } -Vector2D DDZ(const Vector2D& v, CELL_LOC UNUSED(outloc), - const std::string& UNUSED(method), const std::string& UNUSED(region)) { +Vector3D DDZ(const Vector3D& v, CELL_LOC outloc, const std::string& method, + const std::string& region) { + Vector3D result(v.getMesh()); + const Coordinates* metric = v.x.getCoordinates(outloc); + + if (v.covariant) { + result.x = DDZ(v.x, outloc, method, region) - v.x * metric->G1_13 + - v.y * metric->G2_13 - v.z * metric->G3_13; + result.y = DDZ(v.y, outloc, method, region) - v.x * metric->G1_23 + - v.y * metric->G2_23 - v.z * metric->G3_23; + result.z = DDZ(v.z, outloc, method, region) - v.x * metric->G1_33 + - v.y * metric->G2_33 - v.z * metric->G3_33; + result.covariant = true; + } else { + result.x = DDZ(v.x, outloc, method, region) + v.x * metric->G1_13 + + v.y * metric->G1_23 + v.z * metric->G1_33; + result.y = DDZ(v.y, outloc, method, region) + v.x * metric->G2_13 + + v.y * metric->G2_23 + v.z * metric->G2_33; + result.z = DDZ(v.z, outloc, method, region) + v.x * metric->G3_13 + + v.y * metric->G3_23 + v.z * metric->G3_33; + result.covariant = false; + } + + ASSERT2(((outloc == CELL_DEFAULT) && (result.getLocation() == v.getLocation())) + || (result.getLocation() == outloc)); + + return result; +} + +Vector2D DDZ(const Vector2D& v, CELL_LOC UNUSED(outloc), DIFF_METHOD UNUSED(method), + const std::string& UNUSED(region)) { Vector2D result(v.getMesh()); result.covariant = v.covariant; @@ -147,6 +170,11 @@ Vector2D DDZ(const Vector2D& v, CELL_LOC UNUSED(outloc), return result; } +Vector2D DDZ(const Vector2D& v, CELL_LOC outloc, const std::string& method, + const std::string& region) { + return DDZ(v, outloc, parseDDZMethodString(method), region); +} + /******************************************************************************* * 2nd derivative *******************************************************************************/ diff --git a/tests/unit/include/bout/test_stencil_expr.cxx b/tests/unit/include/bout/test_stencil_expr.cxx index 9d3154c29a..e2e36e3b00 100644 --- a/tests/unit/include/bout/test_stencil_expr.cxx +++ b/tests/unit/include/bout/test_stencil_expr.cxx @@ -54,7 +54,7 @@ TEST_P(DDZDispatchExprParamTest, MatchesDDZ) { const auto [inloc, outloc, method] = GetParam(); auto input = makeTestField(mesh_staggered, inloc); - const auto actual = Field3D{DDZ_Dispatch(input, outloc, method)}; + const auto actual = Field3D{DDZ(input, outloc, method)}; const auto expected = DDZ(input, outloc, toString(method)); EXPECT_EQ(actual.getLocation(), outloc); @@ -64,8 +64,8 @@ TEST_P(DDZDispatchExprParamTest, MatchesDDZ) { TEST_F(DDZDispatchExprTest, UsesRequestedRegion) { auto input = makeTestField(mesh_staggered, CELL_CENTRE); - const auto actual = Field3D{DDZ_Dispatch(input, CELL_ZLOW, DIFF_C2, "RGN_ALL")}; - const auto expected = DDZ(input, CELL_ZLOW, toString(DIFF_C2), "RGN_ALL"); + const auto actual = Field3D{DDZ(input, CELL_ZLOW, DIFF_C2, "RGN_ALL")}; + const auto expected = DDZ(input, CELL_ZLOW, "C2", "RGN_ALL"); EXPECT_EQ(actual.getLocation(), CELL_ZLOW); EXPECT_TRUE(IsFieldEqual(actual, expected, "RGN_ALL")); @@ -74,7 +74,7 @@ TEST_F(DDZDispatchExprTest, UsesRequestedRegion) { TEST_F(DDZDispatchExprTest, RejectsUnsupportedMethods) { auto input = makeTestField(mesh_staggered, CELL_CENTRE); - EXPECT_THROW((void)DDZ_Dispatch(input, CELL_DEFAULT, DIFF_FFT), BoutException); + EXPECT_THROW((void)DDZ(input, CELL_DEFAULT, DIFF_FFT), BoutException); } TEST_F(DDZDispatchExprTest, ResolvesDefaultMethodFromMesh) { @@ -83,11 +83,11 @@ TEST_F(DDZDispatchExprTest, ResolvesDefaultMethodFromMesh) { Options diff_options{{"ddz", {{"first", "C4"}}}, {"ddzstag", {{"first", "C2"}}}}; static_cast(mesh_staggered)->initDerivs(&diff_options); - const auto centre_actual = Field3D{DDZ_Dispatch(input, CELL_CENTRE, DIFF_DEFAULT)}; - const auto centre_expected = DDZ(input, CELL_CENTRE, toString(DIFF_C4)); + const auto centre_actual = Field3D{DDZ(input, CELL_CENTRE, DIFF_DEFAULT)}; + const auto centre_expected = DDZ(input, CELL_CENTRE, "C4"); EXPECT_TRUE(IsFieldEqual(centre_actual, centre_expected, "RGN_NOBNDRY")); - const auto staggered_actual = Field3D{DDZ_Dispatch(input, CELL_ZLOW, DIFF_DEFAULT)}; - const auto staggered_expected = DDZ(input, CELL_ZLOW, toString(DIFF_C2)); + const auto staggered_actual = Field3D{DDZ(input, CELL_ZLOW, DIFF_DEFAULT)}; + const auto staggered_expected = DDZ(input, CELL_ZLOW, "C2"); EXPECT_TRUE(IsFieldEqual(staggered_actual, staggered_expected, "RGN_NOBNDRY")); }