Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 25f8967

Browse files
authored
Switch shader indexing from float to int (#93)
* switch shader indexing from float to int * revert graph_runner_test * self review
1 parent 25814c5 commit 25f8967

23 files changed

+458
-329
lines changed

src/math/math.ts

Lines changed: 109 additions & 67 deletions
Large diffs are not rendered by default.

src/math/math_gpu.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,14 @@ export class NDArrayMathGPU extends NDArrayMath {
153153

154154
protected batchNormalization3DInternal(
155155
x: Array3D, mean: Array3D|Array1D, variance: Array3D|Array1D,
156-
varianceEpsilon = 0.000001, scale?: Array3D|Array1D,
156+
varianceEpsilon: number|null, scale?: Array3D|Array1D,
157157
offset?: Array3D|Array1D): Array3D {
158158
const inputs = [x, mean, variance];
159159

160+
if (varianceEpsilon == null) {
161+
varianceEpsilon = 0.000001;
162+
}
163+
160164
let offsetShape = null;
161165
if (offset != null) {
162166
offsetShape = offset.shape;

src/math/webgl/argminmax_gpu.ts

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,10 @@ export function getArgMinMaxSnippet(
2020
const compOp = (op === 'min') ? '<' : '>';
2121
return `
2222
float getArgMinMax${texName}() {
23-
float bestIndex = 0.0;
24-
float bestValue = get${texName}Flat(0.0);
23+
int bestIndex = 0;
24+
float bestValue = get${texName}Flat(0);
2525
26-
for (int ii = 0; ii < ${size}; ii++) {
27-
float i = float(ii);
26+
for (int i = 0; i < ${size}; i++) {
2827
float candidate = get${texName}Flat(i);
2928
if (isNaN(candidate)) {
3029
return candidate;
@@ -34,7 +33,7 @@ export function getArgMinMaxSnippet(
3433
bestIndex = i;
3534
}
3635
}
37-
return bestIndex;
36+
return float(bestIndex);
3837
}
3938
`;
4039
}

src/math/webgl/concat3d_gpu.ts

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,16 @@ export class Concat3DProgram implements GPGPUProgram {
3232
concat3d_util.computeConcat3DOutputShape(x1Shape, x2Shape, axis);
3333
this.userCode = `
3434
void main() {
35-
vec3 coords = getOutputCoords();
36-
float yR = coords.x;
37-
float yC = coords.y;
38-
float yD = coords.z;
35+
ivec3 coords = getOutputCoords();
36+
int yR = coords.x;
37+
int yC = coords.y;
38+
int yD = coords.z;
3939
4040
float value = 0.0;
41-
if (${concatAxis} < ${x1Shape[axis]}.0) {
41+
if (${concatAxis} < ${x1Shape[axis]}) {
4242
value = getA(yR, yC, yD);
4343
} else {
44-
${concatAxis} -= ${x1Shape[axis]}.0;
44+
${concatAxis} -= ${x1Shape[axis]};
4545
value = getB(yR, yC, yD);
4646
}
4747

src/math/webgl/conv_backprop_gpu.ts

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -36,28 +36,26 @@ export class Conv2DDerWeightsProgram implements GPGPUProgram {
3636
this.params = [stride, zeroPad];
3737
this.userCode = `
3838
void main() {
39-
vec4 coords = getOutputCoords();
40-
float wR = coords.x;
41-
float wC = coords.y;
42-
float d1 = coords.z;
43-
float d2 = coords.w;
39+
ivec4 coords = getOutputCoords();
40+
int wR = coords.x;
41+
int wC = coords.y;
42+
int d1 = coords.z;
43+
int d2 = coords.w;
4444
4545
// Convolve x(?, ?, d1) with dy(:, :, d2) to get dw(wR, wC, d1, d2).
4646
// ? = to be determined. : = across all values in that axis.
4747
float dotProd = 0.0;
48-
for (int iyR = 0; iyR < ${yNumRows}; iyR++) {
49-
float yR = float(iyR);
50-
float xR = wR + yR * ${stride}.0 - ${zeroPad}.0;
48+
for (int yR = 0; yR < ${yNumRows}; yR++) {
49+
int xR = wR + yR * ${stride} - ${zeroPad};
5150
52-
if (xR < 0.0 || xR >= ${xNumRows}.0) {
51+
if (xR < 0 || xR >= ${xNumRows}) {
5352
continue;
5453
}
5554
56-
for (int iyC = 0; iyC < ${yNumCols}; iyC++) {
57-
float yC = float(iyC);
58-
float xC = wC + yC * ${stride}.0 - ${zeroPad}.0;
55+
for (int yC = 0; yC < ${yNumCols}; yC++) {
56+
int xC = wC + yC * ${stride} - ${zeroPad};
5957
60-
if (xC < 0.0 || xC >= ${xNumCols}.0) {
58+
if (xC < 0 || xC >= ${xNumCols}) {
6159
continue;
6260
}
6361
@@ -94,42 +92,41 @@ export class Conv2DTransposeProgram implements GPGPUProgram {
9492
this.params = [pad, fSize, origStride, hasBias];
9593

9694
this.userCode = `
95+
const ivec2 pads = ivec2(${pad}, ${pad});
96+
9797
void main() {
98-
vec3 coords = getOutputCoords();
99-
float yR = coords.x;
100-
float yC = coords.y;
101-
float d2 = coords.z;
98+
ivec3 coords = getOutputCoords();
99+
int d2 = coords.z;
102100
103-
vec2 xRCCorner = vec2(yR, yC) - vec2(${pad}.0, ${pad}.0);
104-
float xRCorner = xRCCorner.x;
105-
float xCCorner = xRCCorner.y;
101+
ivec2 xRCCorner = coords.xy - pads;
102+
int xRCorner = xRCCorner.x;
103+
int xCCorner = xRCCorner.y;
106104
107105
// Convolve x(?, ?, d1) with w(:, :, d2, d1) to get y(yR, yC, d2).
108106
// ? = to be determined. : = across all values in that axis.
109107
float dotProd = 0.0;
110-
for (int iwR = 0; iwR < ${fSize}; iwR++) {
111-
float wR = float(iwR);
112-
float xR = (xRCorner + wR) / ${origStride}.0;
108+
for (int wR = 0; wR < ${fSize}; wR++) {
109+
float xR = float(xRCorner + wR) / ${origStride}.0;
113110
114111
if (xR < 0.0 || xR >= ${xRows}.0 || fract(xR) > 0.0) {
115112
continue;
116113
}
114+
int ixR = int(xR);
117115
118-
float wRPerm = ${fSize}.0 - 1.0 - wR;
116+
int wRPerm = ${fSize} - 1 - wR;
119117
120-
for (int iwC = 0; iwC < ${fSize}; iwC++) {
121-
float wC = float(iwC);
122-
float xC = (xCCorner + wC) / ${origStride}.0;
118+
for (int wC = 0; wC < ${fSize}; wC++) {
119+
float xC = float(xCCorner + wC) / ${origStride}.0;
123120
124121
if (xC < 0.0 || xC >= ${xCols}.0 || fract(xC) > 0.0) {
125122
continue;
126123
}
124+
int ixC = int(xC);
127125
128-
float wCPerm = ${fSize}.0 - 1.0 - wC;
126+
int wCPerm = ${fSize} - 1 - wC;
129127
130-
for (int id1 = 0; id1 < ${origOutputDepth}; id1++) {
131-
float d1 = float(id1);
132-
float xValue = getX(xR, xC, d1);
128+
for (int d1 = 0; d1 < ${origOutputDepth}; d1++) {
129+
float xValue = getX(ixR, ixC, d1);
133130
float wValue = getW(wRPerm, wCPerm, d2, d1);
134131
dotProd += xValue * wValue;
135132
}
@@ -153,13 +150,11 @@ export class Conv2DDerBiasProgram implements GPGPUProgram {
153150
this.outputShape = [outputDepth];
154151
this.userCode = `
155152
void main() {
156-
float d2 = getOutputCoords();
153+
int d2 = getOutputCoords();
157154
158155
float derBias = 0.0;
159-
for (int iyR = 0; iyR < ${yNumRows}; iyR++) {
160-
float yR = float(iyR);
161-
for (int iyC = 0; iyC < ${yNumCols}; iyC++) {
162-
float yC = float(iyC);
156+
for (int yR = 0; yR < ${yNumRows}; yR++) {
157+
for (int yC = 0; yC < ${yNumCols}; yC++) {
163158
derBias += getDy(yR, yC, d2);
164159
}
165160
}

src/math/webgl/conv_gpu.ts

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,38 +33,35 @@ export class Conv2DProgram implements GPGPUProgram {
3333
const xNumRows = xShape[0];
3434
const xNumCols = xShape[1];
3535
this.userCode = `
36+
const ivec2 strides = ivec2(${stride}, ${stride});
37+
const ivec2 pads = ivec2(${pad}, ${pad});
38+
3639
void main() {
37-
vec3 coords = getOutputCoords();
38-
float yR = coords.x;
39-
float yC = coords.y;
40-
float d2 = coords.z;
40+
ivec3 coords = getOutputCoords();
41+
int d2 = coords.z;
4142
42-
vec2 xRCCorner = vec2(yR, yC) * vec2(${stride}.0, ${stride}.0) -
43-
vec2(${pad}.0, ${pad}.0);
44-
float xRCorner = xRCCorner.x;
45-
float xCCorner = xRCCorner.y;
43+
ivec2 xRCCorner = coords.xy * strides - pads;
44+
int xRCorner = xRCCorner.x;
45+
int xCCorner = xRCCorner.y;
4646
4747
// Convolve x(?, ?, d1) with w(:, :, d1, d2) to get y(yR, yC, d2).
4848
// ? = to be determined. : = across all values in that axis.
4949
float dotProd = 0.0;
50-
for (int iwR = 0; iwR < ${fieldSize}; iwR++) {
51-
float wR = float(iwR);
52-
float xR = xRCorner + wR;
50+
for (int wR = 0; wR < ${fieldSize}; wR++) {
51+
int xR = xRCorner + wR;
5352
54-
if (xR < 0.0 || xR >= ${xNumRows}.0) {
53+
if (xR < 0 || xR >= ${xNumRows}) {
5554
continue;
5655
}
5756
58-
for (int iwC = 0; iwC < ${fieldSize}; iwC++) {
59-
float wC = float(iwC);
60-
float xC = xCCorner + wC;
57+
for (int wC = 0; wC < ${fieldSize}; wC++) {
58+
int xC = xCCorner + wC;
6159
62-
if (xC < 0.0 || xC >= ${xNumCols}.0) {
60+
if (xC < 0 || xC >= ${xNumCols}) {
6361
continue;
6462
}
6563
66-
for (int id1 = 0; id1 < ${inputDepth}; id1++) {
67-
float d1 = float(id1);
64+
for (int d1 = 0; d1 < ${inputDepth}; d1++) {
6865
float xValue = getX(xR, xC, d1);
6966
float wValue = getW(wR, wC, d1, d2);
7067
dotProd += xValue * wValue;

src/math/webgl/copy_gpu.ts

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,14 @@ export class Copy2DProgram implements GPGPUProgram {
2626
this.outputShape = null;
2727
this.params = [srcNumCols, destNumCols];
2828
this.userCode = `
29-
uniform vec2 sourceStart;
30-
uniform vec2 destStart;
29+
uniform ivec2 sourceStart;
30+
uniform ivec2 destStart;
3131
3232
void main() {
33-
vec2 destCoords = getOutputCoords() - destStart;
34-
float index = dot(destCoords, vec2(${destNumCols}.0, 1.0));
35-
vec2 sourceCoords = sourceStart + vec2(
36-
floor(index / ${srcNumCols}.0),
37-
mod(index, ${srcNumCols}.0)
38-
);
33+
ivec2 destCoords = getOutputCoords() - destStart;
34+
int index = destCoords.x * ${destNumCols} + destCoords.y;
35+
int r = index / ${srcNumCols};
36+
ivec2 sourceCoords = sourceStart + ivec2(r, index - r * ${srcNumCols});
3937
setOutput(getSource(sourceCoords.x, sourceCoords.y));
4038
}
4139
`;
@@ -48,9 +46,9 @@ export class Copy2DProgram implements GPGPUProgram {
4846
gpgpu.setOutputMatrixWriteRegion(
4947
destStart[0], destSize[0], destStart[1], destSize[1]);
5048
const sourceStartCRLoc = gpgpu.getUniformLocation('sourceStart');
51-
gpgpu.gl.uniform2f(sourceStartCRLoc, sourceStart[0], sourceStart[1]);
49+
gpgpu.gl.uniform2i(sourceStartCRLoc, sourceStart[0], sourceStart[1]);
5250
const destStartCRLoc = gpgpu.getUniformLocation('destStart');
53-
gpgpu.gl.uniform2f(destStartCRLoc, destStart[0], destStart[1]);
51+
gpgpu.gl.uniform2i(destStartCRLoc, destStart[0], destStart[1]);
5452
};
5553
}
5654
}

src/math/webgl/gpgpu_context.ts

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,15 @@ export class GPGPUContext {
4343
if (!webgl_util.isWebGL2Enabled()) {
4444
this.textureFloatExtension =
4545
webgl_util.getExtensionOrThrow(this.gl, 'OES_texture_float');
46+
this.colorBufferFloatExtension =
47+
this.gl.getExtension('WEBGL_color_buffer_float');
4648
} else {
4749
this.colorBufferFloatExtension =
4850
webgl_util.getExtensionOrThrow(this.gl, 'EXT_color_buffer_float');
4951
}
5052

51-
this.loseContextExtension =
52-
webgl_util.getExtensionOrThrow(this.gl, 'WEBGL_lose_context') as
53-
WebGLLoseContextExtension;
53+
this.loseContextExtension = webgl_util.getExtensionOrThrow(
54+
this.gl, 'WEBGL_lose_context') as WebGLLoseContextExtension;
5455
this.vertexBuffer = gpgpu_util.createVertexBuffer(this.gl);
5556
this.indexBuffer = gpgpu_util.createIndexBuffer(this.gl);
5657
this.framebuffer = webgl_util.createFramebuffer(this.gl);
@@ -258,6 +259,9 @@ export class GPGPUContext {
258259
this.throwIfDisposed();
259260
webgl_util.bindColorTextureToFramebuffer(
260261
this.gl, texture, this.framebuffer);
262+
if (this.autoDebugValidate) {
263+
webgl_util.validateFramebuffer(this.gl);
264+
}
261265
const result = downloadAndDecode();
262266
if (this.outputTexture != null) {
263267
webgl_util.bindColorTextureToFramebuffer(

src/math/webgl/gpgpu_context_test.ts

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,21 @@ describe('GPGPUContext downloadMatrixFromTexture WebGL 2.0', () => {
4141
expect(result[0]).toBeCloseTo(0.123);
4242
});
4343

44-
it('returns matrix that was uploaded', () => {
44+
it('returns 1x1 matrix that was uploaded', () => {
4545
gpgpu.uploadMatrixToTexture(texture, 1, 1, new Float32Array([1.234]));
4646
const result = gpgpu.downloadMatrixFromTexture(texture, 1, 1);
4747
expect(result[0]).toBeCloseTo(1.234);
4848
});
4949

50+
it('returns 2x2 matrix that was uploaded', () => {
51+
const texture2 = gpgpu.createMatrixTexture(2, 2);
52+
gpgpu.uploadMatrixToTexture(
53+
texture2, 2, 2, new Float32Array([1.234, 2, 3, 4]));
54+
const result = gpgpu.downloadMatrixFromTexture(texture2, 2, 2);
55+
expect(result).toEqual(new Float32Array([1.234, 2, 3, 4]));
56+
gpgpu.deleteMatrixTexture(texture2);
57+
});
58+
5059
it('uses texture parameter', () => {
5160
const texture2: WebGLTexture = gpgpu.createMatrixTexture(1, 1);
5261
gpgpu.uploadMatrixToTexture(texture, 1, 1, new Float32Array([1]));
@@ -84,12 +93,21 @@ describe('GPGPUContext downloadMatrixFromTexture WebGL 1.0', () => {
8493
expect(result[0]).toBeCloseTo(0.123);
8594
});
8695

87-
it('returns matrix that was uploaded', () => {
96+
it('returns 1x1 matrix that was uploaded', () => {
8897
gpgpu.uploadMatrixToTexture(texture, 1, 1, new Float32Array([1.234]));
8998
const result = gpgpu.downloadMatrixFromTexture(texture, 1, 1);
9099
expect(result[0]).toBeCloseTo(1.234);
91100
});
92101

102+
it('returns 2x2 matrix that was uploaded', () => {
103+
const texture2 = gpgpu.createMatrixTexture(2, 2);
104+
gpgpu.uploadMatrixToTexture(
105+
texture2, 2, 2, new Float32Array([1.234, 2, 3, 4]));
106+
const result = gpgpu.downloadMatrixFromTexture(texture2, 2, 2);
107+
expect(result).toEqual(new Float32Array([1.234, 2, 3, 4]));
108+
gpgpu.deleteMatrixTexture(texture2);
109+
});
110+
93111
it('uses texture parameter', () => {
94112
const texture2: WebGLTexture = gpgpu.createMatrixTexture(1, 1);
95113
gpgpu.uploadMatrixToTexture(texture, 1, 1, new Float32Array([1]));

src/math/webgl/gpgpu_math.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ export function makeShaderKey(
115115
const params = program.params;
116116
const keyStart =
117117
inputs.concat(output).map(x => x.shape + '_' + x.getTextureShapeRC());
118-
const keyEnd = params.map(p => p.toString());
118+
const keyEnd = params.map(String);
119119
let key = [program.constructor.name];
120120
key.push((program.supportsBroadcasting === true).toString());
121121
key = key.concat(keyStart, keyEnd);

0 commit comments

Comments
 (0)