Skip to content

Commit c9af26d

Browse files
NNotepad: Add ability to show additional tensors in the output (#311)
Adds `output()` helper which can be used to list an identifier (or multiple identifiers) that will be shown in the results pane before the result of the final expression or assignment. A few notes: - Ending with an `output()` is supported, as a convenience. - The `output()` helper takes an identifier, not an expression. In the future this could change, either to taking an expression or showing the identifier name in the results pane. - Mixing `split()` (which returns a list of tensors) and `output()` is probably a bit confusing, since split already gives you multiple tensors, you're gonna have to pay attention to how many are coming from each. Sorry! Also fixes a minor issue with the call to asyncInit in js/tests.js - it wasn't awaited, so if the first test case was a test with a call it would fail.
1 parent 1c6f0f2 commit c9af26d

4 files changed

Lines changed: 66 additions & 23 deletions

File tree

nnotepad/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ In addition to WebNN [`MLGraphBuilder`](https://webmachinelearning.github.io/web
5050

5151
* **load(_url_, _shape_, _dataType_)** - fetch a tensor resource. Must be served with appropriate [CORS](https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS) headers. Example: `load('https://www.random.org/cgi-bin/randbyte?nbytes=256', [16, 16], 'uint8')`
5252
* **zeros(_shape_, _dataType_)** - constant zero-filled tensor of the given shape. Example: `zeros([2,2,2,2], 'int8')`
53+
* **output(_identifier_, ...)** - show the named variable(s) as an additional output, in addition to the last expression result. Example: `T = [1,2] output(T) mul(T,3)`
5354

5455

5556
# Details & Gotchas

nnotepad/js/nnotepad.js

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -432,22 +432,31 @@ export class NNotepad {
432432
// Generates WebNN code as the body of a function. `_` is passed as the
433433
// `MLGraphBuilder`. The output of the last expression is returned.
434434

435+
const kUtilArgName = '_util_';
436+
const kOutputsArgName = '_outputs_';
437+
435438
const src = lines
436-
.map(
437-
(line, index) =>
438-
serializeLine(line, index === lines.length - 1))
439+
.map((line, index) =>
440+
serializeLine(line, index === lines.length - 1))
439441
.map((line) => line + ';\n')
440442
.join('');
441443
const AsyncFunction = async function() {}.constructor;
442-
return [new AsyncFunction(['_', 'Util'], src), src];
444+
return [new AsyncFunction(['_', kUtilArgName, kOutputsArgName], src), src];
443445

444446
function serializeLine(line, last) {
445447
const expr = serializeExpr(line.expr);
448+
449+
// If the last thing is an `output()` call, don't wrap it.
450+
const isOutputCall = line.type === 'expression' &&
451+
line.expr.type === 'call' && line.expr.identifier === 'output';
452+
const wrapAsOutput = last && !isOutputCall;
453+
446454
switch (line.type) {
447455
case 'assignment':
448-
return last ? `return ${expr}` : `const ${line.identifier} = ${expr}`;
456+
return wrapAsOutput ? `${kOutputsArgName}.push(${expr})` :
457+
`const ${line.identifier} = ${expr}`;
449458
case 'expression':
450-
return last ? `return ${expr}` : expr;
459+
return wrapAsOutput ? `${kOutputsArgName}.push(${expr})` : expr;
451460
}
452461
throw new Error(`unexpected line type: ${line.type}`);
453462
}
@@ -588,7 +597,7 @@ export class NNotepad {
588597
return `_.constant({dataType: "${dataType.value}", dimensions: ${
589598
Util.stringify(dims)}, shape: ${
590599
Util.stringify(dims)}}, new ${
591-
ctor.name}(await Util.loadBuffer(${
600+
ctor.name}(await ${kUtilArgName}.${Util.loadBuffer.name}(${
592601
Util.stringify(url.value)})).buffer)`;
593602
}
594603

@@ -609,6 +618,15 @@ export class NNotepad {
609618
ctor.name}(${len}).buffer)`;
610619
}
611620

621+
if (name === 'output') {
622+
return args.map((arg) => {
623+
if (arg.type !== 'identifier') {
624+
throw new TypeError('output(): expected identifier');
625+
}
626+
return `${kOutputsArgName}.push(${arg.value})`;
627+
}).join(';\n');
628+
}
629+
612630
return `_.${name}(${
613631
args.map(
614632
(arg, index) =>
@@ -630,21 +648,21 @@ export class NNotepad {
630648
const builder = new self.MLGraphBuilder(context);
631649

632650
const outputOperands = [];
633-
let output = await builderFunc(builder, Util);
634-
if (output instanceof self.MLOperand) {
635-
// TODO: remove try/catch once all back-ends support `identity()`.
636-
try {
637-
// In case `output` is a constant.
638-
output = builder.identity(output);
639-
} catch (ex) {
640-
// Just live with it for now.
651+
const outputs = [];
652+
await builderFunc(builder, Util, outputs);
653+
for (const output of outputs.flat()) {
654+
if (output instanceof self.MLOperand) {
655+
// TODO: remove try/catch once all back-ends support `identity()`.
656+
try {
657+
// In case `output` is a constant.
658+
outputOperands.push(builder.identity(output));
659+
} catch (ex) {
660+
// Just live with it for now.
661+
outputOperands.push(output);
662+
}
663+
} else {
664+
throw new ParseError(`Non-MLOperand output: ${output}`);
641665
}
642-
outputOperands.push(output);
643-
} else if (Array.isArray(output)) {
644-
outputOperands.push(...output);
645-
// no-op
646-
} else {
647-
throw new ParseError(`Non-MLOperand output: ${output}`);
648666
}
649667

650668
const namedOutputs = {};

nnotepad/js/tests.js

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import {Harness} from './testharness.js';
21
import {NNotepad} from './nnotepad.js';
2+
import {Harness} from './testharness.js';
33

44
// ============================================================
55
// Helper for NNotepad-specific tests
@@ -60,7 +60,7 @@ async function testThrows(expr) {
6060
// ============================================================
6161

6262
document.addEventListener('DOMContentLoaded', async (e) => {
63-
NNotepad.asyncInit();
63+
await NNotepad.asyncInit();
6464

6565
Harness.section('Numbers');
6666
await test('125', {dataType: 'float32', shape: [], buffer: [125]});
@@ -158,6 +158,29 @@ document.addEventListener('DOMContentLoaded', async (e) => {
158158
{dataType: 'float32', shape: [2], buffer: [1, 2]},
159159
{dataType: 'float32', shape: [2], buffer: [3, 4]},
160160
]);
161+
await test(`A = [1,2] output(A) B = [3,4]`, [
162+
{dataType: 'float32', shape: [2], buffer: [1, 2]},
163+
{dataType: 'float32', shape: [2], buffer: [3, 4]},
164+
]);
165+
await test(`A = [1,2] output(A) B = [3,4] output(B)`, [
166+
{dataType: 'float32', shape: [2], buffer: [1, 2]},
167+
{dataType: 'float32', shape: [2], buffer: [3, 4]},
168+
]);
169+
await test(`A = [1,2] output(A) split([1,2,3,4], 2)`, [
170+
{dataType: 'float32', shape: [2], buffer: [1, 2]},
171+
{dataType: 'float32', shape: [2], buffer: [1, 2]},
172+
{dataType: 'float32', shape: [2], buffer: [3, 4]},
173+
]);
174+
await test(`array = split([1,2,3,4], 2) output(array)`, [
175+
{dataType: 'float32', shape: [2], buffer: [1, 2]},
176+
{dataType: 'float32', shape: [2], buffer: [3, 4]},
177+
]);
178+
await test(
179+
`A = [[1,7],[2,4]] B = [[3,3],[5,2]] output(A, B) matmul(A,B)`, [
180+
{dataType: 'float32', shape: [2, 2], buffer: [1, 7, 2, 4]},
181+
{dataType: 'float32', shape: [2, 2], buffer: [3, 3, 5, 2]},
182+
{dataType: 'float32', shape: [2, 2], buffer: [38, 17, 26, 14]},
183+
]);
161184

162185
Harness.section('Non-operand arguments: array of operands');
163186
await test(

nnotepad/res/docs.html

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ <h1>Helpers</h1>
110110
<ul>
111111
<li><strong>load(<em>url</em>, <em>shape</em>, <em>dataType</em>)</strong> - fetch a tensor resource. Must be served with appropriate <a href="https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS">CORS</a> headers. Example: <code>load('https://www.random.org/cgi-bin/randbyte?nbytes=256', [16, 16], 'uint8')</code></li>
112112
<li><strong>zeros(<em>shape</em>, <em>dataType</em>)</strong> - constant zero-filled tensor of the given shape. Example: <code>zeros([2,2,2,2], 'int8')</code></li>
113+
<li><strong>output(<em>identifier</em>, ...)</strong> - show the named variable(s) as an additional output, in addition to the last expression result. Example: <code>T = [1,2] output(T) mul(T,3)</code></li>
113114
</ul>
114115
<h1>Details &amp; Gotchas</h1>
115116
<ul>

0 commit comments

Comments
 (0)