diff --git a/ops/opset13/batch_normalization.go b/ops/opset13/batch_normalization.go new file mode 100644 index 0000000..50c9e19 --- /dev/null +++ b/ops/opset13/batch_normalization.go @@ -0,0 +1,230 @@ +package opset13 + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +const ( + MinBatchNormalizationInputs = 5 + MaxBatchNormalizationInputs = 5 + BatchNormalizationDefaultEpsilon = 1e-5 + BatchNormalizationDefaultMomentum = 0.9 +) + +// BatchNormalization represents the ONNX batchNormalization operator. +type BatchNormalization struct { + epsilon float32 + momentum float32 + testMode bool +} + +// newBatchNormalization creates a new batchNormalization operator. +func newBatchNormalization() ops.Operator { + return &BatchNormalization{ + epsilon: BatchNormalizationDefaultEpsilon, + momentum: BatchNormalizationDefaultMomentum, + } +} + +// Init initializes the batchNormalization operator. +func (b *BatchNormalization) Init(n *onnx.NodeProto) error { + hasMomentum := false + + for _, attr := range n.GetAttribute() { + switch attr.GetName() { + case "epsilon": + b.epsilon = attr.GetF() + case "momentum": + hasMomentum = true + b.momentum = attr.GetF() + default: + return ops.ErrInvalidAttribute(attr.GetName(), b) + } + } + + if !hasMomentum { + b.testMode = true + } + + // We only support test mode, as this is by far the most common for inference models. + if !b.testMode { + return ops.ErrUnsupportedAttribute("momentum", b) + } + + return nil +} + +// Apply applies the batchNormalization operator. +func (b *BatchNormalization) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + X := inputs[0] + scale := inputs[1] + B := inputs[2] + mean := inputs[3] + variance := inputs[4] + + out, err := b.testModeCalculation(X, scale, B, mean, variance) + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +// ValidateInputs validates the inputs that will be given to Apply for this operator. +func (b *BatchNormalization) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + return ops.ValidateInputs(b, inputs) +} + +// GetMinInputs returns the minimum number of input tensors this operator expects. +func (b *BatchNormalization) GetMinInputs() int { + return MinBatchNormalizationInputs +} + +// GetMaxInputs returns the maximum number of input tensors this operator expects. +func (b *BatchNormalization) GetMaxInputs() int { + return MaxBatchNormalizationInputs +} + +// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes +// for the corresponding input tensor. +func (b *BatchNormalization) GetInputTypeConstraints() [][]tensor.Dtype { + return [][]tensor.Dtype{ + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + {tensor.Float32, tensor.Float64}, + } +} + +// String implements the stringer interface, and can be used to format errors or messages. +func (b *BatchNormalization) String() string { + return "batchNormalization operator" +} + +func (b *BatchNormalization) reshapeTensors(X, scale, bias, mean, variance tensor.Tensor) (newScale, newBias, newMean, newVariance tensor.Tensor, err error) { + nNonSpatialDims := 2 + + nSpatialDims := len(X.Shape()) - nNonSpatialDims + if nSpatialDims <= 0 { + return scale, bias, mean, variance, nil + } + + // The new shape for the `scale`, `bias`, `mean` and `variance` tensors should + // be (C, 1, 1, ...), such that they can be broadcasted to match the shape of `X`. + newShape := make([]int, 1+nSpatialDims) + + // Here we set the channel dimension. The channel dimension is the same + // for all `X`, `scale`, `bias`, `mean` and `variance` tensors. + newShape[0] = scale.Shape()[0] + + // Set all the remaining dimensions to 1 to allow for broadcasting. + for i := 1; i < len(newShape); i++ { + newShape[i] = 1 + } + + // Now we create new tensors for all the input tensors (except `X`) and reshape + // them. + newScale, ok := scale.Clone().(tensor.Tensor) + if !ok { + return nil, nil, nil, nil, ops.ErrTypeAssert("tensor.Tensor", scale.Clone()) + } + + newBias, ok = bias.Clone().(tensor.Tensor) + if !ok { + return nil, nil, nil, nil, ops.ErrTypeAssert("tensor.Tensor", bias.Clone()) + } + + newMean, ok = mean.Clone().(tensor.Tensor) + if !ok { + return nil, nil, nil, nil, ops.ErrTypeAssert("tensor.Tensor", mean.Clone()) + } + + newVariance, ok = variance.Clone().(tensor.Tensor) + if !ok { + return nil, nil, nil, nil, ops.ErrTypeAssert("tensor.Tensor", variance.Clone()) + } + + err = newScale.Reshape(newShape...) + if err != nil { + return nil, nil, nil, nil, err + } + + err = newBias.Reshape(newShape...) + if err != nil { + return nil, nil, nil, nil, err + } + + err = newMean.Reshape(newShape...) + if err != nil { + return nil, nil, nil, nil, err + } + + err = newVariance.Reshape(newShape...) + if err != nil { + return nil, nil, nil, nil, err + } + + return +} + +func (b *BatchNormalization) testModeCalculation(X, scale, bias, mean, variance tensor.Tensor) (tensor.Tensor, error) { + newScale, newBias, newMean, newVariance, err := b.reshapeTensors(X, scale, bias, mean, variance) + if err != nil { + return nil, err + } + + numerator, err := ops.ApplyBinaryOperation( + X, + newMean, + ops.Sub, + ops.UnidirectionalBroadcasting, + ) + if err != nil { + return nil, err + } + + numerator, err = ops.ApplyBinaryOperation( + numerator[0], + newScale, + ops.Mul, + ops.UnidirectionalBroadcasting, + ) + if err != nil { + return nil, err + } + + denominator, err := tensor.Add(newVariance, b.epsilon) + if err != nil { + return nil, err + } + + denominator, err = tensor.Sqrt(denominator) + if err != nil { + return nil, err + } + + outputs, err := ops.ApplyBinaryOperation( + numerator[0], + denominator, + ops.Div, + ops.UnidirectionalBroadcasting, + ) + if err != nil { + return nil, err + } + + outputs, err = ops.ApplyBinaryOperation( + outputs[0], + newBias, + ops.Add, + ops.UnidirectionalBroadcasting, + ) + if err != nil { + return nil, err + } + + return outputs[0], nil +} diff --git a/ops/opset13/batch_normalization_test.go b/ops/opset13/batch_normalization_test.go new file mode 100644 index 0000000..6feee4e --- /dev/null +++ b/ops/opset13/batch_normalization_test.go @@ -0,0 +1,146 @@ +package opset13 + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestBatchNormalizationInit(t *testing.T) { + b := &BatchNormalization{} + + err := b.Init( + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "epsilon", F: 0.001}, + }, + }, + ) + assert.Nil(t, err) + + assert.Equal(t, float32(0.001), b.epsilon) + assert.True(t, b.testMode) +} + +func TestBatchNormalizationInitTrainingMode(t *testing.T) { + b := &BatchNormalization{} + + err := b.Init( + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "epsilon", F: 0.001}, + {Name: "momentum", F: 0.99}, + }, + }, + ) + assert.Equal(t, ops.ErrUnsupportedAttribute("momentum", b), err) + + assert.Equal(t, float32(0.001), b.epsilon) + assert.Equal(t, float32(0.99), b.momentum) + assert.False(t, b.testMode) +} + +func TestBatchNormalization(t *testing.T) { + tests := []struct { + batchNormalization *BatchNormalization + backings [][]float32 + shapes [][]int + expected []float32 + }{ + { + &BatchNormalization{ + epsilon: 1e5, + momentum: 0.9, + testMode: true, + }, + [][]float32{ + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + {0.2, 0.3, 0.4}, + {0.1, -0.1, 0.2}, + {4, 8, 12}, + {1, 2, 3}, + }, + [][]int{ + {2, 3, 2, 2}, + {3}, + {3}, + {3}, + {3}, + }, + []float32{0.097470194, 0.098102644, 0.098735094, 0.09936755, -0.103794694, -0.10284603, -0.10189735, -0.10094868, 0.19494043, 0.19620533, 0.19747022, 0.19873512, 0.10505962, 0.10569207, 0.10632452, 0.10695698, -0.09241061, -0.091461934, -0.09051326, -0.08956459, 0.21011914, 0.21138403, 0.21264893, 0.21391381}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backings[0], test.shapes[0]...), + ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...), + ops.TensorWithBackingFixture(test.backings[2], test.shapes[2]...), + ops.TensorWithBackingFixture(test.backings[3], test.shapes[3]...), + ops.TensorWithBackingFixture(test.backings[4], test.shapes[4]...), + } + + res, err := test.batchNormalization.Apply(inputs) + assert.Nil(t, err) + + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationBatchNormalization(t *testing.T) { + tests := []struct { + inputs []tensor.Tensor + err error + }{ + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{3, 4}, 2), + ops.TensorWithBackingFixture([]float32{3, 4}, 2), + ops.TensorWithBackingFixture([]float32{3, 4}, 2), + ops.TensorWithBackingFixture([]float32{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float64{1, 2}, 2), + ops.TensorWithBackingFixture([]float64{3, 4}, 2), + ops.TensorWithBackingFixture([]float64{3, 4}, 2), + ops.TensorWithBackingFixture([]float64{3, 4}, 2), + ops.TensorWithBackingFixture([]float64{3, 4}, 2), + }, + nil, + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputCount(1, &BatchNormalization{}), + }, + { + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]int{3, 4}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + ops.ErrInvalidInputType(1, "int", &BatchNormalization{}), + }, + } + + for _, test := range tests { + batchNormalization := &BatchNormalization{} + validated, err := batchNormalization.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} diff --git a/ops/opset13/opset13.go b/ops/opset13/opset13.go index 3930ab3..0f680b8 100644 --- a/ops/opset13/opset13.go +++ b/ops/opset13/opset13.go @@ -5,61 +5,62 @@ import ( ) var operators13 = map[string]func() ops.Operator{ - "Abs": newAbs, - "Acos": newAcos, - "Acosh": newAcosh, - "Add": newAdd, - "And": newAnd, - "ArgMax": newArgMax, - "Asin": newAsin, - "Asinh": newAsinh, - "Atan": newAtan, - "Atanh": newAtanh, - "Cast": newCast, - "Concat": newConcat, - "Constant": newConstant, - "ConstantOfShape": newConstantOfShape, - "Conv": newConv, - "Cos": newCos, - "Cosh": newCosh, - "Div": newDiv, - "Equal": newEqual, - "Expand": newExpand, - "Flatten": newFlatten, - "Gather": newGather, - "Gemm": newGemm, - "Greater": newGreater, - "GreaterOrEqual": newGreaterOrEqual, - "GRU": newGRU, - "Less": newLess, - "LessOrEqual": newLessOrEqual, - "LinearRegressor": newLinearRegressor, - "LogSoftmax": newLogSoftmax, - "LSTM": newLSTM, - "MatMul": newMatMul, - "Mul": newMul, - "Not": newNot, - "Or": newOr, - "PRelu": newPRelu, - "ReduceMax": newReduceMax, - "ReduceMin": newReduceMin, - "Relu": newRelu, - "Reshape": newReshape, - "RNN": newRNN, - "Scaler": newScaler, - "Shape": newShape, - "Sigmoid": newSigmoid, - "Sin": newSin, - "Sinh": newSinh, - "Slice": newSlice, - "Softmax": newSoftmax, - "Squeeze": newSqueeze, - "Sub": newSub, - "Tan": newTan, - "Tanh": newTanh, - "Transpose": newTranspose, - "Unsqueeze": newUnsqueeze, - "Xor": newXor, + "Abs": newAbs, + "Acos": newAcos, + "Acosh": newAcosh, + "Add": newAdd, + "And": newAnd, + "ArgMax": newArgMax, + "Asin": newAsin, + "Asinh": newAsinh, + "Atan": newAtan, + "Atanh": newAtanh, + "BatchNormalization": newBatchNormalization, + "Cast": newCast, + "Concat": newConcat, + "Constant": newConstant, + "ConstantOfShape": newConstantOfShape, + "Conv": newConv, + "Cos": newCos, + "Cosh": newCosh, + "Div": newDiv, + "Equal": newEqual, + "Expand": newExpand, + "Flatten": newFlatten, + "Gather": newGather, + "Gemm": newGemm, + "Greater": newGreater, + "GreaterOrEqual": newGreaterOrEqual, + "GRU": newGRU, + "Less": newLess, + "LessOrEqual": newLessOrEqual, + "LinearRegressor": newLinearRegressor, + "LogSoftmax": newLogSoftmax, + "LSTM": newLSTM, + "MatMul": newMatMul, + "Mul": newMul, + "Not": newNot, + "Or": newOr, + "PRelu": newPRelu, + "ReduceMax": newReduceMax, + "ReduceMin": newReduceMin, + "Relu": newRelu, + "Reshape": newReshape, + "RNN": newRNN, + "Scaler": newScaler, + "Shape": newShape, + "Sigmoid": newSigmoid, + "Sin": newSin, + "Sinh": newSinh, + "Slice": newSlice, + "Softmax": newSoftmax, + "Squeeze": newSqueeze, + "Sub": newSub, + "Tan": newTan, + "Tanh": newTanh, + "Transpose": newTranspose, + "Unsqueeze": newUnsqueeze, + "Xor": newXor, } // GetOperator maps strings as found in the ModelProto to Operators from opset 13. diff --git a/ops/opset13/opset13_test.go b/ops/opset13/opset13_test.go index 01008ee..c4ed60b 100644 --- a/ops/opset13/opset13_test.go +++ b/ops/opset13/opset13_test.go @@ -54,6 +54,11 @@ func TestGetOperator(t *testing.T) { newAsinh(), nil, }, + { + "BatchNormalization", + newBatchNormalization(), + nil, + }, { "Atan", newAtan(), diff --git a/ops/opset13/rnn.go b/ops/opset13/rnn.go index b3248d8..ff0420b 100644 --- a/ops/opset13/rnn.go +++ b/ops/opset13/rnn.go @@ -219,6 +219,7 @@ func (r *RNN) layerCalculation( // getWeights returns the weights from a concatenated weight tensor. The result is // a single weight matrix. W has shape (num_directions, hidden_size, ...). +// This function extracts 1 weight matrix from tensor W. // The W tensor, by GONNX definition, has 3 dimensions with 1 weight // tensor in it (2 if bidirectional, but that is not supported). func (r *RNN) getWeights(W tensor.Tensor) (tensor.Tensor, error) { diff --git a/ops_test.go b/ops_test.go index a07a351..d4af96e 100644 --- a/ops_test.go +++ b/ops_test.go @@ -25,6 +25,8 @@ import ( // implemented, or lower, which we also haven't implemented yet. var ignoredTests = []string{ "test_add_uint8", // Opset14 + "test_batchnorm_epsilon_training_mode", // Opset14 + "test_batchnorm_example_training_mode", // Opset14 "test_div_uint8", // Opset14 "test_gru_batchwise", // Opset14 "test_logsoftmax_axis_1_expanded_ver18", // Opset18 @@ -252,6 +254,12 @@ func shouldRunTest(folder, opFilter string) bool { } } + // For some reason ONNX decided to not let these testcases match the operator name. + // Here we manually replace the filter with the name ONNX uses for this test case. + if opFilter == "test_batchnormalization" { + opFilter = "test_batchnorm" + } + if strings.Contains(folder, opFilter) { remaining := strings.ReplaceAll(folder, opFilter, "") if len(remaining) == 0 || remaining[:1] == "_" { @@ -379,6 +387,8 @@ var expectedTests = []string{ "test_atan_example", "test_atanh", "test_atanh_example", + "test_batchnorm_epsilon", + "test_batchnorm_example", "test_cast_DOUBLE_to_FLOAT", "test_cast_FLOAT_to_DOUBLE", "test_concat_1d_axis_0",