diff --git a/src/solver/impls/snes/snes.cxx b/src/solver/impls/snes/snes.cxx index 446903dd4f..2fc8b8ef5a 100644 --- a/src/solver/impls/snes/snes.cxx +++ b/src/solver/impls/snes/snes.cxx @@ -604,7 +604,9 @@ SNESSolver::SNESSolver(Options* opts) .withDefault(false)), asinh_vars((*options)["asinh_vars"] .doc("Apply asinh() to all variables?") - .withDefault(false)) {} + .withDefault(false)) { + has_constraints = true; ///< This solver can handle constraints +} int SNESSolver::init() { @@ -627,6 +629,29 @@ int SNESSolver::init() { output_info.write("\t3d fields = {:d}, 2d fields = {:d} neq={:d}, local_N={:d}\n", n3Dvars(), n2Dvars(), neq, nlocal); + // Check if there are any constraints + have_constraints = false; + + for (int i = 0; i < n2Dvars(); i++) { + if (f2d[i].constraint) { + have_constraints = true; + break; + } + } + for (int i = 0; i < n3Dvars(); i++) { + if (f3d[i].constraint) { + have_constraints = true; + break; + } + } + + if (have_constraints) { + is_dae.reallocate(nlocal); + // Call the Solver function, which sets the array + // to one when not a constraint, zero for constraint + set_id(std::begin(is_dae)); + } + // Initialise PETSc components // Vectors @@ -673,6 +698,34 @@ int SNESSolver::init() { PetscCall(initPseudoTimestepping()); } + if (have_constraints) { + // CreatePETSc-native index sets representing the two parts of your DAE. + PetscInt istart, iend; + PetscCall(VecGetOwnershipRange(snes_x, &istart, &iend)); + ASSERT2(iend - istart == nlocal); + + std::vector diff_idx; + std::vector alg_idx; + diff_idx.reserve(nlocal); + alg_idx.reserve(nlocal); + + for (PetscInt i = 0; i < nlocal; ++i) { + const PetscInt gi = istart + i; + if (is_dae[i] > 0.5) { // differential + diff_idx.push_back(gi); + } else { // algebraic constraint (i.e. phi) + alg_idx.push_back(gi); + } + } + + PetscCall(ISCreateGeneral(BoutComm::get(), diff_idx.size(), diff_idx.data(), + PETSC_COPY_VALUES, &is_diff)); + PetscCall(ISCreateGeneral(BoutComm::get(), alg_idx.size(), alg_idx.data(), + PETSC_COPY_VALUES, &is_alg)); + + have_is_maps = true; + } + // Nonlinear solver interface (SNES) output_info.write("Create SNES\n"); SNESCreate(BoutComm::get(), &snes); @@ -753,6 +806,9 @@ int SNESSolver::init() { SNESSetForceIteration(snes, PETSC_TRUE); #endif + // Enable checking for domain errors in Jacobian evaluation + SNESSetCheckJacobianDomainError(snes, PETSC_TRUE); + // Get KSP context from SNES KSP ksp; SNESGetKSP(snes, &ksp); @@ -805,6 +861,24 @@ int SNESSolver::init() { } } + if (have_constraints && have_is_maps && !matrix_free && pc_type == "fieldsplit") { + output_info.write("Using PCFieldSplit preconditioner for DAE system\n"); + + // Use PETSc fieldsplit + PetscCall(PCSetType(pc, PCFIELDSPLIT)); + + // Give PETSc the index sets + PetscCall(PCFieldSplitSetIS(pc, "diff", is_diff)); + PetscCall(PCFieldSplitSetIS(pc, "alg", is_alg)); + + // Let the user configure from options (recommended) + // Example options you can set in input file: + // -pc_fieldsplit_type additive + // -fieldsplit_alg_pc_type hypre -fieldsplit_alg_pc_hypre_type boomeramg + // -fieldsplit_diff_pc_type ilu + // + } + // Get runtime options lib.setOptionsFromInputFile(snes); @@ -1162,7 +1236,7 @@ int SNESSolver::run() { timestep = pid(timestep, nl_its, max_timestep); - // NOTE(malamast): Do we really need this? + // NOTE: Do we really need this? // Recompute Jacobian (for now) if (saved_jacobian_lag == 0) { SNESGetLagJacobian(snes, &saved_jacobian_lag); @@ -1262,6 +1336,8 @@ int SNESSolver::run() { run_rhs(target); // Run RHS to calculate auxilliary variables } catch (BoutException& e) { output_error.write("ERROR: BoutException thrown: {}\n", e.what()); + // NOTE: what happens if we hit the exception here? + // Should we add a relaxation step to update the state vector? } if (call_monitors(target, s, getNumberOutputSteps()) != 0) { @@ -1573,7 +1649,13 @@ PetscErrorCode SNESSolver::snes_function(Vec x, Vec f, bool linear) { // Call the RHS function if (rhs_function(x, f, linear) != PETSC_SUCCESS) { // Tell SNES that the input was out of domain - SNESSetFunctionDomainError(snes); + if (linear) { + // During Jacobian evaluation + SNESSetJacobianDomainError(snes); + } else { + // During function evaluation + SNESSetFunctionDomainError(snes); + } // Note: Returning non-zero error here leaves vectors in locked state return 0; } @@ -1597,10 +1679,33 @@ PetscErrorCode SNESSolver::snes_function(Vec x, Vec f, bool linear) { break; } case BoutSnesEquationForm::backward_euler: { - // Backward Euler - // Set f = x - x0 - Δt*f - VecAYPX(f, -dt, x); // f <- x - Δt*f - VecAXPY(f, -1.0, x0); // f <- f - x0 + // Backward Euler: + // Differential: F = x - x0 - dt*f + // Algebraic: F = G(x) (already stored in f by rhs_function) + + if (!have_constraints) { + + VecAYPX(f, -dt, x); // f <- x - Δt*f + VecAXPY(f, -1.0, x0); // f <- f - x0 + + } else { + + ASSERT2(have_is_maps); + // Some constraints + + Vec x_diff, x0_diff, f_diff; + + PetscCall(VecGetSubVector(x, is_diff, &x_diff)); + PetscCall(VecGetSubVector(x0, is_diff, &x0_diff)); + PetscCall(VecGetSubVector(f, is_diff, &f_diff)); + + PetscCall(VecAYPX(f_diff, -dt, x_diff)); // f_diff <- x_diff - dt*f_diff + PetscCall(VecAXPY(f_diff, -1.0, x0_diff)); // f_diff <- f_diff - x0_diff + + PetscCall(VecRestoreSubVector(x, is_diff, &x_diff)); + PetscCall(VecRestoreSubVector(x0, is_diff, &x0_diff)); + PetscCall(VecRestoreSubVector(f, is_diff, &f_diff)); + } break; } case BoutSnesEquationForm::direct_newton: { diff --git a/src/solver/impls/snes/snes.hxx b/src/solver/impls/snes/snes.hxx index 5d007f8ad7..1ea06f6481 100644 --- a/src/solver/impls/snes/snes.hxx +++ b/src/solver/impls/snes/snes.hxx @@ -182,6 +182,13 @@ private: int nlocal; ///< Number of variables on local processor int neq; ///< Number of variables in total + bool have_constraints; ///< Are there any constraint variables? + Array is_dae; ///< If using constraints, 1 -> DAE, 0 -> AE + + IS is_diff = nullptr; // is_dae == 1 + IS is_alg = nullptr; // is_dae == 0 (phi constraint and any other algebraics) + bool have_is_maps = false; + PetscLib lib; ///< Handles initialising, finalising PETSc Vec snes_f; ///< Used by SNES to store function Vec snes_x; ///< Result of SNES