Skip to content

Commit 526afaf

Browse files
committed
bug in optimization where I forgot to pass rec_x_hat. ugh
1 parent 2394676 commit 526afaf

3 files changed

Lines changed: 38 additions & 38 deletions

File tree

examples/1_basic_tutorial.ipynb

Lines changed: 33 additions & 33 deletions
Large diffs are not rendered by default.

pynumdiff/optimize/_optimize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ def _objective_function(point, func, x, dt, singleton_params, categorical_params
152152
# Evaluate estimate according to a loss function
153153
if dxdt_truth is not None:
154154
if metric == 'rmse': # minimize ||dxdt_hat - dxdt_truth||_2
155-
rms_dxdt = evaluate.rmse(dxdt_truth, dxdt_hat, padding=padding)
156-
cache[key] = rms_dxdt; return rms_dxdt
155+
rmse_dxdt = evaluate.rmse(dxdt_truth, dxdt_hat, padding=padding)
156+
cache[key] = rmse_dxdt; return rmse_dxdt
157157
elif metric == 'error_correlation':
158158
ec = evaluate.error_correlation(dxdt_truth, dxdt_hat, padding=padding)
159159
cache[key] = ec; return ec
@@ -163,7 +163,7 @@ def _objective_function(point, func, x, dt, singleton_params, categorical_params
163163
rec_x_hat = utility.integrate_dxdt_hat(dxdt_hat, dt)
164164
rec_x_hat += utility.estimate_integration_constant(x, rec_x_hat, M=huberM)
165165
# rubust_rme(,M=inf) = rmse(), so just use the simpler function if M=inf
166-
cost = evaluate.rmse(x, x_hat, padding=padding) if huberM == float('inf') else evaluate.robust_rme(x, x_hat, padding=padding, M=huberM)
166+
cost = evaluate.rmse(x, rec_x_hat, padding=padding) if huberM == float('inf') else evaluate.robust_rme(x, rec_x_hat, padding=padding, M=huberM)
167167
cost += tvgamma*evaluate.total_variation(dxdt_hat, padding=padding)
168168
cache[key] = cost; return cost
169169

pynumdiff/utils/evaluate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def robust_rme(x, x_hat, padding=0, M=6):
9292
9393
:return: **robust_rmse_x_hat** (float) -- RMS error between x_hat and data
9494
"""
95-
if padding == 'auto': padding = max([1, int(0.025*len(x))])
95+
if padding == 'auto': padding = max(1, int(0.025*len(x)))
9696
s = slice(padding, len(x)-padding) # slice out data we want to measure
9797
N = s.stop - s.start
9898

@@ -111,7 +111,7 @@ def rmse(dxdt_truth, dxdt_hat, padding=0):
111111
112112
:return: **true_rmse_dxdt** (float) -- RMS error between dxdt_hat and dxdt_truth, returns None if dxdt_hat is None
113113
"""
114-
if padding == 'auto': padding = max([1, int(0.025*len(dxdt_truth))])
114+
if padding == 'auto': padding = max(1, int(0.025*len(dxdt_truth)))
115115
s = slice(padding, len(dxdt_hat)-padding) # slice out data we want to measure
116116
N = s.stop - s.start
117117

0 commit comments

Comments
 (0)