Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions ops/identity/identity.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package identity

import (
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

var identityTypeConstraints = [][]tensor.Dtype{ops.AllTypes}

// Identity represents the ONNX identity operator.
type Identity struct {
ops.BaseOperator
}

// newIdentity creates a new identity operator.
func newIdentity(version int, typeConstraints [][]tensor.Dtype) ops.Operator {
return &Identity{
BaseOperator: ops.NewBaseOperator(
version,
1,
1,
typeConstraints,
"identity",
),
}
}

// Init initializes the identity operator.
func (a *Identity) Init(*onnx.NodeProto) error {
return nil
}

// Apply applies the identity operator.
func (a *Identity) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
out, ok := inputs[0].Clone().(tensor.Tensor)
if !ok {
return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[0].Clone())
}

return []tensor.Tensor{out}, nil
}
100 changes: 100 additions & 0 deletions ops/identity/identity_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package identity

import (
"testing"

"github.com/advancedclimatesystems/gonnx/ops"
"github.com/stretchr/testify/assert"
"gorgonia.org/tensor"
)

func TestIdentityInit(t *testing.T) {
i := &Identity{}

// since 'identity' does not have any attributes we pass in nil. This should not
// fail initializing the identity.
err := i.Init(nil)
assert.Nil(t, err)
}

func TestIdentity(t *testing.T) {
tests := []struct {
version int64
backing []float32
shape []int
expected []float32
}{
{
13,
[]float32{0, 1, 2, 3},
[]int{2, 2},
[]float32{0, 1, 2, 3},
},
}

for _, test := range tests {
inputs := []tensor.Tensor{
ops.TensorWithBackingFixture(test.backing, test.shape...),
}

identity := identityVersions[test.version]()

res, err := identity.Apply(inputs)
assert.Nil(t, err)

assert.Equal(t, test.expected, res[0].Data())
}
}

func TestInputValidationIdentity(t *testing.T) {
tests := []struct {
version int64
inputs []tensor.Tensor
err error
}{
{
13,
[]tensor.Tensor{
ops.TensorWithBackingFixture([]uint32{1, 2}, 2),
},
nil,
},
{
13,
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
},
nil,
},
{
13,
[]tensor.Tensor{
ops.TensorWithBackingFixture([]float32{1, 2}, 2),
ops.TensorWithBackingFixture([]float32{3, 4}, 2),
},
ops.ErrInvalidInputCount(2, identity13BaseOpFixture()),
},
{
13,
[]tensor.Tensor{
ops.TensorWithBackingFixture([]int{1, 2}, 2),
},
ops.ErrInvalidInputType(0, "int", identity13BaseOpFixture()),
},
}

for _, test := range tests {
identity := identityVersions[test.version]()
validated, err := identity.ValidateInputs(test.inputs)

assert.Equal(t, test.err, err)

if test.err == nil {
assert.Equal(t, test.inputs, validated)
}
}
}

func identity13BaseOpFixture() ops.BaseOperator {
return ops.NewBaseOperator(13, 1, 1, identityTypeConstraints, "identity")
}
13 changes: 13 additions & 0 deletions ops/identity/versions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package identity

import (
"github.com/advancedclimatesystems/gonnx/ops"
)

var identityVersions = ops.OperatorVersions{
13: ops.NewOperatorConstructor(newIdentity, 13, identityTypeConstraints),
}

func GetVersions() ops.OperatorVersions {
return identityVersions
}
4 changes: 4 additions & 0 deletions ops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ var ignoredTests = []string{
"test_softmax_axis_2_expanded_ver18", // Opset18
"test_reshape_allowzero_reordered", // Opset14

"test_identity_opt", // Error in test? Can not read in input tensor. https://github.com/onnx/onnx/issues/6842

"test_constant_pad", // Pad is not implemented yet.
"test_constant_pad_axes", // Pad is not implemented yet.
"test_logsoftmax_large_number_expanded", // Requires 'Exp' operator.
Expand Down Expand Up @@ -458,6 +460,8 @@ var expectedTests = []string{
"test_gru_defaults",
"test_gru_seq_length",
"test_gru_with_initial_bias",
"test_identity",
"test_identity_sequence",
"test_less",
"test_less_bcast",
"test_less_equal",
Expand Down
4 changes: 3 additions & 1 deletion opset.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/advancedclimatesystems/gonnx/ops/greater"
"github.com/advancedclimatesystems/gonnx/ops/greaterorequal"
"github.com/advancedclimatesystems/gonnx/ops/gru"
"github.com/advancedclimatesystems/gonnx/ops/identity"
"github.com/advancedclimatesystems/gonnx/ops/less"
"github.com/advancedclimatesystems/gonnx/ops/lessorequal"
"github.com/advancedclimatesystems/gonnx/ops/linearregressor"
Expand Down Expand Up @@ -102,6 +103,7 @@ var operators = map[string]ops.OperatorVersions{
"Greater": greater.GetVersions(),
"GreaterOrEqual": greaterorequal.GetVersions(),
"GRU": gru.GetVersions(),
"Identity": identity.GetVersions(),
"Less": less.GetVersions(),
"LessOrEqual": lessorequal.GetVersions(),
"LinearRegressor": linearregressor.GetVersions(),
Expand Down Expand Up @@ -134,7 +136,7 @@ var operators = map[string]ops.OperatorVersions{
"Transpose": transpose.GetVersions(),
"Unsqueeze": unsqueeze.GetVersions(),
"Xor": xor.GetVersions(),
"Where": where.GetVersions(),
"Where": where.GetVersions(),
}

// GetClosestOperatorVersion resolves, given a certain opset version, the operator version that is closest
Expand Down