File tree Expand file tree Collapse file tree 1 file changed +11
-11
lines changed
plugins-INDArrayLayers/src/main/scala-2.11/com/thoughtworks/deeplearning/plugins Expand file tree Collapse file tree 1 file changed +11
-11
lines changed Original file line number Diff line number Diff line change @@ -64,17 +64,17 @@ object INDArrayLayers {
6464 private [plugins] implicit final class Nd4jIssues1869Workaround (indArray : INDArray ) {
6565 def broadcastFix (outputShape : Int * ): INDArray = {
6666 indArray.shape match {
67- case currentShape if (currentShape : Seq [Int ]) == outputShape => indArray
68- case currentShape =>
69- currentShape.padTo(outputShape.length, 1 ).indices.foldLeft(indArray.reshape( currentShape : _* )) {
70- (indArray, i) =>
71- val o = outputShape(i)
72- if (o != 1 && o != currentShape(i)) {
73- currentShape(i) = o
74- indArray.broadcast(currentShape : _* )
75- } else {
76- indArray
77- }
67+ case oldShape if (oldShape : Seq [Int ]) == outputShape => indArray
68+ case oldShape =>
69+ val currentShape = oldShape .padTo(outputShape.length, 1 )
70+ currentShape.indices.foldLeft(indArray) { (indArray, i) =>
71+ val o = outputShape(i)
72+ if (o != 1 && o != currentShape(i)) {
73+ currentShape(i) = o
74+ indArray.broadcast(currentShape : _* )
75+ } else {
76+ indArray
77+ }
7878 }
7979 }
8080 }
You can’t perform that action at this time.
0 commit comments