Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions ops/where/versions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package where

import "github.com/advancedclimatesystems/gonnx/ops"

var whereVersions = ops.OperatorVersions{
9: ops.NewOperatorConstructor(newWhere, 9, whereTypeConstraints),
}

func GetVersions() ops.OperatorVersions {
return whereVersions
}
105 changes: 105 additions & 0 deletions ops/where/where.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package where

import (
"github.com/advancedclimatesystems/gonnx/onnx"
"github.com/advancedclimatesystems/gonnx/ops"
"gorgonia.org/tensor"
)

var whereTypeConstraints = [][]tensor.Dtype{
{tensor.Bool},
ops.AllTypes,
ops.AllTypes,
}

// Where represents the ONNX where operator.
type Where struct {
ops.BaseOperator
}

// newWhere creates a new where operator.
func newWhere(version int, typeConstraints [][]tensor.Dtype) ops.Operator {
return &Where{
BaseOperator: ops.NewBaseOperator(
version,
3,
3,
typeConstraints,
"where",
),
}
}

// Init initializes the where operator.
func (w *Where) Init(*onnx.NodeProto) error {
return nil
}

// Apply applies the where operator.
func (w *Where) Apply(inputs []tensor.Tensor) ([]tensor.Tensor, error) {
condition := inputs[0]

X := inputs[1]
Y := inputs[2]

X, Y, err := ops.MultidirectionalBroadcast(X, Y)
if err != nil {
return nil, err
}

condition, X, err = ops.MultidirectionalBroadcast(condition, X)
if err != nil {
return nil, err
}

out, err := where(X, Y, condition)
if err != nil {
return nil, err
}

return []tensor.Tensor{out}, err
}

func where(X, Y, condition tensor.Tensor) (tensor.Tensor, error) {
out := tensor.New(tensor.Of(X.Dtype()), tensor.WithShape(X.Shape()...))

iterator := condition.Iterator()
iterator.Reset()

for !iterator.Done() {
coords := iterator.Coord()

conditionRaw, err := condition.At(coords...)
if err != nil {
return nil, err
}

conditionValue, ok := conditionRaw.(bool)
if !ok {
return nil, ops.ErrCast
}

var value any
if conditionValue {
value, err = X.At(coords...)
} else {
value, err = Y.At(coords...)
}

if err != nil {
return nil, err
}

err = out.SetAt(value, coords...)
if err != nil {
return nil, err
}

_, err = iterator.Next()
if err != nil {
return nil, err
}
}

return out, nil
}
92 changes: 92 additions & 0 deletions ops/where/where_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package where

import (
"testing"

"github.com/stretchr/testify/assert"
"gorgonia.org/tensor"
)

func TestWhereInit(t *testing.T) {
op := whereVersions[9]()
err := op.Init(nil)
assert.Nil(t, err)
}

func TestWhere(t *testing.T) {
tests := []struct {
version int64
condition []bool
conditionShape []int
backing1 []float32
backing1Shape []int
backing2 []float32
backing2Shape []int
expectedBacking []float32
}{
{
9,
[]bool{true, false, true},
[]int{3},
[]float32{1, 2, 3},
[]int{3},
[]float32{4, 5, 6},
[]int{3},
[]float32{1, 5, 3},
},
{
9,
[]bool{true, false, true, false},
[]int{2, 2},
[]float32{1, 2, 3, 4},
[]int{2, 2},
[]float32{4, 5},
[]int{1, 2},
[]float32{1, 5, 3, 5},
},
{
9,
[]bool{false, true},
[]int{2},
[]float32{1, 2, 3, 4},
[]int{2, 2},
[]float32{4, 5},
[]int{1, 2},
[]float32{4, 2, 4, 4},
},
{
9,
[]bool{false, false, false, true, true, true},
[]int{2, 3},
[]float32{1, 2, 3, 4, 5, 6},
[]int{2, 3},
[]float32{4, 5, 6},
[]int{3},
[]float32{4, 5, 6, 4, 5, 6},
},
{
9,
[]bool{false, true, true, false, false, true},
[]int{2, 3},
[]float32{1, 2, 3, 4, 5, 6},
[]int{2, 3},
[]float32{4, 5, 6},
[]int{3},
[]float32{4, 2, 3, 4, 5, 6},
},
}

for _, test := range tests {
inputs := []tensor.Tensor{
tensor.New(tensor.WithShape(test.conditionShape...), tensor.WithBacking(test.condition)),
tensor.New(tensor.WithShape(test.backing1Shape...), tensor.WithBacking(test.backing1)),
tensor.New(tensor.WithShape(test.backing2Shape...), tensor.WithBacking(test.backing2)),
}

op := whereVersions[test.version]()

res, err := op.Apply(inputs)
assert.Nil(t, err)
assert.Equal(t, test.expectedBacking, res[0].Data())
}
}
2 changes: 2 additions & 0 deletions ops_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,8 @@ var expectedTests = []string{
"test_unsqueeze_three_axes",
"test_unsqueeze_two_axes",
"test_unsqueeze_unsorted_axes",
"test_where_example",
"test_where_long_example",
"test_xor_bcast3v1d",
"test_xor_bcast3v2d",
"test_xor_bcast4v2d",
Expand Down
2 changes: 2 additions & 0 deletions opset.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ import (
"github.com/advancedclimatesystems/gonnx/ops/tanh"
"github.com/advancedclimatesystems/gonnx/ops/transpose"
"github.com/advancedclimatesystems/gonnx/ops/unsqueeze"
"github.com/advancedclimatesystems/gonnx/ops/where"
"github.com/advancedclimatesystems/gonnx/ops/xor"
)

Expand Down Expand Up @@ -130,6 +131,7 @@ var operators = map[string]ops.OperatorVersions{
"Tanh": tanh.GetTanhVersions(),
"Transpose": transpose.GetTransposeVersions(),
"Unsqueeze": unsqueeze.GetUnsqueezeVersions(),
"Where": where.GetVersions(),
"Xor": xor.GetXorVersions(),
}

Expand Down
Loading