Skip to content

Commit 00a62a7

Browse files
committed
Fixed an indexing bug
1 parent aa16408 commit 00a62a7

1 file changed

Lines changed: 25 additions & 25 deletions

File tree

tmva/tmva/src/DNN/Architectures/Cuda/Kernels.cuh

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -258,41 +258,41 @@ __global__ void Im2Col(AFloat * A,
258258
int zeroPaddingHeight,
259259
int zeroPaddingWidth)
260260
{
261-
// The row of the output matrix.
262-
int i = blockDim.y * blockIdx.y + threadIdx.y;
261+
// The row of the output matrix.
262+
int i = blockDim.y * blockIdx.y + threadIdx.y;
263263

264-
// The column of the output matrix.
265-
int j = blockDim.x * blockIdx.x + threadIdx.x;
264+
// The column of the output matrix.
265+
int j = blockDim.x * blockIdx.x + threadIdx.x;
266266

267-
// Number of column in matrix A.
268-
int NLocalViewPixels = fltHeight * fltWidth * depth;
267+
// Number of column in matrix A.
268+
int NLocalViewPixels = fltHeight * fltWidth * depth;
269269

270-
// Number of rows in matrix A.
271-
int NLocalViews = calculateDimension(imgWidth, fltWidth, zeroPaddingWidth, strideCols) *
272-
calculateDimension(imgHeight, fltHeight, zeroPaddingHeight, strideRows);
270+
// Number of rows in matrix A.
271+
int NLocalViews = calculateDimension(imgWidth, fltWidth, zeroPaddingWidth, strideCols) *
272+
calculateDimension(imgHeight, fltHeight, zeroPaddingHeight, strideRows);
273273

274-
if (i > NLocalViews || j > NLocalViewPixels) return;
274+
if (i >= NLocalViews || j >= NLocalViewPixels) return;
275275

276-
int index = j + i * NLocalViewPixels;
276+
int index = j * NLocalViews + i;
277277

278-
int numSlidesPerRow = calculateDimension(imgWidth, fltWidth, zeroPaddingWidth, strideCols);
278+
int numSlidesPerRow = calculateDimension(imgWidth, fltWidth, zeroPaddingWidth, strideCols);
279279

280-
// Which image channel of B?
281-
int bz = j / (fltHeight * fltWidth);
280+
// Which image channel of B?
281+
int bz = j / (fltHeight * fltWidth);
282282

283-
// Which row in matrix B?
284-
int by = (i / numSlidesPerRow) * strideRows - zeroPaddingHeight + j / fltWidth;
283+
// Which row in matrix B?
284+
int by = (i / numSlidesPerRow) * strideRows - zeroPaddingHeight + (j - bz * fltHeight * fltWidth) / fltWidth;
285285

286-
// Which column in matrix B?
287-
int bx = (i % numSlidesPerRow) * strideCols - zeroPaddingWidth + j % fltWidth;
286+
// Which column in matrix B?
287+
int bx = (i % numSlidesPerRow) * strideCols - zeroPaddingWidth + (j - bz * fltHeight * fltWidth) % fltWidth;
288288

289-
if (bx < 0 || by < 0 || bx >= imgWidth || by >= imgHeight) {
290-
// This is a padding element.
291-
A[index] = 0;
292-
}
293-
else {
294-
A[index] = B[bx + by * imgWidth + bz * imgHeight * imgWidth];
295-
}
289+
if (bx < 0 || by < 0 || bx >= imgWidth || by >= imgHeight) {
290+
// This is a padding element.
291+
A[index] = 0;
292+
}
293+
else {
294+
A[index] = B[(bx + by * imgWidth) * depth + bz];
295+
}
296296
}
297297
//____________________________________________________________________________
298298
template<typename AFloat>

0 commit comments

Comments
 (0)