diff --git a/arrow/avro/reader_test.go b/arrow/avro/reader_test.go index 2ba91846..5c57e2d6 100644 --- a/arrow/avro/reader_test.go +++ b/arrow/avro/reader_test.go @@ -25,8 +25,11 @@ import ( "testing" "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/avro/testdata" + "github.com/apache/arrow-go/v18/arrow/memory" hamba "github.com/hamba/avro/v2" + "github.com/hamba/avro/v2/ocf" "github.com/stretchr/testify/assert" ) @@ -227,3 +230,109 @@ func TestReader(t *testing.T) { }) } } + +// TestOCFReaderBytesValues exercises avro `bytes` fields, both plain and as a +// ["null","bytes"] union: hamba hands the decoded value to the appenders as a +// bare []byte, which previously fell into appendBinaryData's fmt fallback and +// appended the formatted text (e.g. "[1 2 3]") instead of the payload. +func TestOCFReaderBytesValues(t *testing.T) { + schema := `{ + "type": "record", + "name": "rec", + "fields": [ + {"name": "plain", "type": "bytes"}, + {"name": "nullable", "type": ["null", "bytes"]} + ] + }` + payload := []byte{0x00, 0x01, 0xfe, 0xff} + + var buf bytes.Buffer + enc, err := ocf.NewEncoder(schema, &buf) + assert.NoError(t, err) + assert.NoError(t, enc.Encode(map[string]any{ + "plain": payload, + "nullable": map[string]any{"bytes": payload}, + })) + assert.NoError(t, enc.Encode(map[string]any{ + "plain": []byte{}, + "nullable": nil, + })) + assert.NoError(t, enc.Close()) + + ar, err := NewOCFReader(bytes.NewReader(buf.Bytes()), WithChunk(-1)) + assert.NoError(t, err) + defer ar.Close() + + assert.True(t, ar.Next()) + assert.NoError(t, ar.Err()) + rec := ar.RecordBatch() + + plain := rec.Column(0).(*array.Binary) + assert.Equal(t, payload, plain.Value(0)) + assert.Equal(t, []byte{}, plain.Value(1)) + + nullable := rec.Column(1).(*array.Binary) + assert.Equal(t, payload, nullable.Value(0)) + assert.True(t, nullable.IsNull(1)) +} + +// Types outside what the hamba decoder produces must error rather than append +// a fmt-formatted rendering of the value. +func TestAppendBinaryAndStringDataUnexpectedTypes(t *testing.T) { + bb := array.NewBinaryBuilder(memory.DefaultAllocator, arrow.BinaryTypes.Binary) + defer bb.Release() + + assert.NoError(t, appendBinaryData(bb, []byte{0x01})) + assert.NoError(t, appendBinaryData(bb, nil)) + assert.NoError(t, appendBinaryData(bb, map[string]any{"bytes": []byte{0x02}})) + assert.ErrorContains(t, appendBinaryData(bb, 42), "unexpected type int") + assert.ErrorContains(t, appendBinaryData(bb, map[string]any{"bytes": "text"}), "unexpected type string") + assert.Equal(t, 3, bb.Len()) + + sb := array.NewStringBuilder(memory.DefaultAllocator) + defer sb.Release() + + assert.NoError(t, appendStringData(sb, "ok")) + assert.NoError(t, appendStringData(sb, []byte("ok"))) + assert.NoError(t, appendStringData(sb, nil)) + assert.NoError(t, appendStringData(sb, map[string]any{"string": "ok"})) + assert.ErrorContains(t, appendStringData(sb, 42), "unexpected type int") + assert.ErrorContains(t, appendStringData(sb, map[string]any{"string": 42}), "unexpected type int") + assert.Equal(t, 4, sb.Len()) +} + +// loadDatum must surface appender errors from nested paths (map values, +// list items), not only from top-level and struct fields. +func TestLoadDatumPropagatesNestedAppendErrors(t *testing.T) { + newLoader := func(t *testing.T, avroSchema string) (*dataLoader, *array.RecordBuilder) { + t.Helper() + schema, err := hamba.Parse(avroSchema) + assert.NoError(t, err) + arrowSchema, err := ArrowSchemaFromAvro(schema) + assert.NoError(t, err) + bld := array.NewRecordBuilder(memory.DefaultAllocator, arrowSchema) + pos := newFieldPos() + ldr := newDataLoader() + for idx, fb := range bld.Fields() { + mapFieldBuilders(fb, arrowSchema.Field(idx), pos) + } + ldr.drawTree(pos) + return ldr, bld + } + + t.Run("map value", func(t *testing.T) { + ldr, bld := newLoader(t, `{"type":"record","name":"r","fields":[ + {"name":"m","type":{"type":"map","values":"bytes"}}]}`) + defer bld.Release() + assert.NoError(t, ldr.loadDatum(map[string]any{"m": map[string]any{"k": []byte{0x01}}})) + assert.ErrorContains(t, ldr.loadDatum(map[string]any{"m": map[string]any{"k": 42}}), "unexpected type int") + }) + + t.Run("list item", func(t *testing.T) { + ldr, bld := newLoader(t, `{"type":"record","name":"r","fields":[ + {"name":"l","type":{"type":"array","items":"bytes"}}]}`) + defer bld.Release() + assert.NoError(t, ldr.loadDatum(map[string]any{"l": []any{[]byte{0x01}}})) + assert.ErrorContains(t, ldr.loadDatum(map[string]any{"l": []any{42}}), "unexpected type int") + }) +} diff --git a/arrow/avro/reader_types.go b/arrow/avro/reader_types.go index aabad17e..45a7b145 100644 --- a/arrow/avro/reader_types.go +++ b/arrow/avro/reader_types.go @@ -92,10 +92,21 @@ func (d *dataLoader) drawTree(field *fieldPos) { // Since array.StructBuilder.AppendNull() will recursively append null to all of the // struct's fields, in the case of nil being passed to a struct's builderFunc it will // return a ErrNullStructData error to signal that all its sub-fields can be skipped. +// filterNullStruct drops ErrNullStructData, which signals a null struct +// whose sub-fields can be skipped rather than a failure. +func filterNullStruct(err error) error { + if err == ErrNullStructData { + return nil + } + return err +} + func (d *dataLoader) loadDatum(data any) error { if d.list == nil && d.mapField == nil { if d.mapValue != nil { - d.mapValue.appendFunc(data) + if err := filterNullStruct(d.mapValue.appendFunc(data)); err != nil { + return err + } } var NullParent *fieldPos for _, f := range d.fields { @@ -136,7 +147,9 @@ func (d *dataLoader) loadDatum(data any) error { } } else { for _, e := range dt { - d.children[0].loadDatum(e) + if err := d.children[0].loadDatum(e); err != nil { + return err + } } } case map[string]any: @@ -154,16 +167,24 @@ func (d *dataLoader) loadDatum(data any) error { } for _, c := range d.children { if c.list != nil { - c.loadDatum(c.list.getValue(data)) + if err := c.loadDatum(c.list.getValue(data)); err != nil { + return err + } } if c.mapField != nil { switch dt := data.(type) { case nil: - c.loadDatum(dt) + if err := c.loadDatum(dt); err != nil { + return err + } case map[string]any: - c.loadDatum(c.mapField.getValue(dt)) + if err := c.loadDatum(c.mapField.getValue(dt)); err != nil { + return err + } default: - c.loadDatum(c.mapField.getValue(data)) + if err := c.loadDatum(c.mapField.getValue(data)); err != nil { + return err + } } } } @@ -171,12 +192,18 @@ func (d *dataLoader) loadDatum(data any) error { if d.list != nil { switch dt := data.(type) { case nil: - d.list.appendFunc(dt) + if err := filterNullStruct(d.list.appendFunc(dt)); err != nil { + return err + } case []any: - d.list.appendFunc(dt) + if err := filterNullStruct(d.list.appendFunc(dt)); err != nil { + return err + } for _, e := range dt { if d.item != nil { - d.item.appendFunc(e) + if err := filterNullStruct(d.item.appendFunc(e)); err != nil { + return err + } } var NullParent *fieldPos for _, f := range d.fields { @@ -194,18 +221,26 @@ func (d *dataLoader) loadDatum(data any) error { } for _, c := range d.children { if c.list != nil { - c.loadDatum(c.list.getValue(e)) + if err := c.loadDatum(c.list.getValue(e)); err != nil { + return err + } } if c.mapField != nil { - c.loadDatum(c.mapField.getValue(e)) + if err := c.loadDatum(c.mapField.getValue(e)); err != nil { + return err + } } } } case map[string]any: - d.list.appendFunc(dt["array"]) + if err := filterNullStruct(d.list.appendFunc(dt["array"])); err != nil { + return err + } for _, e := range dt["array"].([]any) { if d.item != nil { - d.item.appendFunc(e) + if err := filterNullStruct(d.item.appendFunc(e)); err != nil { + return err + } } var NullParent *fieldPos for _, f := range d.fields { @@ -222,27 +257,40 @@ func (d *dataLoader) loadDatum(data any) error { } } for _, c := range d.children { - c.loadDatum(c.list.getValue(e)) + if err := c.loadDatum(c.list.getValue(e)); err != nil { + return err + } } } default: - d.list.appendFunc(data) - d.item.appendFunc(dt) + if err := filterNullStruct(d.list.appendFunc(data)); err != nil { + return err + } + if err := filterNullStruct(d.item.appendFunc(dt)); err != nil { + return err + } } } if d.mapField != nil { switch dt := data.(type) { case nil: - d.mapField.appendFunc(dt) + if err := filterNullStruct(d.mapField.appendFunc(dt)); err != nil { + return err + } case map[string]any: - - d.mapField.appendFunc(dt) + if err := filterNullStruct(d.mapField.appendFunc(dt)); err != nil { + return err + } for k, v := range dt { - d.mapKey.appendFunc(k) + if err := filterNullStruct(d.mapKey.appendFunc(k)); err != nil { + return err + } if d.mapValue != nil { - d.mapValue.appendFunc(v) - } else { - d.children[0].loadDatum(v) + if err := filterNullStruct(d.mapValue.appendFunc(v)); err != nil { + return err + } + } else if err := d.children[0].loadDatum(v); err != nil { + return err } } } @@ -397,8 +445,7 @@ func mapFieldBuilders(b array.Builder, field arrow.Field, parent *fieldPos) { switch bt := b.(type) { case *array.BinaryBuilder: f.appendFunc = func(data interface{}) error { - appendBinaryData(bt, data) - return nil + return appendBinaryData(bt, data) } case *array.BinaryDictionaryBuilder: // has metadata for Avro enum symbols @@ -551,8 +598,7 @@ func mapFieldBuilders(b array.Builder, field arrow.Field, parent *fieldPos) { } case *array.StringBuilder: f.appendFunc = func(data interface{}) error { - appendStringData(bt, data) - return nil + return appendStringData(bt, data) } case *array.StructBuilder: // has metadata for Avro Union named types @@ -590,20 +636,25 @@ func mapFieldBuilders(b array.Builder, field arrow.Field, parent *fieldPos) { } } -func appendBinaryData(b *array.BinaryBuilder, data interface{}) { +func appendBinaryData(b *array.BinaryBuilder, data interface{}) error { switch dt := data.(type) { case nil: b.AppendNull() + case []byte: + b.Append(dt) case map[string]any: switch ct := dt["bytes"].(type) { case nil: b.AppendNull() + case []byte: + b.Append(ct) default: - b.Append(ct.([]byte)) + return fmt.Errorf("unexpected type %T for avro bytes union value", ct) } default: - b.Append(fmt.Append([]byte{}, data)) + return fmt.Errorf("unexpected type %T for avro bytes value", data) } + return nil } func appendBinaryDictData(b *array.BinaryDictionaryBuilder, data interface{}) { @@ -853,22 +904,27 @@ func appendInt64Data(b *array.Int64Builder, data interface{}) { } } -func appendStringData(b *array.StringBuilder, data interface{}) { +func appendStringData(b *array.StringBuilder, data interface{}) error { switch dt := data.(type) { case nil: b.AppendNull() case string: b.Append(dt) + case []byte: + b.Append(string(dt)) case map[string]any: switch v := dt["string"].(type) { case nil: b.AppendNull() case string: b.Append(v) + default: + return fmt.Errorf("unexpected type %T for avro string union value", v) } default: - b.Append(fmt.Sprint(data)) + return fmt.Errorf("unexpected type %T for avro string value", data) } + return nil } func appendTime32Data(b *array.Time32Builder, data interface{}) { diff --git a/arrow/avro/testdata/testdata.go b/arrow/avro/testdata/testdata.go index a5090b40..4c8bac71 100644 --- a/arrow/avro/testdata/testdata.go +++ b/arrow/avro/testdata/testdata.go @@ -42,9 +42,7 @@ const ( type ByteArray []byte func (b ByteArray) MarshalJSON() ([]byte, error) { - s := fmt.Sprint(b) - encoded := base64.StdEncoding.EncodeToString([]byte(s)) - return json.Marshal(encoded) + return json.Marshal(base64.StdEncoding.EncodeToString(b)) } type TimestampMicros int64