From 988d82da21a96bf1517e9cf39cc631e32f1bbcbf Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Sun, 22 Dec 2024 09:43:36 +0100 Subject: [PATCH 1/3] Added CumSum operator --- model_test.go | 2 - ops/cumsum/cumsum.go | 169 ++++++++++++++++++++++++++++++++++++++ ops/cumsum/cumsum_test.go | 118 ++++++++++++++++++++++++++ ops/cumsum/versions.go | 13 +++ ops/utils.go | 16 ++++ ops_test.go | 7 ++ opset.go | 2 + 7 files changed, 325 insertions(+), 2 deletions(-) create mode 100644 ops/cumsum/cumsum.go create mode 100644 ops/cumsum/cumsum_test.go create mode 100644 ops/cumsum/versions.go diff --git a/model_test.go b/model_test.go index 54b9ed3..ed1c4b9 100644 --- a/model_test.go +++ b/model_test.go @@ -1,7 +1,6 @@ package gonnx import ( - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/onnx" @@ -97,7 +96,6 @@ func TestModel(t *testing.T) { } for _, test := range tests { - fmt.Println(test.path) model, err := NewModelFromFile(test.path) assert.Nil(t, err) diff --git a/ops/cumsum/cumsum.go b/ops/cumsum/cumsum.go new file mode 100644 index 0000000..a11d77b --- /dev/null +++ b/ops/cumsum/cumsum.go @@ -0,0 +1,169 @@ +package cumsum + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var cumsumTypeConstraints = [][]tensor.Dtype{ + {tensor.Int32, tensor.Int64, tensor.Uint32, tensor.Uint64, tensor.Float32, tensor.Float64}, + {tensor.Int32, tensor.Int64}, +} + +// CumSum represents the ONNX cumsum operator. +type CumSum struct { + ops.BaseOperator + + exclusive bool + reverse bool +} + +// newCumSum creates a new cumsum operator. +func newCumSum(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &CumSum{ + BaseOperator: ops.NewBaseOperator( + version, + 2, + 2, + typeConstraints, + "cumsum", + ), + exclusive: false, + reverse: false, + } +} + +// Init initializes the cumsum operator. +func (c *CumSum) Init(n *onnx.NodeProto) error { + for _, attr := range n.GetAttribute() { + switch attr.GetName() { + case "exclusive": + c.exclusive = attr.GetI() == 1 + case "reverse": + c.reverse = attr.GetI() == 1 + default: + return ops.ErrInvalidAttribute(attr.GetName(), c) + } + } + + return nil +} + +// Apply applies the cumsum operator. +func (c *CumSum) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + axis, err := ops.AnyToInt(inputs[1].ScalarValue()) + if err != nil { + return nil, err + } + + out, err := cumsum(inputs[0], axis, c.exclusive, c.reverse) + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +// Performs cumulative sum of the input elements along the given axis. By default, it will do the sum inclusively meaning the first element is copied as is. Through an exclusive attribute, this behavior can change to exclude the first element. It can also perform summation in the opposite direction of the axis. For that, set reverse attribute to 1. +func cumsum(x tensor.Tensor, axis int, exclusive, reverse bool) (tensor.Tensor, error) { + // First we copy the input tensor to the output tensor. + out, ok := x.Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrCast + } + + nDims := len(x.Shape()) + axis = ops.ConvertNegativeAxis(axis, nDims) + + if axis < 0 || axis >= nDims { + return nil, ops.ErrAxisOutOfRange(0, nDims, axis) + } + + axisSize := x.Shape()[axis] + + var startValue int + if reverse { + startValue = axisSize - 1 + } else { + startValue = 0 + } + + slices := make([]tensor.Slice, nDims) + slices[axis] = ops.NewSlicer(startValue, startValue+1) + + prevView, err := x.Slice(slices...) + if err != nil { + return nil, err + } + + prevValues := prevView.Materialize() + + for i := startValue; endValueReached(i, axisSize, reverse); { + slices[axis] = ops.NewSlicer(i, i+1) + + currentView, err := out.Slice(slices...) + if err != nil { + return nil, err + } + + currentValues := currentView.Materialize() + + // If exclusive is true, the first result in the cumsum opertaion is zero. + // We can achieve this by subtracting the current values from the current values. + // This way we don't have to infer the underlying type of the tensor. + if i == startValue && exclusive { + zeroValues, err := ops.Sub(currentValues, currentValues) + if err != nil { + return nil, err + } + + err = tensor.Copy(currentView, zeroValues) + if err != nil { + return nil, err + } + } + + if (i != startValue) && exclusive { + err = tensor.Copy(currentView, prevValues) + if err != nil { + return nil, err + } + + newValues, err := ops.Add(currentValues, prevValues) + if err != nil { + return nil, err + } + + prevValues = newValues + } else if i != startValue { + newValues, err := ops.Add(currentValues, prevValues) + if err != nil { + return nil, err + } + + err = tensor.Copy(currentView, newValues) + if err != nil { + return nil, err + } + + prevValues = newValues + } + + if reverse { + i-- + } else { + i++ + } + } + + return out, nil +} + +func endValueReached(i, axisSize int, reverse bool) bool { + if reverse { + return i >= 0 + } + + return i < axisSize +} diff --git a/ops/cumsum/cumsum_test.go b/ops/cumsum/cumsum_test.go new file mode 100644 index 0000000..6595b29 --- /dev/null +++ b/ops/cumsum/cumsum_test.go @@ -0,0 +1,118 @@ +package cumsum + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestCumSumInit(t *testing.T) { + c := &CumSum{} + err := c.Init( + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "exclusive", I: 1}, + {Name: "reverse", I: 1}, + }, + }, + ) + + assert.Nil(t, err) + assert.Equal(t, true, c.exclusive) + assert.Equal(t, true, c.reverse) +} + +func TestCumSumInitDefaults(t *testing.T) { + c := &CumSum{} + err := c.Init( + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{}, + }, + ) + + assert.Nil(t, err) + assert.Equal(t, false, c.exclusive) + assert.Equal(t, false, c.reverse) +} + +func TestCumSum(t *testing.T) { + tests := []struct { + version int64 + node *onnx.NodeProto + backing []float32 + axis int32 + shape []int + expected []float32 + }{ + { + 11, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "exclusive", I: 0}, + {Name: "reverse", I: 0}, + }, + }, + []float32{1, 2, 3, 4}, + 0, + []int{2, 2}, + []float32{1, 2, 4, 6}, + }, + { + 11, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "exclusive", I: 0}, + {Name: "reverse", I: 0}, + }, + }, + []float32{1, 2, 3, 4}, + 1, + []int{2, 2}, + []float32{1, 3, 3, 7}, + }, + { + 11, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "exclusive", I: 1}, + {Name: "reverse", I: 0}, + }, + }, + []float32{1, 2, 3}, + 0, + []int{3}, + []float32{0, 1, 3}, + }, + { + 11, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "exclusive", I: 0}, + {Name: "reverse", I: 1}, + }, + }, + []float32{1, 2, 3}, + 0, + []int{3}, + []float32{6, 5, 3}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + tensor.New(tensor.FromScalar(test.axis)), + } + + cumsum := cumsumVersions[test.version]() + err := cumsum.Init(test.node) + assert.Nil(t, err) + + res, err := cumsum.Apply(inputs) + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} diff --git a/ops/cumsum/versions.go b/ops/cumsum/versions.go new file mode 100644 index 0000000..89fceab --- /dev/null +++ b/ops/cumsum/versions.go @@ -0,0 +1,13 @@ +package cumsum + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var cumsumVersions = ops.OperatorVersions{ + 11: ops.NewOperatorConstructor(newCumSum, 11, cumsumTypeConstraints), +} + +func GetVersions() ops.OperatorVersions { + return cumsumVersions +} diff --git a/ops/utils.go b/ops/utils.go index 1dcee9e..f41d911 100644 --- a/ops/utils.go +++ b/ops/utils.go @@ -78,6 +78,22 @@ func OffsetTensorIfNegative(t tensor.Tensor, offset int) error { return nil } +// AnyToInt casts the given data to an int, but only if the data is of some sort of int type. +func AnyToInt(value interface{}) (int, error) { + switch data := value.(type) { + case int8: + return int(data), nil + case int16: + return int(data), nil + case int32: + return int(data), nil + case int64: + return int(data), nil + default: + return 0, ErrCast + } +} + // AnyToIntSlice casts the data of a node to an int list. This will only // be done if the data is of some sort of int type. func AnyToIntSlice(value interface{}) ([]int, error) { diff --git a/ops_test.go b/ops_test.go index 07fb258..1a7df89 100644 --- a/ops_test.go +++ b/ops_test.go @@ -402,6 +402,13 @@ var expectedTests = []string{ "test_cos_example", "test_cosh", "test_cosh_example", + "test_cumsum_1d", + "test_cumsum_1d_exclusive", + "test_cumsum_1d_reverse", + "test_cumsum_1d_reverse_exclusive", + "test_cumsum_2d_axis_0", + "test_cumsum_2d_axis_1", + "test_cumsum_2d_negative_axis", "test_div", "test_div_bcast", "test_div_example", diff --git a/opset.go b/opset.go index 0bfcdb9..4542915 100644 --- a/opset.go +++ b/opset.go @@ -19,6 +19,7 @@ import ( "github.com/advancedclimatesystems/gonnx/ops/conv" "github.com/advancedclimatesystems/gonnx/ops/cos" "github.com/advancedclimatesystems/gonnx/ops/cosh" + "github.com/advancedclimatesystems/gonnx/ops/cumsum" "github.com/advancedclimatesystems/gonnx/ops/div" "github.com/advancedclimatesystems/gonnx/ops/equal" "github.com/advancedclimatesystems/gonnx/ops/expand" @@ -85,6 +86,7 @@ var operators = map[string]ops.OperatorVersions{ "Conv": conv.GetConvVersions(), "Cos": cos.GetCosVersions(), "Cosh": cosh.GetCoshVersions(), + "CumSum": cumsum.GetVersions(), "Div": div.GetDivVersions(), "Equal": equal.GetEqualVersions(), "Expand": expand.GetExpandVersions(), From 3002ff263e83187a18f5f79e937b95774cc81134 Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Sun, 22 Dec 2024 10:01:57 +0100 Subject: [PATCH 2/3] Clean up comments --- ops/cumsum/cumsum.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ops/cumsum/cumsum.go b/ops/cumsum/cumsum.go index a11d77b..28e346a 100644 --- a/ops/cumsum/cumsum.go +++ b/ops/cumsum/cumsum.go @@ -65,9 +65,10 @@ func (c *CumSum) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { return []tensor.Tensor{out}, nil } -// Performs cumulative sum of the input elements along the given axis. By default, it will do the sum inclusively meaning the first element is copied as is. Through an exclusive attribute, this behavior can change to exclude the first element. It can also perform summation in the opposite direction of the axis. For that, set reverse attribute to 1. +// Performs cumulative sum of the input elements along the given axis. +// Exclusive means the the cumsum for position j will not include the j-th element. +// Reverse means the cumsum will be performed in reverse order. func cumsum(x tensor.Tensor, axis int, exclusive, reverse bool) (tensor.Tensor, error) { - // First we copy the input tensor to the output tensor. out, ok := x.Clone().(tensor.Tensor) if !ok { return nil, ops.ErrCast From 960900c638fb276ec206baa232afbf48a066dbb2 Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Sun, 22 Dec 2024 10:49:58 +0100 Subject: [PATCH 3/3] Rewrite to switch --- ops/cumsum/cumsum.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ops/cumsum/cumsum.go b/ops/cumsum/cumsum.go index 28e346a..970b928 100644 --- a/ops/cumsum/cumsum.go +++ b/ops/cumsum/cumsum.go @@ -110,10 +110,11 @@ func cumsum(x tensor.Tensor, axis int, exclusive, reverse bool) (tensor.Tensor, currentValues := currentView.Materialize() + switch { // If exclusive is true, the first result in the cumsum opertaion is zero. // We can achieve this by subtracting the current values from the current values. // This way we don't have to infer the underlying type of the tensor. - if i == startValue && exclusive { + case i == startValue && exclusive: zeroValues, err := ops.Sub(currentValues, currentValues) if err != nil { return nil, err @@ -123,9 +124,8 @@ func cumsum(x tensor.Tensor, axis int, exclusive, reverse bool) (tensor.Tensor, if err != nil { return nil, err } - } - if (i != startValue) && exclusive { + case i != startValue && exclusive: err = tensor.Copy(currentView, prevValues) if err != nil { return nil, err @@ -137,7 +137,7 @@ func cumsum(x tensor.Tensor, axis int, exclusive, reverse bool) (tensor.Tensor, } prevValues = newValues - } else if i != startValue { + case i != startValue: newValues, err := ops.Add(currentValues, prevValues) if err != nil { return nil, err