Skip to content

Commit ae32af4

Browse files
authored
fix: integrate fmin internally (#6867)
* fix: integrate fmin internally * fix: add unit test * fix: add LICENSE * fix: remove unused param
1 parent 27089ef commit ae32af4

10 files changed

Lines changed: 640 additions & 2 deletions

File tree

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import {
2+
conjugateGradient,
3+
conjugateGradientSolve,
4+
gradientDescent,
5+
gradientDescentLineSearch,
6+
nelderMead,
7+
} from '../../../../../src/data/utils/venn/fmin';
8+
9+
const SMALL = 1e-5;
10+
11+
function nearlyEqual(
12+
left,
13+
right,
14+
tolerance = SMALL,
15+
message = 'assertNearlyEqual',
16+
) {
17+
expect(Math.abs(left - right)).toBeLessThan(tolerance);
18+
console.log(`${message}: ${left} ~== ${right}`);
19+
}
20+
21+
function lessThan(test, left, right, message) {
22+
message = message || 'lessThan';
23+
test.ok(left < right, `${message}: ${left} < ${right}`);
24+
}
25+
26+
const optimizers = [
27+
nelderMead,
28+
gradientDescent,
29+
gradientDescentLineSearch,
30+
conjugateGradient,
31+
];
32+
33+
const optimizerNames = [
34+
'Nelder Mead',
35+
'Gradient Descent',
36+
'Gradient Descent w/ Line Search',
37+
'Conjugate Gradient',
38+
];
39+
40+
describe('fmin', () => {
41+
test('himmelblau', () => {
42+
// due to a bug, this used to not converge to the minimum
43+
const x = 4.9515014216303825;
44+
const y = 0.07301421370357275;
45+
46+
const params = { learnRate: 0.1 };
47+
48+
const himmelblau = (X, fxprime = [0, 0]) => {
49+
const [x, y] = X;
50+
fxprime[0] = 2 * (x + 2 * y - 7) + 4 * (2 * x + y - 5);
51+
fxprime[1] = 4 * (x + 2 * y - 7) + 2 * (2 * x + y - 5);
52+
// biome-ignore lint/style/useExponentiationOperator: TODO: use **
53+
return Math.pow(x + 2 * y - 7, 2) + Math.pow(2 * x + y - 5, 2);
54+
};
55+
56+
optimizers.forEach((optimizer, index) => {
57+
const solution = optimizer(himmelblau, [x, y], params);
58+
nearlyEqual(solution.fx, 0, SMALL, `himmelblau:${optimizerNames[index]}`);
59+
});
60+
});
61+
62+
test('banana', () => {
63+
const x = 1.6084564160555601;
64+
const y = -1.5980748860165477;
65+
66+
const banana = (X, fxprime) => {
67+
fxprime = fxprime || [0, 0];
68+
const x = X[0];
69+
const y = X[1];
70+
fxprime[0] = 400 * x * x * x - 400 * y * x + 2 * x - 2;
71+
fxprime[1] = 200 * y - 200 * x * x;
72+
return (1 - x) * (1 - x) + 100 * (y - x * x) * (y - x * x);
73+
};
74+
75+
const params = { learnRate: 0.0003, maxIterations: 50000 };
76+
for (let i = 0; i < optimizers.length; ++i) {
77+
const solution = optimizers[i](banana, [x, y], params);
78+
nearlyEqual(solution.fx, 0, 1e-3, `banana:${optimizerNames[i]}`);
79+
}
80+
});
81+
82+
test('quadratic1D', () => {
83+
const loss = (x, xprime) => {
84+
xprime = xprime || [0, 0];
85+
xprime[0] = 2 * (x[0] - 10);
86+
return (x[0] - 10) * (x[0] - 10);
87+
};
88+
89+
const params = { learnRate: 0.5 };
90+
91+
for (let i = 0; i < optimizers.length; ++i) {
92+
const solution = optimizers[i](loss, [0], params);
93+
nearlyEqual(solution.fx, 0, SMALL, `quadratic_1d:${optimizerNames[i]}`);
94+
}
95+
});
96+
97+
test('nelderMead', () => {
98+
const loss = (X) => {
99+
const x = X[0];
100+
const y = X[1];
101+
return Math.sin(y) * x + Math.sin(x) * y + x * x + y * y;
102+
};
103+
104+
const solution = nelderMead(loss, [-3.5, 3.5]);
105+
nearlyEqual(solution.fx, 0, SMALL, 'nelderMead');
106+
});
107+
108+
test('conjugateGradientSolve', () => {
109+
// matyas function
110+
let A = [
111+
[0.52, -0.48],
112+
[-0.48, 0.52],
113+
];
114+
let b = [0, 0];
115+
const initial = [-9.08, -7.83];
116+
let x = conjugateGradientSolve(A, b, initial);
117+
nearlyEqual(x[0], 0, SMALL, 'matyas.x');
118+
nearlyEqual(x[1], 0, SMALL, 'matyas.y');
119+
120+
// booth's function
121+
const history = [];
122+
A = [
123+
[10, 8],
124+
[8, 10],
125+
];
126+
b = [34, 38];
127+
x = conjugateGradientSolve(A, b, initial, history);
128+
nearlyEqual(x[0], 1, SMALL, 'booth.x');
129+
nearlyEqual(x[1], 3, SMALL, 'booth.y');
130+
console.log(history);
131+
});
132+
});

package.json

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@
7878
"@antv/util": "^3.3.10",
7979
"@antv/vendor": "^1.0.8",
8080
"flru": "^1.0.2",
81-
"fmin": "0.0.2",
8281
"pdfast": "^0.2.0"
8382
},
8483
"devDependencies": {

src/data/utils/venn/fmin/bisect.ts

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/** finds the zeros of a function, given two starting points (which must
2+
* have opposite signs */
3+
export function bisect(f, a, b, parameters?: any) {
4+
parameters = parameters || {};
5+
const maxIterations = parameters.maxIterations || 100;
6+
const tolerance = parameters.tolerance || 1e-10;
7+
const fA = f(a);
8+
const fB = f(b);
9+
let delta = b - a;
10+
11+
if (fA * fB > 0) {
12+
throw 'Initial bisect points must have opposite signs';
13+
}
14+
15+
if (fA === 0) return a;
16+
if (fB === 0) return b;
17+
18+
for (let i = 0; i < maxIterations; ++i) {
19+
delta /= 2;
20+
const mid = a + delta;
21+
const fMid = f(mid);
22+
23+
if (fMid * fA >= 0) {
24+
a = mid;
25+
}
26+
27+
if (Math.abs(delta) < tolerance || fMid === 0) {
28+
return mid;
29+
}
30+
}
31+
return a + delta;
32+
}

src/data/utils/venn/fmin/blas1.ts

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
// need some basic operations on vectors, rather than adding a dependency,
2+
// just define here
3+
export function zeros(x) {
4+
const r = new Array(x);
5+
for (let i = 0; i < x; ++i) {
6+
r[i] = 0;
7+
}
8+
return r;
9+
}
10+
export function zerosM(x, y) {
11+
return zeros(x).map(() => zeros(y));
12+
}
13+
14+
export function dot(a, b) {
15+
let ret = 0;
16+
for (let i = 0; i < a.length; ++i) {
17+
ret += a[i] * b[i];
18+
}
19+
return ret;
20+
}
21+
22+
export function norm2(a) {
23+
return Math.sqrt(dot(a, a));
24+
}
25+
26+
export function scale(ret, value, c?: any) {
27+
for (let i = 0; i < value.length; ++i) {
28+
ret[i] = value[i] * c;
29+
}
30+
}
31+
32+
export function weightedSum(ret, w1, v1, w2, v2) {
33+
for (let j = 0; j < ret.length; ++j) {
34+
ret[j] = w1 * v1[j] + w2 * v2[j];
35+
}
36+
}
37+
38+
export function gemv(output, A, x) {
39+
for (let i = 0; i < output.length; ++i) {
40+
output[i] = dot(A[i], x);
41+
}
42+
}
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import { dot, gemv, norm2, scale, weightedSum } from './blas1';
2+
import { wolfeLineSearch } from './linesearch';
3+
4+
export function conjugateGradient(f, initial, params) {
5+
// allocate all memory up front here, keep out of the loop for perfomance
6+
// reasons
7+
let current = { x: initial.slice(), fx: 0, fxprime: initial.slice() };
8+
let next = { x: initial.slice(), fx: 0, fxprime: initial.slice() };
9+
const yk = initial.slice();
10+
let temp;
11+
let a = 1;
12+
13+
params = params || {};
14+
const maxIterations = params.maxIterations || initial.length * 20;
15+
16+
current.fx = f(current.x, current.fxprime);
17+
const pk = current.fxprime.slice();
18+
scale(pk, current.fxprime, -1);
19+
20+
for (let i = 0; i < maxIterations; ++i) {
21+
a = wolfeLineSearch(f, pk, current, next, a);
22+
23+
// todo: history in wrong spot?
24+
if (params.history) {
25+
params.history.push({
26+
x: current.x.slice(),
27+
fx: current.fx,
28+
fxprime: current.fxprime.slice(),
29+
alpha: a,
30+
});
31+
}
32+
33+
if (!a) {
34+
// faiiled to find point that satifies wolfe conditions.
35+
// reset direction for next iteration
36+
scale(pk, current.fxprime, -1);
37+
} else {
38+
// update direction using Polak–Ribiere CG method
39+
weightedSum(yk, 1, next.fxprime, -1, current.fxprime);
40+
41+
const delta_k = dot(current.fxprime, current.fxprime);
42+
const beta_k = Math.max(0, dot(yk, next.fxprime) / delta_k);
43+
44+
weightedSum(pk, beta_k, pk, -1, next.fxprime);
45+
46+
temp = current;
47+
current = next;
48+
next = temp;
49+
}
50+
51+
if (norm2(current.fxprime) <= 1e-5) {
52+
break;
53+
}
54+
}
55+
56+
if (params.history) {
57+
params.history.push({
58+
x: current.x.slice(),
59+
fx: current.fx,
60+
fxprime: current.fxprime.slice(),
61+
alpha: a,
62+
});
63+
}
64+
65+
return current;
66+
}
67+
68+
/// Solves a system of lienar equations Ax =b for x
69+
/// using the conjugate gradient method.
70+
export function conjugateGradientSolve(A, b, x, history?: any) {
71+
const r = x.slice();
72+
const Ap = x.slice();
73+
let rsold;
74+
let rsnew;
75+
let alpha;
76+
77+
// r = b - A*x
78+
gemv(Ap, A, x);
79+
weightedSum(r, 1, b, -1, Ap);
80+
const p = r.slice();
81+
rsold = dot(r, r);
82+
83+
for (let i = 0; i < b.length; ++i) {
84+
gemv(Ap, A, p);
85+
alpha = rsold / dot(p, Ap);
86+
if (history) {
87+
history.push({ x: x.slice(), p: p.slice(), alpha: alpha });
88+
}
89+
90+
//x=x+alpha*p;
91+
weightedSum(x, 1, x, alpha, p);
92+
93+
// r=r-alpha*Ap;
94+
weightedSum(r, 1, r, -alpha, Ap);
95+
rsnew = dot(r, r);
96+
if (Math.sqrt(rsnew) <= 1e-10) break;
97+
98+
// p=r+(rsnew/rsold)*p;
99+
weightedSum(p, 1, r, rsnew / rsold, p);
100+
rsold = rsnew;
101+
}
102+
if (history) {
103+
history.push({ x: x.slice(), p: p.slice(), alpha: alpha });
104+
}
105+
return x;
106+
}
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import { dot, norm2, scale, weightedSum, zeros } from './blas1';
2+
import { wolfeLineSearch } from './linesearch';
3+
4+
export function gradientDescent(f, initial, params) {
5+
params = params || {};
6+
const maxIterations = params.maxIterations || initial.length * 100;
7+
const learnRate = params.learnRate || 0.001;
8+
const current = { x: initial.slice(), fx: 0, fxprime: initial.slice() };
9+
10+
for (let i = 0; i < maxIterations; ++i) {
11+
current.fx = f(current.x, current.fxprime);
12+
if (params.history) {
13+
params.history.push({
14+
x: current.x.slice(),
15+
fx: current.fx,
16+
fxprime: current.fxprime.slice(),
17+
});
18+
}
19+
20+
weightedSum(current.x, 1, current.x, -learnRate, current.fxprime);
21+
if (norm2(current.fxprime) <= 1e-5) {
22+
break;
23+
}
24+
}
25+
26+
return current;
27+
}
28+
29+
export function gradientDescentLineSearch(f, initial, params) {
30+
params = params || {};
31+
let current = { x: initial.slice(), fx: 0, fxprime: initial.slice() };
32+
let next = { x: initial.slice(), fx: 0, fxprime: initial.slice() };
33+
const maxIterations = params.maxIterations || initial.length * 100;
34+
let learnRate = params.learnRate || 1;
35+
const pk = initial.slice();
36+
const c1 = params.c1 || 1e-3;
37+
const c2 = params.c2 || 0.1;
38+
let temp;
39+
let functionCalls = [];
40+
41+
if (params.history) {
42+
// wrap the function call to track linesearch samples
43+
const inner = f;
44+
f = (x, fxprime) => {
45+
functionCalls.push(x.slice());
46+
return inner(x, fxprime);
47+
};
48+
}
49+
50+
current.fx = f(current.x, current.fxprime);
51+
for (let i = 0; i < maxIterations; ++i) {
52+
scale(pk, current.fxprime, -1);
53+
learnRate = wolfeLineSearch(f, pk, current, next, learnRate, c1, c2);
54+
55+
if (params.history) {
56+
params.history.push({
57+
x: current.x.slice(),
58+
fx: current.fx,
59+
fxprime: current.fxprime.slice(),
60+
functionCalls: functionCalls,
61+
learnRate: learnRate,
62+
alpha: learnRate,
63+
});
64+
functionCalls = [];
65+
}
66+
67+
temp = current;
68+
current = next;
69+
next = temp;
70+
71+
if (learnRate === 0 || norm2(current.fxprime) < 1e-5) break;
72+
}
73+
74+
return current;
75+
}

0 commit comments

Comments
 (0)