From e4c12218dccad791a543347b362a535887112ea5 Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Sat, 29 Mar 2025 08:44:50 +0100 Subject: [PATCH 1/2] Added Identity operator --- ops/identity/identity.go | 42 ++++++++++++++ ops/identity/identity_test.go | 100 ++++++++++++++++++++++++++++++++++ ops/identity/versions.go | 13 +++++ ops_test.go | 4 ++ opset.go | 4 +- 5 files changed, 162 insertions(+), 1 deletion(-) create mode 100644 ops/identity/identity.go create mode 100644 ops/identity/identity_test.go create mode 100644 ops/identity/versions.go diff --git a/ops/identity/identity.go b/ops/identity/identity.go new file mode 100644 index 0000000..1bc11d5 --- /dev/null +++ b/ops/identity/identity.go @@ -0,0 +1,42 @@ +package identity + +import ( + "github.com/advancedclimatesystems/gonnx/onnx" + "github.com/advancedclimatesystems/gonnx/ops" + "gorgonia.org/tensor" +) + +var identityTypeConstraints = [][]tensor.Dtype{ops.AllTypes} + +// Identity represents the ONNX identity operator. +type Identity struct { + ops.BaseOperator +} + +// newIdentity creates a new identity operator. +func newIdentity(version int, typeConstraints [][]tensor.Dtype) ops.Operator { + return &Identity{ + BaseOperator: ops.NewBaseOperator( + version, + 1, + 1, + typeConstraints, + "identity", + ), + } +} + +// Init initializes the identity operator. +func (a *Identity) Init(*onnx.NodeProto) error { + return nil +} + +// Apply applies the identity operator. +func (a *Identity) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) { + out, ok := inputs[0].Clone().(tensor.Tensor) + if !ok { + return nil, ops.ErrTypeAssert("tensor.Tensor", inputs[0].Clone()) + } + + return []tensor.Tensor{out}, nil +} diff --git a/ops/identity/identity_test.go b/ops/identity/identity_test.go new file mode 100644 index 0000000..54416eb --- /dev/null +++ b/ops/identity/identity_test.go @@ -0,0 +1,100 @@ +package identity + +import ( + "testing" + + "github.com/advancedclimatesystems/gonnx/ops" + "github.com/stretchr/testify/assert" + "gorgonia.org/tensor" +) + +func TestIdentityInit(t *testing.T) { + i := &Identity{} + + // since 'identity' does not have any attributes we pass in nil. This should not + // fail initializing the identity. + err := i.Init(nil) + assert.Nil(t, err) +} + +func TestIdentity(t *testing.T) { + tests := []struct { + version int64 + backing []float32 + shape []int + expected []float32 + }{ + { + 13, + []float32{0, 1, 2, 3}, + []int{2, 2}, + []float32{0, 1, 2, 3}, + }, + } + + for _, test := range tests { + inputs := []tensor.Tensor{ + ops.TensorWithBackingFixture(test.backing, test.shape...), + } + + identity := identityVersions[test.version]() + + res, err := identity.Apply(inputs) + assert.Nil(t, err) + + assert.Equal(t, test.expected, res[0].Data()) + } +} + +func TestInputValidationIdentity(t *testing.T) { + tests := []struct { + version int64 + inputs []tensor.Tensor + err error + }{ + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]uint32{1, 2}, 2), + }, + nil, + }, + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + }, + nil, + }, + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]float32{1, 2}, 2), + ops.TensorWithBackingFixture([]float32{3, 4}, 2), + }, + ops.ErrInvalidInputCount(2, identity13BaseOpFixture()), + }, + { + 13, + []tensor.Tensor{ + ops.TensorWithBackingFixture([]int{1, 2}, 2), + }, + ops.ErrInvalidInputType(0, "int", identity13BaseOpFixture()), + }, + } + + for _, test := range tests { + identity := identityVersions[test.version]() + validated, err := identity.ValidateInputs(test.inputs) + + assert.Equal(t, test.err, err) + + if test.err == nil { + assert.Equal(t, test.inputs, validated) + } + } +} + +func identity13BaseOpFixture() ops.BaseOperator { + return ops.NewBaseOperator(13, 1, 1, identityTypeConstraints, "identity") +} diff --git a/ops/identity/versions.go b/ops/identity/versions.go new file mode 100644 index 0000000..96ad876 --- /dev/null +++ b/ops/identity/versions.go @@ -0,0 +1,13 @@ +package identity + +import ( + "github.com/advancedclimatesystems/gonnx/ops" +) + +var identityVersions = ops.OperatorVersions{ + 13: ops.NewOperatorConstructor(newIdentity, 13, identityTypeConstraints), +} + +func GetVersions() ops.OperatorVersions { + return identityVersions +} diff --git a/ops_test.go b/ops_test.go index d065549..7e78bd9 100644 --- a/ops_test.go +++ b/ops_test.go @@ -82,6 +82,8 @@ var ignoredTests = []string{ "test_softmax_axis_2_expanded_ver18", // Opset18 "test_reshape_allowzero_reordered", // Opset14 + "test_identity_opt", // Bug in test? Can't read in input tensor. + "test_constant_pad", // Pad is not implemented yet. "test_constant_pad_axes", // Pad is not implemented yet. "test_logsoftmax_large_number_expanded", // Requires 'Exp' operator. @@ -458,6 +460,8 @@ var expectedTests = []string{ "test_gru_defaults", "test_gru_seq_length", "test_gru_with_initial_bias", + "test_identity", + "test_identity_sequence", "test_less", "test_less_bcast", "test_less_equal", diff --git a/opset.go b/opset.go index f1f67ad..c589f5c 100644 --- a/opset.go +++ b/opset.go @@ -30,6 +30,7 @@ import ( "github.com/advancedclimatesystems/gonnx/ops/greater" "github.com/advancedclimatesystems/gonnx/ops/greaterorequal" "github.com/advancedclimatesystems/gonnx/ops/gru" + "github.com/advancedclimatesystems/gonnx/ops/identity" "github.com/advancedclimatesystems/gonnx/ops/less" "github.com/advancedclimatesystems/gonnx/ops/lessorequal" "github.com/advancedclimatesystems/gonnx/ops/linearregressor" @@ -102,6 +103,7 @@ var operators = map[string]ops.OperatorVersions{ "Greater": greater.GetVersions(), "GreaterOrEqual": greaterorequal.GetVersions(), "GRU": gru.GetVersions(), + "Identity": identity.GetVersions(), "Less": less.GetVersions(), "LessOrEqual": lessorequal.GetVersions(), "LinearRegressor": linearregressor.GetVersions(), @@ -134,7 +136,7 @@ var operators = map[string]ops.OperatorVersions{ "Transpose": transpose.GetVersions(), "Unsqueeze": unsqueeze.GetVersions(), "Xor": xor.GetVersions(), - "Where": where.GetVersions(), + "Where": where.GetVersions(), } // GetClosestOperatorVersion resolves, given a certain opset version, the operator version that is closest From 544fda5bbf19205f26910831f7e1894110b2d568 Mon Sep 17 00:00:00 2001 From: Swopper050 Date: Sat, 29 Mar 2025 08:46:52 +0100 Subject: [PATCH 2/2] Fix lint --- ops_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ops_test.go b/ops_test.go index 7e78bd9..9dfbcce 100644 --- a/ops_test.go +++ b/ops_test.go @@ -82,7 +82,7 @@ var ignoredTests = []string{ "test_softmax_axis_2_expanded_ver18", // Opset18 "test_reshape_allowzero_reordered", // Opset14 - "test_identity_opt", // Bug in test? Can't read in input tensor. + "test_identity_opt", // Error in test? Can not read in input tensor. https://github.com/onnx/onnx/issues/6842 "test_constant_pad", // Pad is not implemented yet. "test_constant_pad_axes", // Pad is not implemented yet.