1515 * =============================================================================
1616 */
1717
18+ import { TimingInfo } from '../engine' ;
1819import { ENV } from '../environment' ;
1920import { NDArrayMath } from '../math' ;
2021import * as axis_util from '../ops/axis_util' ;
@@ -26,6 +27,7 @@ import * as types from '../types';
2627// tslint:disable-next-line:max-line-length
2728import { DataType , DataTypeMap , Rank , RecursiveArray , TypedArray } from '../types' ;
2829import * as util from '../util' ;
30+
2931import { KernelBackend } from './backend' ;
3032import { ArgMinMaxProgram } from './webgl/argminmax_gpu' ;
3133import { AvgPool2DBackpropProgram } from './webgl/avg_pool_backprop_gpu' ;
@@ -70,12 +72,21 @@ export interface CPUTimerQuery {
7072 endMs ?: number ;
7173}
7274
75+ export interface WebGLTimingInfo extends TimingInfo {
76+ uploadWaitMs : number ;
77+ downloadWaitMs : number ;
78+ }
79+
7380export class MathBackendWebGL implements KernelBackend {
7481 private texData = new WeakMap < DataId , TextureData > ( ) ;
7582 private canvas : HTMLCanvasElement ;
7683
7784 private programTimersStack : TimerNode [ ] ;
7885 private activeTimers : TimerNode [ ] ;
86+ // Accumulated time spent (including blocking) in uploading data to webgl.
87+ private uploadWaitMs = 0 ;
88+ // Accumulated time spent (including blocking in downloading data from webgl.
89+ private downloadWaitMs = 0 ;
7990
8091 register ( dataId : DataId , shape : number [ ] , dtype : DataType ) : void {
8192 if ( this . texData . has ( dataId ) ) {
@@ -152,8 +163,16 @@ export class MathBackendWebGL implements KernelBackend {
152163 this . cacheOnCPU ( dataId ) ;
153164 return values ;
154165 }
166+ const shouldTimeProgram = this . activeTimers != null ;
167+ let start : number ;
168+ if ( shouldTimeProgram ) {
169+ start = performance . now ( ) ;
170+ }
155171 const float32Values =
156172 this . gpgpu . downloadMatrixFromTexture ( texture , texShape [ 0 ] , texShape [ 1 ] ) ;
173+ if ( shouldTimeProgram ) {
174+ this . downloadWaitMs += performance . now ( ) - start ;
175+ }
157176 this . cacheOnCPU ( dataId , float32Values ) ;
158177 return texData . values ;
159178 }
@@ -182,7 +201,7 @@ export class MathBackendWebGL implements KernelBackend {
182201 return this . readSync ( dataId ) ;
183202 }
184203
185- async time ( f : ( ) => void ) : Promise < number > {
204+ async time ( f : ( ) => void ) : Promise < WebGLTimingInfo > {
186205 const oldActiveTimers = this . activeTimers ;
187206 const newActiveTimers : TimerNode [ ] = [ ] ;
188207
@@ -204,11 +223,20 @@ export class MathBackendWebGL implements KernelBackend {
204223 this . programTimersStack = null ;
205224 }
206225
207- return Promise . all ( flattenedActiveTimers ) . then ( results => {
226+ const kernelMs = await Promise . all ( flattenedActiveTimers ) . then ( results => {
208227 let sum = 0 ;
209228 results . forEach ( result => sum += result ) ;
210229 return sum ;
211230 } ) ;
231+ const res : WebGLTimingInfo = {
232+ uploadWaitMs : this . uploadWaitMs ,
233+ downloadWaitMs : this . downloadWaitMs ,
234+ kernelMs,
235+ wallMs : null // will be filled by the engine
236+ } ;
237+ this . uploadWaitMs = 0 ;
238+ this . downloadWaitMs = 0 ;
239+ return res ;
212240 }
213241 memory ( ) {
214242 return { unreliable : false } ;
@@ -933,6 +961,11 @@ export class MathBackendWebGL implements KernelBackend {
933961 // Array is already on GPU. No-op.
934962 return ;
935963 }
964+ const shouldTimeProgram = this . activeTimers != null ;
965+ let start : number ;
966+ if ( shouldTimeProgram ) {
967+ start = performance . now ( ) ;
968+ }
936969 const texShape =
937970 webgl_util . getTextureShapeFromLogicalShape ( this . gpgpu . gl , shape ) ;
938971 texData . texShape = texShape ;
@@ -945,6 +978,9 @@ export class MathBackendWebGL implements KernelBackend {
945978 texShape [ 1 ] , typedArrayToFloat32 ( values , dtype ) ) ;
946979 // Once uploaded, don't store the values on cpu.
947980 texData . values = null ;
981+ if ( shouldTimeProgram ) {
982+ this . uploadWaitMs += performance . now ( ) - start ;
983+ }
948984 }
949985 }
950986
0 commit comments