Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions model_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package gonnx

import (
"fmt"
"testing"

"github.com/advancedclimatesystems/gonnx/onnx"
Expand Down Expand Up @@ -97,7 +96,6 @@ func TestModel(t *testing.T) {
}

for _, test := range tests {
fmt.Println(test.path)
model, err := NewModelFromFile(test.path)
assert.Nil(t, err)

Expand Down
170 changes: 170 additions & 0 deletions ops/cumsum/cumsum.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
package cumsum

import (
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

var cumsumTypeConstraints = [][]tensor.Dtype{
{tensor.Int32, tensor.Int64, tensor.Uint32, tensor.Uint64, tensor.Float32, tensor.Float64},
{tensor.Int32, tensor.Int64},
}

// CumSum represents the ONNX cumsum operator.
type CumSum struct {
ops.BaseOperator

exclusive bool
reverse bool
}

// newCumSum creates a new cumsum operator.
func newCumSum(version int, typeConstraints [][]tensor.Dtype) ops.Operator {
return &CumSum{
BaseOperator: ops.NewBaseOperator(
version,
2,
2,
typeConstraints,
"cumsum",
),
exclusive: false,
reverse: false,
}
}

// Init initializes the cumsum operator.
func (c *CumSum) Init(n *onnx.NodeProto) error {
for _, attr := range n.GetAttribute() {
switch attr.GetName() {
case "exclusive":
c.exclusive = attr.GetI() == 1
case "reverse":
c.reverse = attr.GetI() == 1
default:
return ops.ErrInvalidAttribute(attr.GetName(), c)
}
}

return nil
}

// Apply applies the cumsum operator.
func (c *CumSum) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
axis, err := ops.AnyToInt(inputs[1].ScalarValue())
if err != nil {
return nil, err
}

out, err := cumsum(inputs[0], axis, c.exclusive, c.reverse)
if err != nil {
return nil, err
}

return []tensor.Tensor{out}, nil
}

// Performs cumulative sum of the input elements along the given axis.
// Exclusive means the the cumsum for position j will not include the j-th element.
// Reverse means the cumsum will be performed in reverse order.
func cumsum(x tensor.Tensor, axis int, exclusive, reverse bool) (tensor.Tensor, error) {
out, ok := x.Clone().(tensor.Tensor)
if !ok {
return nil, ops.ErrCast
}

nDims := len(x.Shape())
axis = ops.ConvertNegativeAxis(axis, nDims)

if axis < 0 || axis >= nDims {
return nil, ops.ErrAxisOutOfRange(0, nDims, axis)
}

axisSize := x.Shape()[axis]

var startValue int
if reverse {
startValue = axisSize - 1
} else {
startValue = 0
}

slices := make([]tensor.Slice, nDims)
slices[axis] = ops.NewSlicer(startValue, startValue+1)

prevView, err := x.Slice(slices...)
if err != nil {
return nil, err
}

prevValues := prevView.Materialize()

for i := startValue; endValueReached(i, axisSize, reverse); {
slices[axis] = ops.NewSlicer(i, i+1)

currentView, err := out.Slice(slices...)
if err != nil {
return nil, err
}

currentValues := currentView.Materialize()

switch {
// If exclusive is true, the first result in the cumsum opertaion is zero.
// We can achieve this by subtracting the current values from the current values.
// This way we don't have to infer the underlying type of the tensor.
case i == startValue && exclusive:
zeroValues, err := ops.Sub(currentValues, currentValues)
if err != nil {
return nil, err
}

err = tensor.Copy(currentView, zeroValues)
if err != nil {
return nil, err
}

case i != startValue && exclusive:
err = tensor.Copy(currentView, prevValues)
if err != nil {
return nil, err
}

newValues, err := ops.Add(currentValues, prevValues)
if err != nil {
return nil, err
}

prevValues = newValues
case i != startValue:
newValues, err := ops.Add(currentValues, prevValues)
if err != nil {
return nil, err
}

err = tensor.Copy(currentView, newValues)
if err != nil {
return nil, err
}

prevValues = newValues
}

if reverse {
i--
} else {
i++
}
}

return out, nil
}

func endValueReached(i, axisSize int, reverse bool) bool {
if reverse {
return i >= 0
}

return i < axisSize
}
118 changes: 118 additions & 0 deletions ops/cumsum/cumsum_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package cumsum

import (
"testing"

"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"github.com/stretchr/testify/assert"
"gorgonia.org/tensor"
)

func TestCumSumInit(t *testing.T) {
c := &CumSum{}
err := c.Init(
&onnx.NodeProto{
Attribute: []*onnx.AttributeProto{
{Name: "exclusive", I: 1},
{Name: "reverse", I: 1},
},
},
)

assert.Nil(t, err)
assert.Equal(t, true, c.exclusive)
assert.Equal(t, true, c.reverse)
}

func TestCumSumInitDefaults(t *testing.T) {
c := &CumSum{}
err := c.Init(
&onnx.NodeProto{
Attribute: []*onnx.AttributeProto{},
},
)

assert.Nil(t, err)
assert.Equal(t, false, c.exclusive)
assert.Equal(t, false, c.reverse)
}

func TestCumSum(t *testing.T) {
tests := []struct {
version int64
node *onnx.NodeProto
backing []float32
axis int32
shape []int
expected []float32
}{
{
11,
&onnx.NodeProto{
Attribute: []*onnx.AttributeProto{
{Name: "exclusive", I: 0},
{Name: "reverse", I: 0},
},
},
[]float32{1, 2, 3, 4},
0,
[]int{2, 2},
[]float32{1, 2, 4, 6},
},
{
11,
&onnx.NodeProto{
Attribute: []*onnx.AttributeProto{
{Name: "exclusive", I: 0},
{Name: "reverse", I: 0},
},
},
[]float32{1, 2, 3, 4},
1,
[]int{2, 2},
[]float32{1, 3, 3, 7},
},
{
11,
&onnx.NodeProto{
Attribute: []*onnx.AttributeProto{
{Name: "exclusive", I: 1},
{Name: "reverse", I: 0},
},
},
[]float32{1, 2, 3},
0,
[]int{3},
[]float32{0, 1, 3},
},
{
11,
&onnx.NodeProto{
Attribute: []*onnx.AttributeProto{
{Name: "exclusive", I: 0},
{Name: "reverse", I: 1},
},
},
[]float32{1, 2, 3},
0,
[]int{3},
[]float32{6, 5, 3},
},
}

for _, test := range tests {
inputs := []tensor.Tensor{
ops.TensorWithBackingFixture(test.backing, test.shape...),
tensor.New(tensor.FromScalar(test.axis)),
}

cumsum := cumsumVersions[test.version]()
err := cumsum.Init(test.node)
assert.Nil(t, err)

res, err := cumsum.Apply(inputs)
assert.Nil(t, err)
assert.Equal(t, test.expected, res[0].Data())
}
}
13 changes: 13 additions & 0 deletions ops/cumsum/versions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package cumsum

import (
"github.com/advancedclimatesystems/gonnx/ops"
)

var cumsumVersions = ops.OperatorVersions{
11: ops.NewOperatorConstructor(newCumSum, 11, cumsumTypeConstraints),
}

func GetVersions() ops.OperatorVersions {
return cumsumVersions
}
16 changes: 16 additions & 0 deletions ops/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,22 @@ func OffsetTensorIfNegative(t tensor.Tensor, offset int) error {
return nil
}

// AnyToInt casts the given data to an int, but only if the data is of some sort of int type.
func AnyToInt(value interface{}) (int, error) {
switch data := value.(type) {
case int8:
return int(data), nil
case int16:
return int(data), nil
case int32:
return int(data), nil
case int64:
return int(data), nil
default:
return 0, ErrCast
}
}

// AnyToIntSlice casts the data of a node to an int list. This will only
// be done if the data is of some sort of int type.
func AnyToIntSlice(value interface{}) ([]int, error) {
Expand Down
7 changes: 7 additions & 0 deletions ops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,13 @@ var expectedTests = []string{
"test_cos_example",
"test_cosh",
"test_cosh_example",
"test_cumsum_1d",
"test_cumsum_1d_exclusive",
"test_cumsum_1d_reverse",
"test_cumsum_1d_reverse_exclusive",
"test_cumsum_2d_axis_0",
"test_cumsum_2d_axis_1",
"test_cumsum_2d_negative_axis",
"test_div",
"test_div_bcast",
"test_div_example",
Expand Down
2 changes: 2 additions & 0 deletions opset.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/advancedclimatesystems/gonnx/ops/conv"
"github.com/advancedclimatesystems/gonnx/ops/cos"
"github.com/advancedclimatesystems/gonnx/ops/cosh"
"github.com/advancedclimatesystems/gonnx/ops/cumsum"
"github.com/advancedclimatesystems/gonnx/ops/div"
"github.com/advancedclimatesystems/gonnx/ops/equal"
"github.com/advancedclimatesystems/gonnx/ops/expand"
Expand Down Expand Up @@ -85,6 +86,7 @@ var operators = map[string]ops.OperatorVersions{
"Conv": conv.GetConvVersions(),
"Cos": cos.GetCosVersions(),
"Cosh": cosh.GetCoshVersions(),
"CumSum": cumsum.GetVersions(),
"Div": div.GetDivVersions(),
"Equal": equal.GetEqualVersions(),
"Expand": expand.GetExpandVersions(),
Expand Down
Loading