diff --git a/model_test.go b/model_test.go index 54b9ed3..ed1c4b9 100644 --- a/model_test.go +++ b/model_test.go @@ -1,7 +1,6 @@ package gonnx import ( - "fmt" "testing" "github.com/advancedclimatesystems/gonnx/onnx" @@ -97,7 +96,6 @@ func TestModel(t *testing.T) { } for _, test := range tests { - fmt.Println(test.path) model, err := NewModelFromFile(test.path) assert.Nil(t, err) diff --git a/ops/cumsum/cumsum.go b/ops/cumsum/cumsum.go new file mode 100644 index 0000000..970b928 --- /dev/null +++ b/ops/cumsum/cumsum.go @@ -0,0 +1,170 @@ +package cumsum + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var cumsumTypeConstraints = [][]tensor.Dtype{ + {tensor.Int32, tensor.Int64, tensor.Uint32, tensor.Uint64, tensor.Float32, tensor.Float64}, + {tensor.Int32, tensor.Int64}, +} + +// CumSum represents the ONNX cumsum operator. +type CumSum struct { + ops.BaseOperator + + exclusive bool + reverse bool +} + +// newCumSum creates a new cumsum operator. +func newCumSum(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &CumSum{ + BaseOperator: ops.NewBaseOperator( + version, + 2, + 2, + typeConstraints, + "cumsum", + ), + exclusive: false, + reverse: false, + } +} + +// Init initializes the cumsum operator. +func (c *CumSum) Init(n *onnx.NodeProto) error { + for _, attr := range n.GetAttribute() { + switch attr.GetName() { + case "exclusive": + c.exclusive = attr.GetI() == 1 + case "reverse": + c.reverse = attr.GetI() == 1 + default: + return ops.ErrInvalidAttribute(attr.GetName(), c) + } + } + + return nil +} + +// Apply applies the cumsum operator. +func (c *CumSum) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + axis, err := ops.AnyToInt(inputs[1].ScalarValue()) + if err != nil { + return nil, err + } + + out, err := cumsum(inputs[0], axis, c.exclusive, c.reverse) + if err != nil { + return nil, err + } + + return []tensor.Tensor{out}, nil +} + +// Performs cumulative sum of the input elements along the given axis. +// Exclusive means the the cumsum for position j will not include the j-th element. +// Reverse means the cumsum will be performed in reverse order. +func cumsum(x tensor.Tensor, axis int, exclusive, reverse bool) (tensor.Tensor, error) { + out, ok := x.Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrCast + } + + nDims := len(x.Shape()) + axis = ops.ConvertNegativeAxis(axis, nDims) + + if axis < 0 || axis >= nDims { + return nil, ops.ErrAxisOutOfRange(0, nDims, axis) + } + + axisSize := x.Shape()[axis] + + var startValue int + if reverse { + startValue = axisSize - 1 + } else { + startValue = 0 + } + + slices := make([]tensor.Slice, nDims) + slices[axis] = ops.NewSlicer(startValue, startValue+1) + + prevView, err := x.Slice(slices...) + if err != nil { + return nil, err + } + + prevValues := prevView.Materialize() + + for i := startValue; endValueReached(i, axisSize, reverse); { + slices[axis] = ops.NewSlicer(i, i+1) + + currentView, err := out.Slice(slices...) + if err != nil { + return nil, err + } + + currentValues := currentView.Materialize() + + switch { + // If exclusive is true, the first result in the cumsum opertaion is zero. + // We can achieve this by subtracting the current values from the current values. + // This way we don't have to infer the underlying type of the tensor. + case i == startValue && exclusive: + zeroValues, err := ops.Sub(currentValues, currentValues) + if err != nil { + return nil, err + } + + err = tensor.Copy(currentView, zeroValues) + if err != nil { + return nil, err + } + + case i != startValue && exclusive: + err = tensor.Copy(currentView, prevValues) + if err != nil { + return nil, err + } + + newValues, err := ops.Add(currentValues, prevValues) + if err != nil { + return nil, err + } + + prevValues = newValues + case i != startValue: + newValues, err := ops.Add(currentValues, prevValues) + if err != nil { + return nil, err + } + + err = tensor.Copy(currentView, newValues) + if err != nil { + return nil, err + } + + prevValues = newValues + } + + if reverse { + i-- + } else { + i++ + } + } + + return out, nil +} + +func endValueReached(i, axisSize int, reverse bool) bool { + if reverse { + return i >= 0 + } + + return i < axisSize +} diff --git a/ops/cumsum/cumsum_test.go b/ops/cumsum/cumsum_test.go new file mode 100644 index 0000000..6595b29 --- /dev/null +++ b/ops/cumsum/cumsum_test.go @@ -0,0 +1,118 @@ +package cumsum + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestCumSumInit(t *testing.T) { + c := &CumSum{} + err := c.Init( + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "exclusive", I: 1}, + {Name: "reverse", I: 1}, + }, + }, + ) + + assert.Nil(t, err) + assert.Equal(t, true, c.exclusive) + assert.Equal(t, true, c.reverse) +} + +func TestCumSumInitDefaults(t *testing.T) { + c := &CumSum{} + err := c.Init( + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{}, + }, + ) + + assert.Nil(t, err) + assert.Equal(t, false, c.exclusive) + assert.Equal(t, false, c.reverse) +} + +func TestCumSum(t *testing.T) { + tests := []struct { + version int64 + node *onnx.NodeProto + backing []float32 + axis int32 + shape []int + expected []float32 + }{ + { + 11, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "exclusive", I: 0}, + {Name: "reverse", I: 0}, + }, + }, + []float32{1, 2, 3, 4}, + 0, + []int{2, 2}, + []float32{1, 2, 4, 6}, + }, + { + 11, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "exclusive", I: 0}, + {Name: "reverse", I: 0}, + }, + }, + []float32{1, 2, 3, 4}, + 1, + []int{2, 2}, + []float32{1, 3, 3, 7}, + }, + { + 11, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "exclusive", I: 1}, + {Name: "reverse", I: 0}, + }, + }, + []float32{1, 2, 3}, + 0, + []int{3}, + []float32{0, 1, 3}, + }, + { + 11, + &onnx.NodeProto{ + Attribute: []*onnx.AttributeProto{ + {Name: "exclusive", I: 0}, + {Name: "reverse", I: 1}, + }, + }, + []float32{1, 2, 3}, + 0, + []int{3}, + []float32{6, 5, 3}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + tensor.New(tensor.FromScalar(test.axis)), + } + + cumsum := cumsumVersions[test.version]() + err := cumsum.Init(test.node) + assert.Nil(t, err) + + res, err := cumsum.Apply(inputs) + assert.Nil(t, err) + assert.Equal(t, test.expected, res[0].Data()) + } +} diff --git a/ops/cumsum/versions.go b/ops/cumsum/versions.go new file mode 100644 index 0000000..89fceab --- /dev/null +++ b/ops/cumsum/versions.go @@ -0,0 +1,13 @@ +package cumsum + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var cumsumVersions = ops.OperatorVersions{ + 11: ops.NewOperatorConstructor(newCumSum, 11, cumsumTypeConstraints), +} + +func GetVersions() ops.OperatorVersions { + return cumsumVersions +} diff --git a/ops/utils.go b/ops/utils.go index 1dcee9e..f41d911 100644 --- a/ops/utils.go +++ b/ops/utils.go @@ -78,6 +78,22 @@ func OffsetTensorIfNegative(t tensor.Tensor, offset int) error { return nil } +// AnyToInt casts the given data to an int, but only if the data is of some sort of int type. +func AnyToInt(value interface{}) (int, error) { + switch data := value.(type) { + case int8: + return int(data), nil + case int16: + return int(data), nil + case int32: + return int(data), nil + case int64: + return int(data), nil + default: + return 0, ErrCast + } +} + // AnyToIntSlice casts the data of a node to an int list. This will only // be done if the data is of some sort of int type. func AnyToIntSlice(value interface{}) ([]int, error) { diff --git a/ops_test.go b/ops_test.go index 07fb258..1a7df89 100644 --- a/ops_test.go +++ b/ops_test.go @@ -402,6 +402,13 @@ var expectedTests = []string{ "test_cos_example", "test_cosh", "test_cosh_example", + "test_cumsum_1d", + "test_cumsum_1d_exclusive", + "test_cumsum_1d_reverse", + "test_cumsum_1d_reverse_exclusive", + "test_cumsum_2d_axis_0", + "test_cumsum_2d_axis_1", + "test_cumsum_2d_negative_axis", "test_div", "test_div_bcast", "test_div_example", diff --git a/opset.go b/opset.go index 0bfcdb9..4542915 100644 --- a/opset.go +++ b/opset.go @@ -19,6 +19,7 @@ import ( "github.com/advancedclimatesystems/gonnx/ops/conv" "github.com/advancedclimatesystems/gonnx/ops/cos" "github.com/advancedclimatesystems/gonnx/ops/cosh" + "github.com/advancedclimatesystems/gonnx/ops/cumsum" "github.com/advancedclimatesystems/gonnx/ops/div" "github.com/advancedclimatesystems/gonnx/ops/equal" "github.com/advancedclimatesystems/gonnx/ops/expand" @@ -85,6 +86,7 @@ var operators = map[string]ops.OperatorVersions{ "Conv": conv.GetConvVersions(), "Cos": cos.GetCosVersions(), "Cosh": cosh.GetCoshVersions(), + "CumSum": cumsum.GetVersions(), "Div": div.GetDivVersions(), "Equal": equal.GetEqualVersions(), "Expand": expand.GetExpandVersions(),