Skip to content

Commit 3edd349

Browse files
committed
two new callbacks and minor cleanup
1 parent f3673a7 commit 3edd349

7 files changed

Lines changed: 374 additions & 267 deletions

File tree

docs/source/newtons_method.md

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,8 @@ The following are the callback containers used:
219219
class SolverCallbacks
220220
{
221221
std::vector<std::function<void()>> before_energy_evaluation;
222+
std::vector<std::function<void()>> before_step;
223+
std::vector<std::function<void()>> after_step;
222224
std::vector<std::function<bool()>> is_initial_state_valid;
223225
std::vector<std::function<bool()>> is_intermediate_state_valid;
224226
std::vector<std::function<void()>> on_intermediate_state_invalid;
@@ -230,17 +232,19 @@ class SolverCallbacks
230232
```
231233
232234
Brief descriptions:
233-
- `before_energy_evaluation`: Will be executed before evaluating any value or derivative. Example: it can be used to update contact pairs.
234-
- `is_initial_state_valid`: Exits the Newton process before it starts. Example: Element inversions in the rest pose.
235+
- `before_step`: Executed at the very start of each Newton iteration. Example: logging, resetting per-iteration state.
236+
- `before_energy_evaluation`: Executed before evaluating any value or derivative each iteration. Example: update contact pairs.
237+
- `after_step`: Executed at the end of each Newton iteration, including failed/early-exit iterations. Guaranteed to be called once for every `before_step` invocation. Example: per-iteration diagnostics, adaptive parameter updates.
238+
- `is_initial_state_valid`: Exits the Newton process before it starts. Example: element inversions in the rest pose.
235239
- `is_intermediate_state_valid`: Used to backtrack invalid states in the line search. Example: element inversions as a consequence of a step.
236-
- `on_intermediate_state_invalid`: Used if the invalid state could not be backtracked after the afforded number of iterations. Example: increase penalty stiffness.
237-
- `on_armijo_fail`: Analogous if Armijo's backtracking fails. Example: restart the solve with a shorter time step size.
240+
- `on_intermediate_state_invalid`: Called when the invalid state could not be backtracked. Example: increase penalty stiffness.
241+
- `on_armijo_fail`: Called when Armijo backtracking fails. Example: restart the solve with a shorter time step size.
238242
- `is_converged`: User's custom convergence criteria.
239-
- `is_converged_state_valid`: Let Newton's Method return non-success right before return. Example: penalties are too soft, tighten them before requesting repeating the solve.
240-
- `max_allowed_step`: Allows the user to specify a max step for the current iteration. Example: CCD.
243+
- `is_converged_state_valid`: Checked right before returning success — lets the solver return non-success if the converged state is unacceptable. Example: tighten constraints before requesting a repeat solve.
244+
- `max_allowed_step`: Returns the maximum step length for the current iteration. Example: CCD.
241245
242-
Callbacks can be appended to those lists by using the corresponding `.add_<callback_name>(T f)` method, e.g. `.add_before_energy_evaluation(std::function<void()> f)`.
243-
`SolverCallbacks` can be accessed in `NewtonsMethod.callbacks`.
246+
Callbacks can be appended via the corresponding `.add_<callback_name>(T f)` method, e.g. `.add_before_energy_evaluation(std::function<void()> f)`.
247+
`SolverCallbacks` can be accessed via `NewtonsMethod.callbacks`.
244248
245249
Further, a custom residual can be specified by `.set_residual(std::function<double(Eigen::VectorXd&)> f)`.
246250

symx/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ set(SOURCE_FILES
8888
src/solver/NewtonsMethod.cpp
8989
src/solver/gnuplot.h
9090
src/solver/gnuplot.cpp
91+
src/solver/solver_types.h
92+
src/solver/solver_settings.h
93+
src/solver/SolverCallbacks.h
9194
src/solver/solver_utils.h
9295
src/solver/fem_integrators.h
9396
src/solver/fem_integrators.cpp

symx/src/solver/NewtonsMethod.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,15 @@ SolverReturn NewtonsMethod::solve()
8383
while (result == SolverReturn::Running) {
8484
newton_iteration++;
8585

86+
// Before-step callback (called for every iteration, including failed ones)
87+
this->callbacks->run_before_step();
88+
89+
// After-step callback (RAII)
90+
struct AfterNewtonStepGuard {
91+
spSolverCallbacks& callbacks;
92+
~AfterNewtonStepGuard() { callbacks->run_after_step(); }
93+
} _after_step_guard{this->callbacks};
94+
8695
// Check maximum iterations
8796
if (newton_iteration == this->settings.max_iterations) {
8897
this->output->print_with_new_line("Newton failure: Too many iterations.", Verbosity::Medium);

symx/src/solver/SolverCallbacks.h

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
#pragma once
2+
#include <functional>
3+
#include <algorithm>
4+
#include <vector>
5+
6+
#include <Eigen/Dense>
7+
8+
#include "solver_types.h"
9+
#include "Context.h"
10+
11+
namespace symx
12+
{
13+
static auto default_residual = [](const Eigen::VectorXd& r) { return r.cwiseAbs().maxCoeff(); };
14+
15+
/*
16+
Holds user-registered callbacks that the Newton solver invokes
17+
at well-defined points during the solve. Callbacks are registered
18+
via add_*() and run via run_*() (called internally by the solver).
19+
*/
20+
class SolverCallbacks
21+
{
22+
private:
23+
/* Fields */
24+
std::vector<std::function<void()>> before_energy_evaluation;
25+
std::vector<std::function<void()>> before_step;
26+
std::vector<std::function<void()>> after_step;
27+
std::vector<std::function<bool()>> is_initial_state_valid;
28+
std::vector<std::function<bool()>> is_intermediate_state_valid;
29+
std::vector<std::function<void()>> on_intermediate_state_invalid;
30+
std::vector<std::function<void()>> on_armijo_fail;
31+
std::vector<std::function<bool()>> is_converged;
32+
std::vector<std::function<bool()>> is_converged_state_valid;
33+
std::vector<std::function<double()>> max_allowed_step;
34+
std::function<double(const Eigen::VectorXd&)> residual = default_residual;
35+
spContext context = nullptr;
36+
37+
/* Internal helpers */
38+
void _run(const std::vector<std::function<void()>>& fs)
39+
{
40+
for (auto& f : fs) {
41+
f();
42+
}
43+
}
44+
bool _run_bool(bool default_bool, const std::vector<std::function<bool()>>& fs)
45+
{
46+
bool valid = default_bool;
47+
for (auto& f : fs) {
48+
valid = valid && f();
49+
}
50+
return valid;
51+
}
52+
53+
public:
54+
/* Construction */
55+
SolverCallbacks(spContext context) : context(context) {}
56+
static std::shared_ptr<SolverCallbacks> create(spContext context) { return std::make_shared<SolverCallbacks>(context); }
57+
58+
// ---- Register callbacks ----
59+
60+
void set_residual(std::function<double(const Eigen::VectorXd&)> f) { this->residual = f; }
61+
62+
/// Called once per Newton iteration, before energy/gradient/Hessian evaluation.
63+
void add_before_energy_evaluation(std::function<void()> f) { this->before_energy_evaluation.push_back(f); }
64+
65+
/// Called at the very start of each Newton iteration (iteration index is 0-based).
66+
void add_before_step(std::function<void()> f) { this->before_step.push_back(f); }
67+
68+
/// Called at the end of each Newton iteration (including failed/early-exit iterations).
69+
void add_after_step(std::function<void()> f) { this->after_step.push_back(f); }
70+
71+
void add_is_initial_state_valid(std::function<bool()> f) { this->is_initial_state_valid.push_back(f); }
72+
void add_is_intermediate_state_valid(std::function<bool()> f) { this->is_intermediate_state_valid.push_back(f); }
73+
void add_on_intermediate_state_invalid(std::function<void()> f) { this->on_intermediate_state_invalid.push_back(f); }
74+
void add_on_armijo_fail(std::function<void()> f) { this->on_armijo_fail.push_back(f); }
75+
void add_is_converged(std::function<bool()> f) { this->is_converged.push_back(f); }
76+
void add_is_converged_state_valid(std::function<bool()> f) { this->is_converged_state_valid.push_back(f); }
77+
void add_max_allowed_step(std::function<double()> f) { this->max_allowed_step.push_back(f); }
78+
79+
// ---- Invoke callbacks (called by the solver) ----
80+
81+
void run_before_energy_evaluation()
82+
{
83+
auto _t = this->context->logger->time("before_energy_evaluation");
84+
this->_run(this->before_energy_evaluation);
85+
}
86+
void run_before_step()
87+
{
88+
auto _t = this->context->logger->time("before_step");
89+
this->_run(this->before_step);
90+
}
91+
void run_after_step()
92+
{
93+
auto _t = this->context->logger->time("after_step");
94+
this->_run(this->after_step);
95+
}
96+
bool run_is_initial_state_valid()
97+
{
98+
auto _t = this->context->logger->time("is_initial_state_valid");
99+
return this->_run_bool(true, this->is_initial_state_valid);
100+
}
101+
bool run_is_intermediate_state_valid()
102+
{
103+
auto _t = this->context->logger->time("is_intermediate_state_valid");
104+
return this->_run_bool(true, this->is_intermediate_state_valid);
105+
}
106+
void run_on_intermediate_state_invalid()
107+
{
108+
auto _t = this->context->logger->time("on_intermediate_state_invalid");
109+
this->_run(this->on_intermediate_state_invalid);
110+
}
111+
void run_on_armijo_fail()
112+
{
113+
auto _t = this->context->logger->time("on_armijo_fail");
114+
this->_run(this->on_armijo_fail);
115+
}
116+
bool run_is_converged()
117+
{
118+
auto _t = this->context->logger->time("is_converged");
119+
bool converged = false;
120+
for (auto& f : this->is_converged) {
121+
converged = f() || converged;
122+
}
123+
return converged;
124+
}
125+
bool run_is_converged_state_valid()
126+
{
127+
auto _t = this->context->logger->time("is_converged_state_valid");
128+
return this->_run_bool(true, this->is_converged_state_valid);
129+
}
130+
double run_max_allowed_step()
131+
{
132+
auto _t = this->context->logger->time("max_allowed_step");
133+
double max_step = 1.0;
134+
for (auto f : this->max_allowed_step) {
135+
max_step = std::min(max_step, f());
136+
}
137+
return max_step;
138+
}
139+
140+
// ---- Residual computation ----
141+
double compute_residual(const Eigen::VectorXd& r) { return this->residual(r); }
142+
};
143+
144+
using spSolverCallbacks = std::shared_ptr<SolverCallbacks>;
145+
}

symx/src/solver/solver_settings.h

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#pragma once
2+
#include <string>
3+
#include <limits>
4+
5+
#include "solver_types.h"
6+
7+
namespace symx
8+
{
9+
// ================================================================
10+
// SolverSettings
11+
//
12+
// Base settings shared by all iterative solvers.
13+
// ================================================================
14+
struct SolverSettings
15+
{
16+
// Iteration limits
17+
int max_iterations = std::numeric_limits<int>::max();
18+
int min_iterations = 0;
19+
20+
// Convergence criteria
21+
double residual_tolerance_abs = 1e-6;
22+
double residual_tolerance_rel = 0.0;
23+
double step_tolerance = 0.0;
24+
bool max_iterations_as_success = false; // Treat hitting max_iterations as convergence
25+
26+
// Line search
27+
double step_cap = std::numeric_limits<double>::infinity(); // Clamp step norm to this value
28+
bool enable_armijo_backtracking = true;
29+
double line_search_armijo_beta = 1e-4; // Sufficient decrease parameter
30+
int max_backtracking_armijo_iterations = 20;
31+
int max_backtracking_invalid_state_iterations = 8;
32+
bool print_line_search_upon_failure = false;
33+
34+
std::string as_string(const std::string& prefix = "") const
35+
{
36+
std::string p = prefix;
37+
std::string out;
38+
out += "\n" + p + "Iteration limits";
39+
out += "\n" + p + " max_iterations: " + std::to_string(max_iterations);
40+
out += "\n" + p + " min_iterations: " + std::to_string(min_iterations);
41+
out += "\n" + p + "Convergence";
42+
out += "\n" + p + " residual_tolerance_abs: " + to_string_sci(residual_tolerance_abs);
43+
out += "\n" + p + " residual_tolerance_rel: " + to_string_sci(residual_tolerance_rel);
44+
out += "\n" + p + " step_tolerance: " + to_string_sci(step_tolerance);
45+
out += "\n" + p + " max_iterations_as_success: " + to_string(max_iterations_as_success);
46+
out += "\n" + p + "Line search";
47+
out += "\n" + p + " step_cap: " + to_string_sci(step_cap);
48+
out += "\n" + p + " enable_armijo_backtracking: " + to_string(enable_armijo_backtracking);
49+
out += "\n" + p + " armijo_beta: " + to_string_sci(line_search_armijo_beta);
50+
out += "\n" + p + " max_armijo_iterations: " + std::to_string(max_backtracking_armijo_iterations);
51+
out += "\n" + p + " max_invalid_state_iterations: " + std::to_string(max_backtracking_invalid_state_iterations);
52+
out += "\n" + p + " print_line_search_upon_failure: " + to_string(print_line_search_upon_failure);
53+
return out;
54+
}
55+
};
56+
57+
// ================================================================
58+
// NewtonSettings
59+
//
60+
// Settings for Newton's method, extending SolverSettings with
61+
// Hessian projection and linear solver options.
62+
// ================================================================
63+
struct NewtonSettings : public SolverSettings
64+
{
65+
// Hessian projection to PD
66+
ProjectionToPD projection_mode = ProjectionToPD::ProjectedNewton; // Safe default
67+
double projection_eps = 1e-10; // Min eigenvalue after projection
68+
bool project_to_pd_use_mirroring = false; // Mirror instead of clamp for negative eigenvalues
69+
70+
// Project-On-Demand
71+
int project_on_demand_countdown = 4; // Iterations to project after failure
72+
73+
// Progressive Projected Newton (PPN)
74+
double ppn_tightening_factor = 0.5; // Tighten threshold on non-descending step
75+
double ppn_release_factor = 2.0; // Relax threshold after successful step
76+
77+
// Linear solver
78+
LinearSolver linear_solver = LinearSolver::BDPCG;
79+
int cg_max_iterations = 10000;
80+
double cg_abs_tolerance = 1e-12;
81+
double cg_rel_tolerance = 1e-4;
82+
bool cg_stop_on_indefiniteness = true;
83+
double bailout_residual = 1e-10; // Skip step if residual below this (CG stability)
84+
85+
std::string as_string(const std::string& prefix = "") const
86+
{
87+
std::string p = prefix;
88+
std::string out;
89+
out += SolverSettings::as_string(prefix);
90+
out += "\n" + p + "Hessian projection";
91+
out += "\n" + p + " projection_mode: " + to_string(projection_mode);
92+
out += "\n" + p + " projection_eps: " + to_string_sci(projection_eps);
93+
out += "\n" + p + " use_mirroring: " + to_string(project_to_pd_use_mirroring);
94+
out += "\n" + p + " on_demand_countdown: " + std::to_string(project_on_demand_countdown);
95+
out += "\n" + p + " ppn_tightening_factor: " + to_string_fixed(ppn_tightening_factor);
96+
out += "\n" + p + " ppn_release_factor: " + to_string_fixed(ppn_release_factor);
97+
out += "\n" + p + "Linear solver";
98+
out += "\n" + p + " solver: " + to_string(linear_solver);
99+
out += "\n" + p + " cg_max_iterations: " + std::to_string(cg_max_iterations);
100+
out += "\n" + p + " cg_abs_tolerance: " + to_string_sci(cg_abs_tolerance);
101+
out += "\n" + p + " cg_rel_tolerance: " + to_string_sci(cg_rel_tolerance);
102+
out += "\n" + p + " cg_stop_on_indefiniteness: " + to_string(cg_stop_on_indefiniteness);
103+
out += "\n" + p + " bailout_residual: " + to_string_sci(bailout_residual);
104+
return out;
105+
}
106+
};
107+
}

0 commit comments

Comments
 (0)