diff --git a/ops/cast/cast.go b/ops/cast/cast.go index 71f63a8..b51ca1a 100644 --- a/ops/cast/cast.go +++ b/ops/cast/cast.go @@ -7,7 +7,7 @@ import ( ) var castTypeConstraints = [][]tensor.Dtype{ - {tensor.Int16, tensor.Uint16, tensor.Int32, tensor.Uint32, tensor.Int64, tensor.Uint64, tensor.Float32, tensor.Float64}, + {tensor.Bool, tensor.Int16, tensor.Uint16, tensor.Int32, tensor.Uint32, tensor.Int64, tensor.Uint64, tensor.Float32, tensor.Float64}, } // Cast represents the ONNX cast operator. diff --git a/ops/cast/cast_test.go b/ops/cast/cast_test.go index 5d22875..5db92a9 100644 --- a/ops/cast/cast_test.go +++ b/ops/cast/cast_test.go @@ -60,6 +60,20 @@ func TestCast(t *testing.T) { 3, []int8{1, 1}, }, + { + 13, + []float64{1.0, 0.0}, + []int{2}, + 9, + []bool{true, false}, + }, + { + 13, + []bool{false, true}, + []int{2}, + 1, + []float32{0.0, 1.0}, + }, } for _, test := range tests { @@ -102,9 +116,9 @@ func TestInputValidationCast(t *testing.T) { { 13, []tensor.Tensor{ - ops.TensorWithBackingFixture([]bool{true, false}, 2), + ops.TensorWithBackingFixture([]int{1, 0}, 2), }, - ops.ErrInvalidInputType(0, "bool", cast13BaseOpFixture()), + ops.ErrInvalidInputType(0, "int", cast13BaseOpFixture()), }, } diff --git a/ops/convert.go b/ops/convert.go index dd313cc..26881c3 100644 --- a/ops/convert.go +++ b/ops/convert.go @@ -5,9 +5,33 @@ import ( "gorgonia.org/tensor" ) -// Number is a type which represents a number. -type Number interface { - float32 | float64 | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64 +func DTypeToONNXType(t tensor.Dtype) (int32, error) { + switch t { + case tensor.Float32: + return int32(onnx.TensorProto_FLOAT), nil + case tensor.Float64: + return int32(onnx.TensorProto_DOUBLE), nil + case tensor.Int8: + return int32(onnx.TensorProto_INT8), nil + case tensor.Int16: + return int32(onnx.TensorProto_INT16), nil + case tensor.Int32: + return int32(onnx.TensorProto_INT32), nil + case tensor.Int64: + return int32(onnx.TensorProto_INT64), nil + case tensor.Uint8: + return int32(onnx.TensorProto_UINT8), nil + case tensor.Uint16: + return int32(onnx.TensorProto_UINT16), nil + case tensor.Uint32: + return int32(onnx.TensorProto_UINT32), nil + case tensor.Uint64: + return int32(onnx.TensorProto_UINT64), nil + case tensor.Bool: + return int32(onnx.TensorProto_BOOL), nil + default: + return 0, ErrUnknownTensorONNXDtype(t) + } } // ConvertTensorDtype converts an interface of a specific dtype to a new dtype. @@ -40,6 +64,8 @@ func ConvertTensorDtype(t tensor.Tensor, newType int32) (tensor.Tensor, error) { newBacking, err = convertBacking(backing.([]uint32), newType) case tensor.Uint64: newBacking, err = convertBacking(backing.([]uint64), newType) + case tensor.Bool: + newBacking, err = convertBooleanBacking(backing.([]bool), newType) default: return nil, ErrConversionInvalidType(t.Dtype(), newType) } @@ -51,34 +77,7 @@ func ConvertTensorDtype(t tensor.Tensor, newType int32) (tensor.Tensor, error) { return tensor.New(tensor.WithShape(t.Shape()...), tensor.WithBacking(newBacking)), nil } -func DTypeToONNXType(t tensor.Dtype) (int32, error) { - switch t { - case tensor.Float32: - return int32(onnx.TensorProto_FLOAT), nil - case tensor.Float64: - return int32(onnx.TensorProto_DOUBLE), nil - case tensor.Int8: - return int32(onnx.TensorProto_INT8), nil - case tensor.Int16: - return int32(onnx.TensorProto_INT16), nil - case tensor.Int32: - return int32(onnx.TensorProto_INT32), nil - case tensor.Int64: - return int32(onnx.TensorProto_INT64), nil - case tensor.Uint8: - return int32(onnx.TensorProto_UINT8), nil - case tensor.Uint16: - return int32(onnx.TensorProto_UINT16), nil - case tensor.Uint32: - return int32(onnx.TensorProto_UINT32), nil - case tensor.Uint64: - return int32(onnx.TensorProto_UINT64), nil - default: - return 0, ErrUnknownTensorONNXDtype(t) - } -} - -func convertBacking[B Number](backing []B, dataType int32) (any, error) { +func convertBacking[B NumericType](backing []B, dataType int32) (any, error) { switch onnx.TensorProto_DataType(dataType) { case onnx.TensorProto_FLOAT: return createNewBacking[B, float32](backing), nil @@ -100,14 +99,47 @@ func convertBacking[B Number](backing []B, dataType int32) (any, error) { return createNewBacking[B, uint32](backing), nil case onnx.TensorProto_UINT64: return createNewBacking[B, uint64](backing), nil - case onnx.TensorProto_BFLOAT16, onnx.TensorProto_BOOL, onnx.TensorProto_COMPLEX64, onnx.TensorProto_COMPLEX128, onnx.TensorProto_FLOAT16, onnx.TensorProto_UNDEFINED, onnx.TensorProto_STRING: + case onnx.TensorProto_BOOL: + return createNewBooleanBacking[B](backing), nil + case onnx.TensorProto_BFLOAT16, onnx.TensorProto_COMPLEX64, onnx.TensorProto_COMPLEX128, onnx.TensorProto_FLOAT16, onnx.TensorProto_UNDEFINED, onnx.TensorProto_STRING: return nil, ErrConversionNotSupported(dataType) default: return nil, ErrConversionNotSupported(dataType) } } -func createNewBacking[B Number, R Number](backing []B) []R { +func convertBooleanBacking(backing []bool, dataType int32) (any, error) { + switch onnx.TensorProto_DataType(dataType) { + case onnx.TensorProto_FLOAT: + return createNewBackingFromBoolean[float32](backing), nil + case onnx.TensorProto_DOUBLE: + return createNewBackingFromBoolean[float64](backing), nil + case onnx.TensorProto_INT8: + return createNewBackingFromBoolean[int8](backing), nil + case onnx.TensorProto_INT16: + return createNewBackingFromBoolean[int16](backing), nil + case onnx.TensorProto_INT32: + return createNewBackingFromBoolean[int32](backing), nil + case onnx.TensorProto_INT64: + return createNewBackingFromBoolean[int64](backing), nil + case onnx.TensorProto_UINT8: + return createNewBackingFromBoolean[uint8](backing), nil + case onnx.TensorProto_UINT16: + return createNewBackingFromBoolean[uint16](backing), nil + case onnx.TensorProto_UINT32: + return createNewBackingFromBoolean[uint32](backing), nil + case onnx.TensorProto_UINT64: + return createNewBackingFromBoolean[uint64](backing), nil + case onnx.TensorProto_BOOL: + return backing, nil + case onnx.TensorProto_BFLOAT16, onnx.TensorProto_COMPLEX64, onnx.TensorProto_COMPLEX128, onnx.TensorProto_FLOAT16, onnx.TensorProto_UNDEFINED, onnx.TensorProto_STRING: + return nil, ErrConversionNotSupported(dataType) + default: + return nil, ErrConversionNotSupported(dataType) + } +} + +func createNewBacking[B NumericType, R NumericType](backing []B) []R { newBacking := make([]R, len(backing)) for i := range backing { newBacking[i] = R(backing[i]) @@ -115,3 +147,26 @@ func createNewBacking[B Number, R Number](backing []B) []R { return newBacking } + +func createNewBooleanBacking[B NumericType](backing []B) []bool { + newBacking := make([]bool, len(backing)) + for i := range backing { + newBacking[i] = backing[i] != 0 + } + + return newBacking +} + +func createNewBackingFromBoolean[T NumericType](backing []bool) []T { + newBacking := make([]T, len(backing)) + + for i := range backing { + if backing[i] { + newBacking[i] = 1 + } else { + newBacking[i] = 0 + } + } + + return newBacking +} diff --git a/ops/convert_test.go b/ops/convert_test.go index 400f67e..be2ea4e 100644 --- a/ops/convert_test.go +++ b/ops/convert_test.go @@ -87,10 +87,10 @@ func TestConvertTensorDtype(t *testing.T) { nil, }, { - tensor.New(tensor.WithShape(2), tensor.WithBacking([]bool{true, false})), + tensor.New(tensor.WithShape(2), tensor.WithBacking([]string{"joe", "joe"})), tensor.New(tensor.WithShape(2), tensor.WithBacking([]float32{1.0, 2.0})), 1, - ErrConversionInvalidType(tensor.Bool, 1), + ErrConversionInvalidType(tensor.String, 1), }, { tensor.New(tensor.WithShape(2), tensor.WithBacking([]float32{1.0, 2.1})), diff --git a/ops/types.go b/ops/types.go index 385ba0c..a01ed4b 100644 --- a/ops/types.go +++ b/ops/types.go @@ -7,8 +7,12 @@ type FloatType interface { float32 | float64 } +type IntType interface { + uint8 | uint16 | uint32 | uint64 | int8 | int16 | int32 | int64 +} + type NumericType interface { - uint8 | uint16 | uint32 | uint64 | int8 | int16 | int32 | int64 | FloatType + IntType | FloatType } // AllTypes is a type constraint which allows all types.