Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 230 additions & 0 deletions ops/opset13/batch_normalization.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
package opset13

import (
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

const (
MinBatchNormalizationInputs = 5
MaxBatchNormalizationInputs = 5
BatchNormalizationDefaultEpsilon = 1e-5
BatchNormalizationDefaultMomentum = 0.9
)

// BatchNormalization represents the ONNX batchNormalization operator.
type BatchNormalization struct {
epsilon float32
momentum float32
testMode bool
}

// newBatchNormalization creates a new batchNormalization operator.
func newBatchNormalization() ops.Operator {
return &BatchNormalization{
epsilon: BatchNormalizationDefaultEpsilon,
momentum: BatchNormalizationDefaultMomentum,
}
}

// Init initializes the batchNormalization operator.
func (b *BatchNormalization) Init(n *onnx.NodeProto) error {
hasMomentum := false

for _, attr := range n.GetAttribute() {
switch attr.GetName() {
case "epsilon":
b.epsilon = attr.GetF()
case "momentum":
hasMomentum = true
b.momentum = attr.GetF()
default:
return ops.ErrInvalidAttribute(attr.GetName(), b)
}
}

if !hasMomentum {
b.testMode = true
}

// We only support test mode, as this is by far the most common for inference models.
if !b.testMode {
return ops.ErrUnsupportedAttribute("momentum", b)
}

return nil
}

// Apply applies the batchNormalization operator.
func (b *BatchNormalization) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
X := inputs[0]
scale := inputs[1]
B := inputs[2]
mean := inputs[3]
variance := inputs[4]

out, err := b.testModeCalculation(X, scale, B, mean, variance)
if err != nil {
return nil, err
}

return []tensor.Tensor{out}, nil
}

// ValidateInputs validates the inputs that will be given to Apply for this operator.
func (b *BatchNormalization) ValidateInputs(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
return ops.ValidateInputs(b, inputs)
}

// GetMinInputs returns the minimum number of input tensors this operator expects.
func (b *BatchNormalization) GetMinInputs() int {
return MinBatchNormalizationInputs
}

// GetMaxInputs returns the maximum number of input tensors this operator expects.
func (b *BatchNormalization) GetMaxInputs() int {
return MaxBatchNormalizationInputs
}

// GetInputTypeConstraints returns a list. Every element represents a set of allowed tensor dtypes
// for the corresponding input tensor.
func (b *BatchNormalization) GetInputTypeConstraints() [][]tensor.Dtype {
return [][]tensor.Dtype{
{tensor.Float32, tensor.Float64},
{tensor.Float32, tensor.Float64},
{tensor.Float32, tensor.Float64},
{tensor.Float32, tensor.Float64},
{tensor.Float32, tensor.Float64},
}
}

// String implements the stringer interface, and can be used to format errors or messages.
func (b *BatchNormalization) String() string {
return "batchNormalization operator"
}

func (b *BatchNormalization) reshapeTensors(X, scale, bias, mean, variance tensor.Tensor) (newScale, newBias, newMean, newVariance tensor.Tensor, err error) {
nNonSpatialDims := 2

nSpatialDims := len(X.Shape()) - nNonSpatialDims
if nSpatialDims <= 0 {
return scale, bias, mean, variance, nil
}

// The new shape for the `scale`, `bias`, `mean` and `variance` tensors should
// be (C, 1, 1, ...), such that they can be broadcasted to match the shape of `X`.
newShape := make([]int, 1+nSpatialDims)

// Here we set the channel dimension. The channel dimension is the same
// for all `X`, `scale`, `bias`, `mean` and `variance` tensors.
newShape[0] = scale.Shape()[0]

// Set all the remaining dimensions to 1 to allow for broadcasting.
for i := 1; i < len(newShape); i++ {
newShape[i] = 1
}

// Now we create new tensors for all the input tensors (except `X`) and reshape
// them.
newScale, ok := scale.Clone().(tensor.Tensor)
if !ok {
return nil, nil, nil, nil, ops.ErrTypeAssert("tensor.Tensor", scale.Clone())
}

newBias, ok = bias.Clone().(tensor.Tensor)
if !ok {
return nil, nil, nil, nil, ops.ErrTypeAssert("tensor.Tensor", bias.Clone())
}

newMean, ok = mean.Clone().(tensor.Tensor)
if !ok {
return nil, nil, nil, nil, ops.ErrTypeAssert("tensor.Tensor", mean.Clone())
}

newVariance, ok = variance.Clone().(tensor.Tensor)
if !ok {
return nil, nil, nil, nil, ops.ErrTypeAssert("tensor.Tensor", variance.Clone())
}

err = newScale.Reshape(newShape...)
if err != nil {
return nil, nil, nil, nil, err
}

err = newBias.Reshape(newShape...)
if err != nil {
return nil, nil, nil, nil, err
}

err = newMean.Reshape(newShape...)
if err != nil {
return nil, nil, nil, nil, err
}

err = newVariance.Reshape(newShape...)
if err != nil {
return nil, nil, nil, nil, err
}

return
}

