Skip to content

Commit fb943a5

Browse files
committed
Fix out of bound exception
1 parent 06316f9 commit fb943a5

File tree

1 file changed

+11
-11
lines changed
  • plugins-INDArrayLayers/src/main/scala-2.11/com/thoughtworks/deeplearning/plugins

1 file changed

+11
-11
lines changed

plugins-INDArrayLayers/src/main/scala-2.11/com/thoughtworks/deeplearning/plugins/INDArrayLayers.scala

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,17 @@ object INDArrayLayers {
6464
private[plugins] implicit final class Nd4jIssues1869Workaround(indArray: INDArray) {
6565
def broadcastFix(outputShape: Int*): INDArray = {
6666
indArray.shape match {
67-
case currentShape if (currentShape: Seq[Int]) == outputShape => indArray
68-
case currentShape =>
69-
currentShape.padTo(outputShape.length, 1).indices.foldLeft(indArray.reshape(currentShape: _*)) {
70-
(indArray, i) =>
71-
val o = outputShape(i)
72-
if (o != 1 && o != currentShape(i)) {
73-
currentShape(i) = o
74-
indArray.broadcast(currentShape: _*)
75-
} else {
76-
indArray
77-
}
67+
case oldShape if (oldShape: Seq[Int]) == outputShape => indArray
68+
case oldShape =>
69+
val currentShape = oldShape.padTo(outputShape.length, 1)
70+
currentShape.indices.foldLeft(indArray) { (indArray, i) =>
71+
val o = outputShape(i)
72+
if (o != 1 && o != currentShape(i)) {
73+
currentShape(i) = o
74+
indArray.broadcast(currentShape: _*)
75+
} else {
76+
indArray
77+
}
7878
}
7979
}
8080
}

0 commit comments

Comments
 (0)