diff --git a/index.html b/index.html index 3f6060d6..b222ba13 100644 --- a/index.html +++ b/index.html @@ -36,9 +36,11 @@ A Neural Network Playground + + @@ -142,6 +144,7 @@

Tinker With a Neural Network R
+
@@ -308,6 +311,35 @@

Output

+
+ +
+
+
Model equations
+
+
+
Forward pass
+
+
+
+
+
Training objective
+
+
+
+
+
Backpropagation (per example)
+
+
+
+
+
Weight and bias updates (mini-batch)
+
+
+
+
+
+ diff --git a/package-lock.json b/package-lock.json index 80d87e6e..6aded10b 100644 --- a/package-lock.json +++ b/package-lock.json @@ -9,6 +9,7 @@ "version": "2016.3.10", "dependencies": { "d3": "^3.5.16", + "katex": "^0.16.11", "material-design-lite": "^1.3.0", "seedrandom": "^2.4.3" }, @@ -1882,6 +1883,29 @@ "node": "*" } }, + "node_modules/katex": { + "version": "0.16.11", + "resolved": "https://registry.npmjs.org/katex/-/katex-0.16.11.tgz", + "integrity": "sha512-RQrI8rlHY92OLf3rho/Ts8i/XvjgguEjOkO1BEXcU3N8BqPpSzBNwV/G0Ukr+P/l3ivvJUE/Fa/CwbS6HesGNQ==", + "funding": [ + "https://opencollective.com/katex", + "https://github.com/sponsors/katex" + ], + "dependencies": { + "commander": "^8.3.0" + }, + "bin": { + "katex": "cli.js" + } + }, + "node_modules/katex/node_modules/commander": { + "version": "8.3.0", + "resolved": "https://registry.npmjs.org/commander/-/commander-8.3.0.tgz", + "integrity": "sha512-OkTL9umf+He2DZkUq8f8J9of7yL6RJKI24dVITBmNfZBmri9zYZQrKkuXiKhyfPSu8tUhnVBB1iKXevvnlR4Ww==", + "engines": { + "node": ">= 12" + } + }, "node_modules/kind-of": { "version": "3.2.2", "resolved": "https://registry.npmjs.org/kind-of/-/kind-of-3.2.2.tgz", @@ -5027,6 +5051,21 @@ "through": ">=2.2.7 <3" } }, + "katex": { + "version": "0.16.11", + "resolved": "https://registry.npmjs.org/katex/-/katex-0.16.11.tgz", + "integrity": "sha512-RQrI8rlHY92OLf3rho/Ts8i/XvjgguEjOkO1BEXcU3N8BqPpSzBNwV/G0Ukr+P/l3ivvJUE/Fa/CwbS6HesGNQ==", + "requires": { + "commander": "^8.3.0" + }, + "dependencies": { + "commander": { + "version": "8.3.0", + "resolved": "https://registry.npmjs.org/commander/-/commander-8.3.0.tgz", + "integrity": "sha512-OkTL9umf+He2DZkUq8f8J9of7yL6RJKI24dVITBmNfZBmri9zYZQrKkuXiKhyfPSu8tUhnVBB1iKXevvnlR4Ww==" + } + } + }, "kind-of": { "version": "3.2.2", "resolved": "https://registry.npmjs.org/kind-of/-/kind-of-3.2.2.tgz", diff --git a/package.json b/package.json index c127686c..37f6a9ca 100644 --- a/package.json +++ b/package.json @@ -6,7 +6,7 @@ "scripts": { "clean": "rimraf dist", "start": "npm run serve-watch", - "prep": "copyfiles analytics.js dist && concat node_modules/material-design-lite/material.min.js node_modules/seedrandom/seedrandom.min.js > dist/lib.js", + "prep": "copyfiles analytics.js dist && copyfiles -u 3 node_modules/katex/dist/katex.min.css dist && copyfiles -u 3 node_modules/katex/dist/katex.min.js dist && copyfiles -u 3 \"node_modules/katex/dist/fonts/*.woff2\" dist && concat node_modules/material-design-lite/material.min.js node_modules/seedrandom/seedrandom.min.js > dist/lib.js", "build-css": "concat node_modules/material-design-lite/material.min.css styles.css > dist/bundle.css", "watch-css": "concat node_modules/material-design-lite/material.min.css styles.css -o dist/bundle.css", "build-html": "copyfiles index.html dist", @@ -32,6 +32,7 @@ }, "dependencies": { "d3": "^3.5.16", + "katex": "^0.16.11", "material-design-lite": "^1.3.0", "seedrandom": "^2.4.3" } diff --git a/src/equation-guides.ts b/src/equation-guides.ts new file mode 100644 index 00000000..be151e43 --- /dev/null +++ b/src/equation-guides.ts @@ -0,0 +1,232 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +interface GuideItem { + term: string; + detail: string; +} + +function fillDescriptionList( + container: HTMLElement | null, items: GuideItem[]): void { + if (container == null) { + return; + } + container.innerHTML = ""; + let dl = document.createElement("dl"); + dl.className = "nn-equation-desc-list"; + for (let i = 0; i < items.length; i++) { + let it = items[i]; + let dt = document.createElement("dt"); + dt.textContent = it.term; + let dd = document.createElement("dd"); + dd.textContent = it.detail; + dl.appendChild(dt); + dl.appendChild(dd); + } + container.appendChild(dl); +} + +const FORWARD_GUIDE: GuideItem[] = [ + { + term: "Layered formula", + detail: "Each line is one neuron: affine map (bias plus weighted sum of " + + "incoming values) then the layer activation. Numbers are the live " + + "weights and biases from the diagram." + }, + { + term: "Input symbols (X1, X2, …)", + detail: "Values from the Features column for one point, in top-to-bottom " + + "order of input nodes." + }, + { + term: "h_j^(l)", + detail: "Hidden unit j in hidden layer l (column l after inputs), top " + + "to bottom in the diagram." + }, + { + term: "ŷ", + detail: "Scalar network output after the output activation (regression: " + + "linear; classification: tanh in this demo)." + } +]; + +const OBJECTIVE_GUIDE: GuideItem[] = [ + { + term: "L (calligraphic in the formula)", + detail: "Scalar objective minimized during training: average prediction " + + "error on the training set, plus an optional weight penalty." + }, + { + term: "N", + detail: "Number of training examples used in the average." + }, + { + term: "ŷ_i", + detail: "Network output for training example i (same definition as ŷ in " + + "the forward pass)." + }, + { + term: "y_i", + detail: "Target label for example i (class or numeric target from the " + + "dataset)." + }, + { + term: "½(ŷ_i − y_i)²", + detail: "Half squared error for one example; the mean over i is the data " + + "loss shown as training loss (up to the same formula)." + }, + { + term: "λ and ∑_w", + detail: "When regularization is on: λ is the regularization-rate slider; " + + "the sum is over all connection weights. L1 adds λ|w|; L2 adds λ·½w² " + + "per weight (matching the implementation)." + }, + { + term: "Gradient of L w.r.t. a weight (data + regularization)", + detail: "If L = (1/N)Σ_i E_i + λ Σ_w R(w), then for a fixed connection " + + "weight w, ∂L/∂w = (1/N)Σ_i ∂E_i/∂w + λ R′(w). The first term is what " + + "backprop builds from the chain rule through the network (prediction → " + + "layers → that edge). The second term does not pass through ŷ: R depends " + + "only on w, so its contribution is just λ R′(w). This demo averages " + + "per-example ∂E/∂w over the mini-batch, then applies a separate " + + "regularization step—equivalent to adding those two parts of the gradient." + } +]; + +const BACKPROP_GUIDE: GuideItem[] = [ + { + term: "Chain rule (overall)", + detail: "E depends on ŷ, which depends on a long composition of affine maps " + + "(z = b + Σ w a) and activations (a = σ(z)). Backprop is the chain rule " + + "applied in reverse order: start from ∂E/∂ŷ and propagate factors " + + "∂E/∂z, ∂E/∂w, and ∂E/∂a layer by layer until every weight has a " + + "∂E/∂w." + }, + { + term: "E", + detail: "Error for a single training example only: E = ½(ŷ − y)². " + + "Backpropagation computes derivatives of this E (not the full-batch " + + "average) with respect to weights and activations." + }, + { + term: "ŷ and y", + detail: "Prediction and label for that one example—the same ŷ as after " + + "forward pass, y from the data point." + }, + { + term: "∂E/∂ŷ", + detail: "Derivative of E with respect to the network output. For this " + + "square loss it equals ŷ − y; it is the starting signal at the output " + + "neuron before applying the output activation derivative. This is the " + + "first factor in the chain: ∂E/∂z_out = (∂E/∂ŷ)(∂ŷ/∂z_out) with " + + "∂ŷ/∂z_out = σ′(z_out) when ŷ = σ(z_out)." + }, + { + term: "δ^(z) and ∂E/∂z", + detail: "Same quantity: error signal for a neuron’s total input z (bias " + + "plus weighted sum). In code this is inputDer. It is ∂E/∂z: how much E " + + "changes if z nudges, after all later layers are folded in by the chain " + + "rule." + }, + { + term: "Chain rule through the activation (σ)", + detail: "Here a = σ(z). E depends on a, and a depends on z, so by the " + + "chain rule ∂E/∂z = (∂E/∂a)(da/dz) = (∂E/∂a)σ′(z). That product is " + + "exactly δ^(z). For a linear output, σ′ = 1 so ∂E/∂z = ∂E/∂a." + }, + { + term: "z and a", + detail: "z is pre-activation input to the neuron; a = σ(z) is the value " + + "sent along outgoing edges (what the diagram shows on the node)." + }, + { + term: "∂E/∂a (before σ′)", + detail: "How E changes with the neuron’s output a. At the output node, " + + "∂E/∂a is ∂E/∂ŷ (or merges with it depending on notation). Deeper " + + "layers receive ∂E/∂a as the sum of backward messages from the next " + + "column (see the sum formula)." + }, + { + term: "Chain rule for one weight", + detail: "The affine map gives z_to = b + Σ w a_from, so ∂z_to/∂w = a_from " + + "when w is one of those edges. Hence ∂E/∂w = (∂E/∂z_to)(∂z_to/∂w) = " + + "δ^(z)_to · a_from (errorDer on that link in code)." + }, + { + term: "Chain rule into an earlier activation (the sum)", + detail: "Each downstream z depends on a_from through z = … + w a_from + …, " + + "so ∂z/∂a_from = w for that edge. If several units feed from the same " + + "a_from, E changes through each path: ∂E/∂a_from = Σ_out " + + "(∂E/∂z_out)(∂z_out/∂a_from) = Σ_out w_out δ^(z)_out. That is the " + + "displayed sum; it becomes ∂E/∂a for the previous layer’s nodes." + } +]; + +const UPDATE_GUIDE: GuideItem[] = [ + { + term: "w and b", + detail: "Connection weight and neuron bias updated after a mini-batch of " + + "examples (each example ran forward and backward and contributed to " + + "running sums)." + }, + { + term: "η (eta)", + detail: "Learning rate from the UI; scales how far each update moves " + + "along the negative gradient." + }, + { + term: "n_b and the overline", + detail: "Mini-batch size (batch size control). The bar means average: " + + "sum of per-example ∂E/∂w or ∂E/∂b over the batch, divided by n_b." + }, + { + term: "Second w line (regularization)", + detail: "If L1 or L2 is enabled: after the gradient step, each weight is " + + "moved again to shrink the penalty R(w), scaled by η and λ." + }, + { + term: "R(w) and R′(w)", + detail: "Regularization term and its derivative: L2 uses ½w² and R′=w; L1 " + + "uses |w| with a subgradient in {−1, 0, 1} (0 at w = 0), as in the " + + "implementation." + }, + { + term: "Chain rule: data loss vs. regularization in the update", + detail: "The averaged term uses ∂E/∂w from backprop—every factor is a chain " + + "rule through ŷ and the layers. The term λ R′(w) is not chained through " + + "the network: ∂(λ R(w))/∂w = λ R′(w) because R depends only on w. The " + + "code applies −(η/n_b)·(sum of ∂E/∂w) and then −ηλ R′(w), which matches " + + "descending the sum of data-gradient and regularization-gradient on w." + } +]; + +/** Populate glossary lists under each equation block. */ +export function fillEquationPanelGuides( + forwardDesc: HTMLElement | null, + objectiveDesc: HTMLElement | null, + backpropDesc: HTMLElement | null, + updateDesc: HTMLElement | null): void { + fillDescriptionList(forwardDesc, FORWARD_GUIDE); + fillDescriptionList(objectiveDesc, OBJECTIVE_GUIDE); + fillDescriptionList(backpropDesc, BACKPROP_GUIDE); + fillDescriptionList(updateDesc, UPDATE_GUIDE); +} + +export const EQUATION_DESC_ELEMENT_IDS: string[] = [ + "nn-equation-forward-desc", + "nn-equation-objective-desc", + "nn-equation-backprop-desc", + "nn-equation-update-desc" +]; diff --git a/src/equation-legend.ts b/src/equation-legend.ts new file mode 100644 index 00000000..b4288fcb --- /dev/null +++ b/src/equation-legend.ts @@ -0,0 +1,175 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import * as nn from "./nn"; + +/** e.g. 6 -> "6th", 21 -> "21st" (for hidden-layer column wording). */ +function ordinalSuffix(n: number): string { + let mod100 = n % 100; + if (mod100 >= 11 && mod100 <= 13) { + return n + "th"; + } + switch (n % 10) { + case 1: + return n + "st"; + case 2: + return n + "nd"; + case 3: + return n + "rd"; + default: + return n + "th"; + } +} + +export interface LegendRow { + /** Small KaTeX fragment; omit for note-only rows. */ + symbolTex?: string; + detail: string; +} + +export interface LegendSection { + title: string; + rows: LegendRow[]; +} + +const FEATURE_GUIDE: {[id: string]: string} = { + "x": "Horizontal coordinate of each data point (matches a node in the Features column, top to bottom).", + "y": "Vertical coordinate of each data point.", + "xSquared": "Square of the horizontal coordinate.", + "ySquared": "Square of the vertical coordinate.", + "xTimesY": "Product of horizontal and vertical coordinates.", + "sinX": "Sine of the horizontal coordinate.", + "sinY": "Sine of the vertical coordinate.", + "cosX": "Cosine of the horizontal coordinate.", + "cosY": "Cosine of the vertical coordinate." +}; + +/** + * Human-readable mapping between equation symbols and the diagram (columns and + * neuron order top-to-bottom) for the layered equation view. + */ +export function buildEquationLegendSections( + network: nn.Node[][], inputIds: string[], inputSymbols: string[], + hiddenActivationSummary: string, + outputActivationSummary: string): LegendSection[] { + let sections: LegendSection[] = []; + + if (network == null || network.length < 2) { + return sections; + } + + let inputRows: LegendRow[] = []; + for (let i = 0; i < inputIds.length; i++) { + let id = inputIds[i]; + let guide = FEATURE_GUIDE[id] != null ? + FEATURE_GUIDE[id] : + "Enabled input feature in the Features column."; + inputRows.push({ + symbolTex: inputSymbols[i], + detail: guide + }); + } + sections.push({ + title: "Input layer (left column in the diagram)", + rows: inputRows + }); + + let numHidden = network.length - 2; + for (let layerIdx = 1; layerIdx <= numHidden; layerIdx++) { + let layer = network[layerIdx]; + let colOrdinal = layerIdx === 1 ? "first" : + layerIdx === 2 ? "second" : + layerIdx === 3 ? "third" : + layerIdx === 4 ? "fourth" : + layerIdx === 5 ? "fifth" : ordinalSuffix(layerIdx); + let hiddenRows: LegendRow[] = []; + for (let i = 0; i < layer.length; i++) { + let node = layer[i]; + let ordinal = i + 1; + hiddenRows.push({ + symbolTex: "h_{" + ordinal + "}^{(" + layerIdx + ")}", + detail: "Neuron " + ordinal + " from the top in the " + colOrdinal + + " hidden column after inputs (square node id " + node.id + + " in the diagram)." + }); + } + sections.push({ + title: "Hidden layer " + layerIdx + " (" + colOrdinal + + " column of weighted neurons; " + hiddenActivationSummary + ")", + rows: hiddenRows + }); + } + + sections.push({ + title: "Output layer (rightmost column)", + rows: [{ + symbolTex: "\\hat{y}", + detail: "Network output after the output activation (" + + outputActivationSummary + ")." + }] + }); + + return sections; +} + +/** Notation for training / backprop blocks (shown after diagram symbol sections). */ +export function buildTrainingNotationLegendSection( + learningRate: number, + regularizationKey: string, + regularizationRate: number): LegendSection { + let regDetail = regularizationKey === "none" ? + "No penalty term; the objective is only the mean squared error." : + "Regularization rate slider; matches lambda in the objective and update equations."; + return { + title: "Training notation", + rows: [ + { + symbolTex: "\\eta", + detail: "Learning rate (current value " + learningRate.toFixed(3) + + ", from the Learning rate control)." + }, + { + symbolTex: "\\lambda", + detail: regDetail + + (regularizationKey !== "none" ? + " Current value " + regularizationRate.toFixed(3) + "." : + "") + }, + { + symbolTex: "z", + detail: "Total input into a neuron (bias plus weighted sum of incoming " + + "activations). The neuron output is the activation applied to z." + }, + { + symbolTex: "\\delta^{(z)}", + detail: "Derivative of per-example error E with respect to z; " + + "bias updates use the same accumulated quantity as in the code." + }, + { + symbolTex: "E", + detail: "Per-example squared error (half square of output minus label) " + + "used in backprop. Training and test loss in the UI are the mean " + + "over N examples." + }, + { + symbolTex: "h_1", + detail: "In numeric backprop lines, the top hidden neuron in the first " + + "hidden column (only when the network has at least one hidden layer). " + + "Values are from the last training example in the most recent Step or " + + "Play epoch, captured before loss is recomputed over the dataset." + } + ] + }; +} diff --git a/src/network-equation.ts b/src/network-equation.ts new file mode 100644 index 00000000..cdcbf96c --- /dev/null +++ b/src/network-equation.ts @@ -0,0 +1,114 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import * as nn from "./nn"; + +/** Formats a scalar for LaTeX (fixed decimals, trim negative zero). */ +function fmt(n: number): string { + let s = n.toFixed(3); + if (s === "-0.000") { + return "0.000"; + } + return s; +} + +/** Builds bias + sum_k w_k * symbol_k with clean + / - spacing. */ +function affineTex(node: nn.Node, prevTex: string[]): string { + let parts: string[] = [fmt(node.bias)]; + for (let j = 0; j < node.inputLinks.length; j++) { + parts.push(fmt(node.inputLinks[j].weight) + " \\cdot " + prevTex[j]); + } + return parts.join(" + ").replace(/\+ -/g, "- "); +} + +/** Maps activation keys (as in state.activations) to LaTeX. */ +function actWrap(activationKey: string, body: string): string { + switch (activationKey) { + case "relu": + return "\\operatorname{ReLU}\\left(" + body + "\\right)"; + case "tanh": + return "\\tanh\\left(" + body + "\\right)"; + case "sigmoid": + return "\\sigma\\left(" + body + "\\right)"; + case "linear": + return body; + default: + return "\\left(" + body + "\\right)"; + } +} + +function sigmaLegendRow(): string { + return "&\\text{with } \\sigma(t)=\\dfrac{1}{1+e^{-t}}"; +} + +function needsSigmaLegend(hiddenKey: string, outputKey: string): boolean { + return hiddenKey === "sigmoid" || outputKey === "sigmoid"; +} + +/** + * Layer-wise definition: equations reference h_j^{(l)} symbols. + */ +function buildLayeredTex( + network: nn.Node[][], inputSymbols: string[], + hiddenActivationKey: string, outputActivationKey: string): string { + let lines: string[] = []; + let prevTex = inputSymbols.slice(); + + for (let layerIdx = 1; layerIdx < network.length; layerIdx++) { + let layer = network[layerIdx]; + let isOutput = layerIdx === network.length - 1; + let actKey = isOutput ? outputActivationKey : hiddenActivationKey; + let nextPrev: string[] = []; + + for (let i = 0; i < layer.length; i++) { + let node = layer[i]; + let inner = affineTex(node, prevTex); + let rhs = actWrap(actKey, inner); + + if (isOutput) { + lines.push("\\hat{y} &= " + rhs); + } else { + let h = "h_{" + (i + 1) + "}^{(" + layerIdx + ")}"; + nextPrev.push(h); + lines.push(h + " &= " + rhs); + } + } + if (!isOutput) { + prevTex = nextPrev; + } + } + + if (needsSigmaLegend(hiddenActivationKey, outputActivationKey)) { + lines.push(sigmaLegendRow()); + } + return "\\begin{aligned}\n" + lines.join("\\\\\n") + "\n\\end{aligned}"; +} + +/** + * Layer-wise MLP formula with numeric weights and biases. + * inputSymbols must match network[0] order (same as constructInput / buildNetwork). + */ +export function buildNetworkEquationTex( + network: nn.Node[][], inputSymbols: string[], + hiddenActivationKey: string, outputActivationKey: string): string { + if (network == null || network.length < 2) { + return ""; + } + if (inputSymbols.length !== network[0].length) { + return ""; + } + return buildLayeredTex( + network, inputSymbols, hiddenActivationKey, outputActivationKey); +} diff --git a/src/nn.ts b/src/nn.ts index e92a13de..8d9e149e 100644 --- a/src/nn.ts +++ b/src/nn.ts @@ -321,8 +321,9 @@ export function backProp(network: Node[][], target: number, // Compute the error derivative with respect to each node's output. node.outputDer = 0; for (let j = 0; j < node.outputs.length; j++) { - let output = node.outputs[j]; - node.outputDer += output.weight * output.dest.inputDer; + let link = node.outputs[j]; + let dest = link.dest; + node.outputDer += link.weight * dest.inputDer; } } } diff --git a/src/playground.ts b/src/playground.ts index aeac0f9c..8e070151 100644 --- a/src/playground.ts +++ b/src/playground.ts @@ -27,6 +27,21 @@ import { } from "./state"; import {Example2D, shuffle} from "./dataset"; import {AppendingLineChart} from "./linechart"; +import {buildNetworkEquationTex} from "./network-equation"; +import { + buildEquationLegendSections, + buildTrainingNotationLegendSection +} from "./equation-legend"; +import { + buildObjectiveTex, + buildBackpropTex, + buildWeightUpdateTex, + extractBackpropSnapshot +} from "./training-equations"; +import { + fillEquationPanelGuides, + EQUATION_DESC_ELEMENT_IDS +} from "./equation-guides"; import * as d3 from 'd3'; let mainWidth; @@ -72,6 +87,92 @@ let INPUTS: {[name: string]: InputFeature} = { "sinY": {f: (x, y) => Math.sin(y), label: "sin(X_2)"}, }; +function fillEquationLegend( + katex: {[key: string]: any}, container: HTMLElement | null, + sections: ReturnType): void { + if (container == null) { + return; + } + container.innerHTML = ""; + for (let s = 0; s < sections.length; s++) { + let sec = sections[s]; + let wrap = document.createElement("div"); + wrap.className = "nn-legend-section"; + let h = document.createElement("h6"); + h.className = "nn-legend-section-title"; + h.textContent = sec.title; + wrap.appendChild(h); + let ul = document.createElement("ul"); + ul.className = "nn-legend-list"; + for (let r = 0; r < sec.rows.length; r++) { + let row = sec.rows[r]; + let li = document.createElement("li"); + li.className = "nn-legend-row"; + if (row.symbolTex != null && row.symbolTex !== "") { + let sym = document.createElement("span"); + sym.className = "nn-legend-symbol"; + if (katex != null && typeof katex.renderToString === "function") { + sym.innerHTML = katex.renderToString(row.symbolTex, { + throwOnError: false, + displayMode: false + }); + } else { + sym.textContent = row.symbolTex; + } + li.appendChild(sym); + } + let det = document.createElement("span"); + det.className = "nn-legend-detail"; + det.textContent = row.detail; + li.appendChild(det); + ul.appendChild(li); + } + wrap.appendChild(ul); + container.appendChild(wrap); + } +} + +function renderKatexDisplay( + katex: {[key: string]: any} | undefined, + el: HTMLElement | null, + tex: string): void { + if (el == null) { + return; + } + if (tex === "") { + el.innerHTML = ""; + return; + } + if (katex != null && typeof katex.render === "function") { + katex.render(tex, el, {displayMode: true, throwOnError: false}); + } else { + el.innerHTML = ""; + } +} + +function clearEquationPanelMath(): void { + for (let id of [ + "nn-equation-forward", + "nn-equation-objective", + "nn-equation-backprop", + "nn-equation-update"]) { + let n = document.getElementById(id); + if (n != null) { + n.innerHTML = ""; + } + } + for (let i = 0; i < EQUATION_DESC_ELEMENT_IDS.length; i++) { + let d = document.getElementById(EQUATION_DESC_ELEMENT_IDS[i]); + if (d != null) { + d.innerHTML = ""; + } + } + let leg = document.getElementById("nn-equation-legend"); + if (leg != null) { + leg.innerHTML = ""; + } +} + let HIDABLE_CONTROLS = [ ["Show test data", "showTestData"], ["Discretize output", "discretize"], @@ -88,6 +189,7 @@ let HIDABLE_CONTROLS = [ ["Noise level", "noise"], ["Batch size", "batchSize"], ["# of hidden layers", "numHiddenLayers"], + ["Model equations", "equationPanel"], ]; class Player { @@ -166,6 +268,8 @@ let colorScale = d3.scale.linear() .range(["#f59322", "#e8eaeb", "#0877bd"]) .clamp(true); let iter = 0; +/** Captured after each training epoch, before getLoss (see extractBackpropSnapshot). */ +let lastBackpropSnapshot: ReturnType = null; let trainData: Example2D[] = []; let testData: Example2D[] = []; let network: nn.Node[][] = null; @@ -408,13 +512,14 @@ function updateWeightsUI(network: nn.Node[][], container) { let node = currentLayer[i]; for (let j = 0; j < node.inputLinks.length; j++) { let link = node.inputLinks[j]; - container.select(`#link${link.source.id}-${link.dest.id}`) - .style({ - "stroke-dashoffset": -iter / 3, - "stroke-width": linkWidthScale(Math.abs(link.weight)), - "stroke": colorScale(link.weight) - }) - .datum(link); + let sel = container.select(`#link${link.source.id}-${link.dest.id}`); + sel.style({ + "stroke-dasharray": null, + "stroke-dashoffset": -iter / 3, + "stroke-width": linkWidthScale(Math.abs(link.weight)), + stroke: colorScale(link.weight), + opacity: 1 + }).datum(link); } } } @@ -776,14 +881,14 @@ function drawLink( // Add an invisible thick link that will be used for // showing the weight value on hover. - container.append("path") + let hover = container.append("path") .attr("d", diagonal(datum, 0)) - .attr("class", "link-hover") - .on("mouseenter", function() { - updateHoverCard(HoverType.WEIGHT, input, d3.mouse(this)); - }).on("mouseleave", function() { - updateHoverCard(null); - }); + .attr("class", "link-hover"); + hover.on("mouseenter", function() { + updateHoverCard(HoverType.WEIGHT, input, d3.mouse(this)); + }).on("mouseleave", function() { + updateHoverCard(null); + }); return line; } @@ -849,8 +954,8 @@ function getLoss(network: nn.Node[][], dataPoints: Example2D[]): number { } function updateUI(firstStep = false) { - // Update the links visually. - updateWeightsUI(network, d3.select("g.core")); + // Update the links visually (scope to network SVG, not e.g. colormap g.core). + updateWeightsUI(network, d3.select("#svg").select("g.core")); // Update the bias values visually. updateBiasesUI(network); // Get the decision boundary of the network. @@ -884,6 +989,65 @@ function updateUI(firstStep = false) { d3.select("#loss-test").text(humanReadable(lossTest)); d3.select("#iter-number").text(addCommas(zeroPad(iter))); lineChart.addDataPoint([lossTrain, lossTest]); + + let forwardEl = document.getElementById("nn-equation-forward"); + let objEl = document.getElementById("nn-equation-objective"); + let backEl = document.getElementById("nn-equation-backprop"); + let updEl = document.getElementById("nn-equation-update"); + let legEl = document.getElementById("nn-equation-legend"); + let katex = (window as any)["katex"]; + let eqPanelHidden = state.getHiddenProps().indexOf("equationPanel") >= 0; + if (eqPanelHidden) { + clearEquationPanelMath(); + } else { + fillEquationPanelGuides( + document.getElementById("nn-equation-forward-desc"), + document.getElementById("nn-equation-objective-desc"), + document.getElementById("nn-equation-backprop-desc"), + document.getElementById("nn-equation-update-desc")); + let regKey = getKeyFromValue(regularizations, state.regularization) || + "none"; + renderKatexDisplay( + katex, objEl, buildObjectiveTex(regKey, state.regularizationRate)); + renderKatexDisplay( + katex, backEl, buildBackpropTex(lastBackpropSnapshot)); + renderKatexDisplay( + katex, updEl, + buildWeightUpdateTex( + state.learningRate, state.regularizationRate, regKey)); + + if (network != null && forwardEl != null) { + let inputIds = constructInputIds(); + let inputSymbols = inputIds.map(id => { + let feature = INPUTS[id]; + return feature != null && feature.label != null ? feature.label : id; + }); + let hiddenKey = getKeyFromValue(activations, state.activation) || "tanh"; + let outputKey = state.problem === Problem.REGRESSION ? "linear" : "tanh"; + let tex = buildNetworkEquationTex( + network, inputSymbols, hiddenKey, outputKey); + renderKatexDisplay(katex, forwardEl, tex); + let outSummary = state.problem === Problem.REGRESSION ? + "linear / identity (regression)" : + "tanh (classification)"; + let hiddenSummary = hiddenKey === "relu" ? "ReLU on each neuron" : + hiddenKey === "tanh" ? "tanh on each neuron" : + hiddenKey === "sigmoid" ? "sigmoid σ on each neuron" : + "linear (no nonlinearity) on each neuron"; + let sections = buildEquationLegendSections( + network, inputIds, inputSymbols, hiddenSummary, outSummary); + sections.push(buildTrainingNotationLegendSection( + state.learningRate, regKey, state.regularizationRate)); + fillEquationLegend(katex, legEl, sections); + } else { + if (forwardEl != null) { + forwardEl.innerHTML = ""; + } + if (legEl != null) { + legEl.innerHTML = ""; + } + } + } } function constructInputIds(): string[] { @@ -916,6 +1080,12 @@ function oneStep(): void { nn.updateWeights(network, state.learningRate, state.regularizationRate); } }); + if (trainData.length > 0) { + lastBackpropSnapshot = extractBackpropSnapshot( + network, trainData[trainData.length - 1].label); + } else { + lastBackpropSnapshot = null; + } // Compute the loss. lossTrain = getLoss(network, trainData); lossTest = getLoss(network, testData); @@ -951,6 +1121,7 @@ function reset(onStartup=false) { // Make a simple network. iter = 0; + lastBackpropSnapshot = null; let numInputs = constructInput(0 , 0).length; let shape = [numInputs].concat(state.networkShape).concat([1]); let outputActivation = (state.problem === Problem.REGRESSION) ? @@ -1024,19 +1195,25 @@ function drawDatasetThumbnails() { } } -function hideControls() { - // Set display:none to all the UI elements that are hidden. +/** Show or hide UI blocks listed in HIDABLE_CONTROLS from URL hide flags. */ +function applyHidableControlVisibility() { let hiddenProps = state.getHiddenProps(); - hiddenProps.forEach(prop => { - let controls = d3.selectAll(`.ui-${prop}`); - if (controls.size() === 0) { - console.warn(`0 html elements found with class .ui-${prop}`); + HIDABLE_CONTROLS.forEach(([, id]) => { + let hidden = hiddenProps.indexOf(id) >= 0; + let controls = d3.selectAll(`.ui-${id}`); + if (hidden && controls.size() === 0) { + console.warn(`0 html elements found with class .ui-${id}`); } - controls.style("display", "none"); + controls.style("display", hidden ? "none" : null); }); +} + +function hideControls() { + applyHidableControlVisibility(); // Also add checkbox for each hidable control in the "use it in classrom" // section. + let hiddenProps = state.getHiddenProps(); let hideControls = d3.select(".hide-controls"); HIDABLE_CONTROLS.forEach(([text, id]) => { let label = hideControls.append("label") @@ -1051,8 +1228,10 @@ function hideControls() { } input.on("change", function() { state.setHideProperty(id, !this.checked); + applyHidableControlVisibility(); state.serialize(); userHasInteracted(); + updateUI(); d3.select(".hide-controls-link") .attr("href", window.location.href); }); diff --git a/src/state.ts b/src/state.ts index 42dc8154..6d350d4c 100644 --- a/src/state.ts +++ b/src/state.ts @@ -147,6 +147,8 @@ export class State { problem = Problem.CLASSIFICATION; initZero = false; hideText = false; + /** When true, the Model equations panel is hidden (classroom hide control). */ + equationPanel_hide = true; collectStats = false; numHiddenLayers = 1; hiddenLayerControls: any[] = []; diff --git a/src/training-equations.ts b/src/training-equations.ts new file mode 100644 index 00000000..7954ba53 --- /dev/null +++ b/src/training-equations.ts @@ -0,0 +1,222 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +import * as nn from "./nn"; + +/** Scalar for TeX \text{...} (matches network-equation numeric style). */ +export function fmtCoeff(n: number): string { + let s = n.toFixed(3); + if (s === "-0.000") { + return "0.000"; + } + return s; +} + +/** + * Total objective: mean half-square error plus optional L1/L2 on weights + * (see nn.updateWeights and RegularizationFunction in nn.ts). + */ +export function buildObjectiveTex( + regularizationKey: string | undefined, + lambda: number): string { + let key = regularizationKey === "L1" || regularizationKey === "L2" ? + regularizationKey : + "none"; + let dataMean = + "\\frac{1}{N}\\sum_{i=1}^{N} \\frac{1}{2}\\left(\\hat{y}_i - " + + "y_i\\right)^2"; + if (key === "none") { + return "\\begin{aligned}\n\\mathcal{L} &= " + dataMean + "\n\\end{aligned}"; + } + let regTerm = key === "L1" ? + "\\lambda \\sum_{w} \\lvert w\\rvert" : + "\\lambda \\sum_{w} \\frac{1}{2} w^2"; + let lam = fmtCoeff(lambda); + return ( + "\\begin{aligned}\n\\mathcal{L} &= " + dataMean + " + " + regTerm + + "\\\\\n&\\quad \\lambda = \\text{" + lam + + "} \\text{ (regularization rate)}\n\\end{aligned}"); +} + +/** Numeric values for the output neuron and one incoming weight (see extract). */ +export interface BackpropNumericSnapshot { + yHat: number; + y: number; + e: number; + dEdYhat: number; + zOut: number; + sigmaPrimeOut: number; + deltaOut: number; + exampleWeight: number; + exampleAFrom: number; + exampleDEdW: number; + hiddenDelta?: number; + hiddenZ?: number; + hiddenSigmaPrime?: number; + hiddenDEdA?: number; +} + +/** + * Read derivatives and activations right after nn.backProp for this example. + * Caller must ensure forwardProp+backProp just ran for the same point (before + * any further forwardProp, e.g. from getLoss). + */ +export function extractBackpropSnapshot( + network: nn.Node[][], labelY: number): BackpropNumericSnapshot | null { + if (network == null || network.length < 2) { + return null; + } + let out = network[network.length - 1][0]; + if (out == null) { + return null; + } + let yHat = out.output; + let y = labelY; + let e = 0.5 * (yHat - y) * (yHat - y); + let zOut = out.totalInput; + let sigmaPrimeOut = out.activation.der(zOut); + let dEdYhat = out.outputDer; + let deltaOut = out.inputDer; + + let linkToShow: nn.Link | null = null; + for (let j = 0; j < out.inputLinks.length; j++) { + let L = out.inputLinks[j]; + if (!L.isDead) { + linkToShow = L; + break; + } + } + if (linkToShow == null && out.inputLinks.length > 0) { + linkToShow = out.inputLinks[0]; + } + let exampleWeight = linkToShow != null ? linkToShow.weight : 0; + let exampleAFrom = linkToShow != null && linkToShow.source != null ? + linkToShow.source.output : + 0; + let exampleDEdW = linkToShow != null ? linkToShow.errorDer : 0; + + let snap: BackpropNumericSnapshot = { + yHat: yHat, + y: y, + e: e, + dEdYhat: dEdYhat, + zOut: zOut, + sigmaPrimeOut: sigmaPrimeOut, + deltaOut: deltaOut, + exampleWeight: exampleWeight, + exampleAFrom: exampleAFrom, + exampleDEdW: exampleDEdW + }; + + if (network.length > 2) { + let h0 = network[1][0]; + if (h0 != null) { + snap.hiddenDelta = h0.inputDer; + snap.hiddenZ = h0.totalInput; + snap.hiddenSigmaPrime = h0.activation.der(h0.totalInput); + snap.hiddenDEdA = h0.outputDer; + } + } + return snap; +} + +/** + * Backprop for scalar output and per-example squared error (nn.backProp with + * Errors.SQUARE). Optional snapshot from the last training example in an epoch. + */ +export function buildBackpropTex(snapshot: BackpropNumericSnapshot | null): string { + let f = fmtCoeff; + let rows: string[] = [ + "E &= \\tfrac{1}{2}\\left(\\hat{y} - y\\right)^2 " + + "\\quad \\text{(one example)}\\\\", + "\\frac{\\partial E}{\\partial \\hat{y}} &= \\hat{y} - y\\\\", + "\\delta^{(z)} &= \\frac{\\partial E}{\\partial z} = " + + "\\frac{\\partial E}{\\partial a}\\,\\sigma'(z), \\quad " + + "a = \\sigma(z)\\\\", + "\\frac{\\partial E}{\\partial w} &= \\delta^{(z)}_{\\text{to}}\\, " + + "a_{\\text{from}}\\\\", + "\\frac{\\partial E}{\\partial a_{\\text{from}}} &= " + + "\\sum_{\\text{out}} w\\,\\delta^{(z)}_{\\text{out}}" + ]; + if (snapshot != null) { + let s = snapshot; + rows.push( + "&\\text{Values after the last training example in the latest step:}\\\\", + "E &= \\tfrac{1}{2}(\\hat{y}-y)^2 = \\text{" + f(s.e) + "} \\quad " + + "(\\hat{y}=\\text{" + f(s.yHat) + "},\\, y=\\text{" + f(s.y) + "})\\\\", + "\\frac{\\partial E}{\\partial \\hat{y}} &= \\text{" + f(s.dEdYhat) + + "} \\quad (z_{\\text{out}}=\\text{" + f(s.zOut) + + "},\\, \\sigma'(z_{\\text{out}})=\\text{" + f(s.sigmaPrimeOut) + "})\\\\", + "\\delta^{(z)}_{\\text{out}} &= \\text{" + f(s.deltaOut) + "}\\\\", + "\\frac{\\partial E}{\\partial w} &= \\text{" + f(s.exampleDEdW) + + "} \\quad \\text{(first non-dead input to output; } w=\\text{" + + f(s.exampleWeight) + "},\\, a_{\\text{from}}=\\text{" + + f(s.exampleAFrom) + "})"); + if (s.hiddenDelta != null && s.hiddenZ != null && + s.hiddenSigmaPrime != null && s.hiddenDEdA != null) { + rows.push( + "\\delta^{(z)}_{h_1} &= \\text{" + f(s.hiddenDelta) + "} \\quad " + + "(z=\\text{" + f(s.hiddenZ) + "},\\, \\sigma'(z)=\\text{" + + f(s.hiddenSigmaPrime) + "})\\\\", + "\\frac{\\partial E}{\\partial a_{h_1}} &= \\text{" + + f(s.hiddenDEdA) + "}"); + } + } else { + rows.push( + "&\\text{Train with Step or Play to show numbers from the last example " + + "in each epoch.}"); + } + return "\\begin{aligned}\n" + rows.join("\\\\\n") + "\n\\end{aligned}"; +} + +/** + * nn.updateWeights: averaged gradient step, then optional regularization step. + */ +export function buildWeightUpdateTex( + learningRate: number, + lambda: number, + regularizationKey: string | undefined): string { + let eta = fmtCoeff(learningRate); + let key = regularizationKey === "L1" || regularizationKey === "L2" ? + regularizationKey : + "none"; + let rows: string[] = [ + "w &\\leftarrow w - \\frac{\\eta}{n_{\\text{b}}}\\,\\overline{" + + "\\frac{\\partial E}{\\partial w}}\\\\", + "b &\\leftarrow b - \\frac{\\eta}{n_{\\text{b}}}\\,\\overline{" + + "\\frac{\\partial E}{\\partial b}}\\\\", + "&\\text{where } \\overline{\\cdot} \\text{ is the mean over the " + + "mini-batch of size } n_{\\text{b}} \\text{; } \\eta = \\text{" + + eta + "} \\text{ (learning rate).}" + ]; + if (key !== "none") { + let lam = fmtCoeff(lambda); + rows.push( + "w &\\leftarrow w - \\eta\\lambda\\, R'(w) \\quad \\text{(after " + + "the step above; } R' \\text{ uses } w \\text{ from that step)}"); + if (key === "L1") { + rows.push( + "&\\lambda = \\text{" + lam + + "},\\ R(w)=\\lvert w\\rvert,\\ R'(w)\\in\\{-1,0,1\\} " + + "\\text{(0 at }w=0\\text{). L1 may set } w\\to 0 \\text{ when " + + "the update would change sign.}"); + } else { + rows.push( + "&\\lambda = \\text{" + lam + + "},\\ R(w)=\\tfrac{1}{2}w^2,\\ R'(w)=w."); + } + } + return "\\begin{aligned}\n" + rows.join("\\\\\n") + "\n\\end{aligned}"; +} diff --git a/styles.css b/styles.css index 87b9bc5d..35b2c87d 100644 --- a/styles.css +++ b/styles.css @@ -378,14 +378,29 @@ header h1 .optional { display: -ms-flexbox; display: -webkit-flex; display: flex; - -webkit-justify-content: space-between; - justify-content: space-between; + -webkit-flex-direction: column; + flex-direction: column; + -webkit-align-items: stretch; + align-items: stretch; margin-top: 30px; margin-bottom: 50px; padding-top: 2px; position: relative; } +.main-part-columns { + display: -webkit-box; + display: -moz-box; + display: -ms-flexbox; + display: -webkit-flex; + display: flex; + -webkit-justify-content: space-between; + justify-content: space-between; + -webkit-align-items: flex-start; + align-items: flex-start; + width: 100%; +} + @media (min-height: 700px) { #main-part { margin-top: 50px; @@ -647,6 +662,161 @@ g.column rect { z-index: 100; } +.nn-equation-panel { + box-sizing: border-box; + width: 100%; + max-width: 100%; + margin-top: 28px; + padding: 16px 20px 20px; + border-top: 1px solid rgba(0, 0, 0, 0.08); + background: rgba(0, 0, 0, 0.02); + border-radius: 4px; + position: relative; + z-index: 1; +} + +.nn-equation-header { + margin-bottom: 12px; +} + +.nn-equation-title { + margin: 0; + font-size: 13px; + font-weight: 500; + color: #333; + text-transform: uppercase; + letter-spacing: 0.04em; +} + +.nn-equation-block { + margin-top: 16px; +} + +.nn-equation-header + .nn-equation-block { + margin-top: 0; +} + +.nn-equation-subtitle { + margin: 0 0 8px 0; + font-size: 11px; + font-weight: 500; + text-transform: uppercase; + letter-spacing: 0.05em; + color: #555; +} + +.nn-equation { + width: 100%; + max-width: 100%; + overflow-x: auto; + overflow-y: hidden; + font-size: 14px; + line-height: 1.4; + -webkit-overflow-scrolling: touch; +} + +.nn-equation .katex-display { + margin: 0.5em 0; +} + +.nn-equation-desc { + margin-top: 12px; + padding-top: 12px; + border-top: 1px solid rgba(0, 0, 0, 0.05); + font-size: 12px; + line-height: 1.5; + color: #444; +} + +.nn-equation-desc-list { + margin: 0; +} + +.nn-equation-desc-list dt { + font-weight: 500; + color: #333; + margin-top: 10px; +} + +.nn-equation-desc-list dt:first-child { + margin-top: 0; +} + +.nn-equation-desc-list dd { + margin: 4px 0 0 0; + padding: 0; + font-weight: 300; +} + +.nn-equation-legend { + margin-top: 20px; + padding-top: 16px; + border-top: 1px solid rgba(0, 0, 0, 0.06); +} + +.nn-legend-section { + margin-bottom: 14px; +} + +.nn-legend-section:last-child { + margin-bottom: 0; +} + +.nn-legend-section-title { + margin: 0 0 8px 0; + font-size: 11px; + font-weight: 500; + text-transform: uppercase; + letter-spacing: 0.05em; + color: #555; +} + +.nn-legend-list { + margin: 0; + padding: 0; + list-style: none; + color: #444; + font-size: 12px; + line-height: 1.45; + font-weight: 300; +} + +.nn-legend-row { + display: -webkit-box; + display: -webkit-flex; + display: -ms-flexbox; + display: flex; + -webkit-flex-wrap: wrap; + flex-wrap: wrap; + -webkit-align-items: baseline; + align-items: baseline; + margin-bottom: 8px; + padding-left: 10px; + border-left: 2px solid rgba(0, 0, 0, 0.08); +} + +.nn-legend-row:last-child { + margin-bottom: 0; +} + +.nn-legend-symbol { + -webkit-flex: 0 0 auto; + flex: 0 0 auto; + margin-right: 12px; + min-width: 2.5em; +} + +.nn-legend-symbol .katex { + font-size: 1em; +} + +.nn-legend-detail { + -webkit-flex: 1 1 220px; + flex: 1 1 220px; + color: #555; + min-width: 0; +} + #network svg .main-label { font-size: 13px; fill: #333;