func (b *BatchNormalization) testModeCalculation(X, scale, bias, mean, variance tensor.Tensor) (tensor.Tensor, error) {
newScale, newBias, newMean, newVariance, err := b.reshapeTensors(X, scale, bias, mean, variance)
if err != nil {
return nil, err
}

numerator, err := ops.ApplyBinaryOperation(
X,
newMean,
ops.Sub,
ops.UnidirectionalBroadcasting,
)
if err != nil {
return nil, err
}

numerator, err = ops.ApplyBinaryOperation(
numerator[0],
newScale,
ops.Mul,
ops.UnidirectionalBroadcasting,
)
if err != nil {
return nil, err
}

denominator, err := tensor.Add(newVariance, b.epsilon)
if err != nil {
return nil, err
}

denominator, err = tensor.Sqrt(denominator)
if err != nil {
return nil, err
}

outputs, err := ops.ApplyBinaryOperation(
numerator[0],
denominator,
ops.Div,
ops.UnidirectionalBroadcasting,
)
if err != nil {
return nil, err
}

outputs, err = ops.ApplyBinaryOperation(
outputs[0],
newBias,
ops.Add,
ops.UnidirectionalBroadcasting,
)
if err != nil {
return nil, err
}

return outputs[0], nil
}
146 changes: 146 additions & 0 deletions ops/opset13/batch_normalization_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
package opset13

import (
"testing"

"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"github.com/stretchr/testify/assert"
"gorgonia.org/tensor"
)

func TestBatchNormalizationInit(t *testing.T) {
b := &BatchNormalization{}

err := b.Init(
&onnx.NodeProto{
Attribute: []*onnx.AttributeProto{
{Name: "epsilon", F: 0.001},
},
},
)
assert.Nil(t, err)

assert.Equal(t, float32(0.001), b.epsilon)
assert.True(t, b.testMode)
}

func TestBatchNormalizationInitTrainingMode(t *testing.T) {
b := &BatchNormalization{}

err := b.Init(
&onnx.NodeProto{
Attribute: []*onnx.AttributeProto{
{Name: "epsilon", F: 0.001},
{Name: "momentum", F: 0.99},
},
},
)
assert.Equal(t, ops.ErrUnsupportedAttribute("momentum", b), err)

assert.Equal(t, float32(0.001), b.epsilon)
assert.Equal(t, float32(0.99), b.momentum)
assert.False(t, b.testMode)
}

func TestBatchNormalization(t *testing.T) {
tests := []struct {
batchNormalization *BatchNormalization
backings [][]float32
shapes [][]int
expected []float32
}{
{
&BatchNormalization{
epsilon: 1e5,
momentum: 0.9,
testMode: true,
},
[][]float32{
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
{0.2, 0.3, 0.4},
{0.1, -0.1, 0.2},
{4, 8, 12},
{1, 2, 3},
},
[][]int{
{2, 3, 2, 2},
{3},
{3},
{3},
{3},
},
[]float32{0.097470194, 0.098102644, 0.098735094, 0.09936755, -0.103794694, -0.10284603, -0.10189735, -0.10094868, 0.19494043, 0.19620533, 0.19747022, 0.19873512, 0.10505962, 0.10569207, 0.10632452, 0.10695698, -0.09241061, -0.091461934, -0.09051326, -0.08956459, 0.21011914, 0.21138403, 0.21264893, 0.21391381},
},
}

for _, test := range tests {
inputs := []tensor.Tensor{
ops.TensorWithBackingFixture(test.backings[0], test.shapes[0]...),
ops.TensorWithBackingFixture(test.backings[1], test.shapes[1]...),
ops.TensorWithBackingFixture(test.backings[2], test.shapes[2]...),
ops.TensorWithBackingFixture(test.backings[3], test.shapes[3]...),
ops.TensorWithBackingFixture(test.backings[4], test.shapes[4]...),
}

res, err := test.batchNormalization.Apply(inputs)
assert.Nil(t, err)

assert.Equal(t, test.expected, res[0].Data())
}
}

func TestInputValidationBatchNormalization(t *testing.T) {
tests := []struct {
inputs []tensor.Tensor
err error
}{
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
ops.TensorWithBackingFixture([]float32{3, 4}, 2),
ops.TensorWithBackingFixture([]float32{3, 4}, 2),
ops.TensorWithBackingFixture([]float32{3, 4}, 2),
ops.TensorWithBackingFixture([]float32{3, 4}, 2),
},
nil,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float64{1, 2}, 2),
ops.TensorWithBackingFixture([]float64{3, 4}, 2),
ops.TensorWithBackingFixture([]float64{3, 4}, 2),
ops.TensorWithBackingFixture([]float64{3, 4}, 2),
ops.TensorWithBackingFixture([]float64{3, 4}, 2),
},
nil,
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int{1, 2}, 2),
},
ops.ErrInvalidInputCount(1, &BatchNormalization{}),
},
{
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
ops.TensorWithBackingFixture([]int{3, 4}, 2),
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
},
ops.ErrInvalidInputType(1, "int", &BatchNormalization{}),
},
}

for _, test := range tests {
batchNormalization := &BatchNormalization{}
validated, err := batchNormalization.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)

if test.err == nil {
assert.Equal(t, test.inputs, validated)
}
}
}
Loading