@@ -258,41 +258,41 @@ __global__ void Im2Col(AFloat * A,
258258 int zeroPaddingHeight,
259259 int zeroPaddingWidth)
260260{
261- // The row of the output matrix.
262- int i = blockDim .y * blockIdx .y + threadIdx .y ;
261+ // The row of the output matrix.
262+ int i = blockDim .y * blockIdx .y + threadIdx .y ;
263263
264- // The column of the output matrix.
265- int j = blockDim .x * blockIdx .x + threadIdx .x ;
264+ // The column of the output matrix.
265+ int j = blockDim .x * blockIdx .x + threadIdx .x ;
266266
267- // Number of column in matrix A.
268- int NLocalViewPixels = fltHeight * fltWidth * depth;
267+ // Number of column in matrix A.
268+ int NLocalViewPixels = fltHeight * fltWidth * depth;
269269
270- // Number of rows in matrix A.
271- int NLocalViews = calculateDimension (imgWidth, fltWidth, zeroPaddingWidth, strideCols) *
272- calculateDimension (imgHeight, fltHeight, zeroPaddingHeight, strideRows);
270+ // Number of rows in matrix A.
271+ int NLocalViews = calculateDimension (imgWidth, fltWidth, zeroPaddingWidth, strideCols) *
272+ calculateDimension (imgHeight, fltHeight, zeroPaddingHeight, strideRows);
273273
274- if (i > NLocalViews || j > NLocalViewPixels) return ;
274+ if (i >= NLocalViews || j >= NLocalViewPixels) return ;
275275
276- int index = j + i * NLocalViewPixels ;
276+ int index = j * NLocalViews + i ;
277277
278- int numSlidesPerRow = calculateDimension (imgWidth, fltWidth, zeroPaddingWidth, strideCols);
278+ int numSlidesPerRow = calculateDimension (imgWidth, fltWidth, zeroPaddingWidth, strideCols);
279279
280- // Which image channel of B?
281- int bz = j / (fltHeight * fltWidth);
280+ // Which image channel of B?
281+ int bz = j / (fltHeight * fltWidth);
282282
283- // Which row in matrix B?
284- int by = (i / numSlidesPerRow) * strideRows - zeroPaddingHeight + j / fltWidth;
283+ // Which row in matrix B?
284+ int by = (i / numSlidesPerRow) * strideRows - zeroPaddingHeight + (j - bz * fltHeight * fltWidth) / fltWidth;
285285
286- // Which column in matrix B?
287- int bx = (i % numSlidesPerRow) * strideCols - zeroPaddingWidth + j % fltWidth;
286+ // Which column in matrix B?
287+ int bx = (i % numSlidesPerRow) * strideCols - zeroPaddingWidth + (j - bz * fltHeight * fltWidth) % fltWidth;
288288
289- if (bx < 0 || by < 0 || bx >= imgWidth || by >= imgHeight) {
290- // This is a padding element.
291- A[index] = 0 ;
292- }
293- else {
294- A[index] = B[bx + by * imgWidth + bz * imgHeight * imgWidth ];
295- }
289+ if (bx < 0 || by < 0 || bx >= imgWidth || by >= imgHeight) {
290+ // This is a padding element.
291+ A[index] = 0 ;
292+ }
293+ else {
294+ A[index] = B[( bx + by * imgWidth) * depth + bz ];
295+ }
296296}
297297// ____________________________________________________________________________
298298template <typename AFloat>
0 commit comments