diff --git a/internal/object/wire.go b/internal/object/wire.go index 548f337094..7f04839d6f 100644 --- a/internal/object/wire.go +++ b/internal/object/wire.go @@ -4,7 +4,9 @@ import ( "errors" "fmt" "io" + "strconv" + iprotobuf "github.com/nspcc-dev/neofs-node/internal/protobuf" "github.com/nspcc-dev/neofs-sdk-go/object" protoobject "github.com/nspcc-dev/neofs-sdk-go/proto/object" "github.com/nspcc-dev/neofs-sdk-go/proto/refs" @@ -19,8 +21,31 @@ const ( fieldObjectSignature fieldObjectHeader fieldObjectPayload + + FieldHeaderVersion = 1 + FieldHeaderContainerID = 2 + FieldHeaderOwnerID = 3 + FieldHeaderCreationEpoch = 4 + FieldHeaderPayloadLength = 5 + FieldHeaderPayloadHash = 6 + FieldHeaderType = 7 + FieldHeaderHomoHash = 8 + FieldHeaderSessionToken = 9 + FieldHeaderAttributes = 10 + FieldHeaderSplit = 11 + FieldHeaderSessionTokenV2 = 12 + + FieldHeaderSplitParent = 1 + FieldHeaderSplitPrevious = 2 + FieldHeaderSplitParentSignature = 3 + FieldHeaderSplitParentHeader = 4 + FieldHeaderSplitChildren = 5 + FieldHeaderSplitSplitID = 6 + FieldHeaderSplitFirst = 7 ) +var errEmptyData = errors.New("empty data") + // WriteWithoutPayload writes the object header to the given writer without the payload. func WriteWithoutPayload(w io.Writer, obj object.Object) error { header := obj.CutPayload().Marshal() @@ -41,7 +66,7 @@ func ExtractHeaderAndPayload(data []byte) (*object.Object, []byte, error) { ) if len(data) == 0 { - return nil, nil, fmt.Errorf("empty data") + return nil, nil, errEmptyData } for offset < len(data) { @@ -108,3 +133,145 @@ func ReadHeaderPrefix(r io.Reader) (*object.Object, []byte, error) { } return ExtractHeaderAndPayload(buf[:n]) } + +// GetNonPayloadFieldBounds seeks ID, signature and header in object message and +// parses their boundaries. +// +// If buf is empty, GetNonPayloadFieldBounds returns an error. +// +// If any field is missing, no error is returned. +// +// Message should have ascending field order, otherwise error returns. +func GetNonPayloadFieldBounds(buf []byte) (iprotobuf.FieldBounds, iprotobuf.FieldBounds, iprotobuf.FieldBounds, error) { + var idf, sigf, hdrf iprotobuf.FieldBounds + if len(buf) == 0 { + return idf, sigf, hdrf, errEmptyData + } + + var off int + var prevNum protowire.Number +loop: + for { + num, typ, n, err := iprotobuf.ParseTag(buf[off:]) + if err != nil { + return idf, sigf, hdrf, err + } + + if num > fieldObjectHeader { + break + } + if num < prevNum { + return idf, sigf, hdrf, iprotobuf.NewUnorderedFieldsError(prevNum, num) + } + if num == prevNum { + return idf, sigf, hdrf, iprotobuf.NewRepeatedFieldError(num) + } + prevNum = num + + f, err := iprotobuf.ParseLENFieldBounds(buf, off, n, num, typ) + if err != nil { + return idf, sigf, hdrf, err + } + + switch num { + case fieldObjectID: + idf = f + case fieldObjectSignature: + sigf = f + case fieldObjectHeader: + hdrf = f + break loop + default: + panic("unreachable with num " + strconv.Itoa(int(num))) + } + + off = f.To + + if off == len(buf) { + break + } + } + + return idf, sigf, hdrf, nil +} + +// GetParentNonPayloadFieldBounds seeks parent's ID, signature and header in child +// object message and parses their boundaries. +// +// If buf is empty, GetParentNonPayloadFieldBounds returns an error. +// +// If any field is missing, no error is returned. +// +// Message should have ascending field order, otherwise error returns. +func GetParentNonPayloadFieldBounds(buf []byte) (iprotobuf.FieldBounds, iprotobuf.FieldBounds, iprotobuf.FieldBounds, error) { + var idf, sigf, hdrf iprotobuf.FieldBounds + if len(buf) == 0 { + return idf, sigf, hdrf, errEmptyData + } + + rootHdrf, err := iprotobuf.GetLENFieldBounds(buf, fieldObjectHeader) + if err != nil { + return idf, sigf, hdrf, err + } + + if rootHdrf.IsMissing() { + return idf, sigf, hdrf, nil + } + + splitf, err := iprotobuf.GetLENFieldBounds(buf[rootHdrf.ValueFrom:rootHdrf.To], FieldHeaderSplit) + if err != nil { + return idf, sigf, hdrf, err + } + + if splitf.IsMissing() { + return idf, sigf, hdrf, nil + } + + buf = buf[:rootHdrf.ValueFrom+splitf.To] + off := rootHdrf.ValueFrom + splitf.ValueFrom + var prevNum protowire.Number +loop: + for { + num, typ, n, err := iprotobuf.ParseTag(buf[off:]) + if err != nil { + return idf, sigf, hdrf, err + } + + if num > FieldHeaderSplitParentHeader { + break + } + if num < prevNum { + return idf, sigf, hdrf, iprotobuf.NewUnorderedFieldsError(prevNum, num) + } + if num == prevNum { + return idf, sigf, hdrf, iprotobuf.NewRepeatedFieldError(num) + } + prevNum = num + + f, err := iprotobuf.ParseLENFieldBounds(buf, off, n, num, typ) + if err != nil { + return idf, sigf, hdrf, err + } + + switch num { + case FieldHeaderSplitParent: + idf = f + case FieldHeaderSplitPrevious: + case FieldHeaderSplitParentSignature: + sigf = f + case FieldHeaderSplitParentHeader: + hdrf = f + break loop + default: + panic("unreachable with num " + strconv.Itoa(int(num))) + } + + off = f.To + + if off == len(buf) { + break + } + } + + return idf, sigf, hdrf, nil +} diff --git a/internal/object/wire_test.go b/internal/object/wire_test.go index cfeb77f361..b5b877a3f6 100644 --- a/internal/object/wire_test.go +++ b/internal/object/wire_test.go @@ -3,15 +3,52 @@ package object_test import ( "bytes" "crypto/rand" + "crypto/sha256" "io" + "math" + "slices" "testing" iobject "github.com/nspcc-dev/neofs-node/internal/object" + iprotobuf "github.com/nspcc-dev/neofs-node/internal/protobuf" + "github.com/nspcc-dev/neofs-node/internal/testutil" + "github.com/nspcc-dev/neofs-sdk-go/checksum" + cidtest "github.com/nspcc-dev/neofs-sdk-go/container/id/test" + neofscrypto "github.com/nspcc-dev/neofs-sdk-go/crypto" + neofscryptotest "github.com/nspcc-dev/neofs-sdk-go/crypto/test" "github.com/nspcc-dev/neofs-sdk-go/object" + oid "github.com/nspcc-dev/neofs-sdk-go/object/id" + oidtest "github.com/nspcc-dev/neofs-sdk-go/object/id/test" objecttest "github.com/nspcc-dev/neofs-sdk-go/object/test" + "github.com/nspcc-dev/neofs-sdk-go/version" + "github.com/nspcc-dev/tzhash/tz" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/encoding/protowire" ) +func TestFields(t *testing.T) { + require.EqualValues(t, 1, iobject.FieldHeaderVersion) + require.EqualValues(t, 2, iobject.FieldHeaderContainerID) + require.EqualValues(t, 3, iobject.FieldHeaderOwnerID) + require.EqualValues(t, 4, iobject.FieldHeaderCreationEpoch) + require.EqualValues(t, 5, iobject.FieldHeaderPayloadLength) + require.EqualValues(t, 6, iobject.FieldHeaderPayloadHash) + require.EqualValues(t, 7, iobject.FieldHeaderType) + require.EqualValues(t, 8, iobject.FieldHeaderHomoHash) + require.EqualValues(t, 9, iobject.FieldHeaderSessionToken) + require.EqualValues(t, 10, iobject.FieldHeaderAttributes) + require.EqualValues(t, 11, iobject.FieldHeaderSplit) + require.EqualValues(t, 12, iobject.FieldHeaderSessionTokenV2) + + require.EqualValues(t, 1, iobject.FieldHeaderSplitParent) + require.EqualValues(t, 2, iobject.FieldHeaderSplitPrevious) + require.EqualValues(t, 3, iobject.FieldHeaderSplitParentSignature) + require.EqualValues(t, 4, iobject.FieldHeaderSplitParentHeader) + require.EqualValues(t, 5, iobject.FieldHeaderSplitChildren) + require.EqualValues(t, 6, iobject.FieldHeaderSplitSplitID) + require.EqualValues(t, 7, iobject.FieldHeaderSplitFirst) +} + func TestWriteWithoutPayload(t *testing.T) { t.Run("write empty object", func(t *testing.T) { var buf bytes.Buffer @@ -97,3 +134,168 @@ func TestReadHeaderPrefix(t *testing.T) { require.Equal(t, expectedSize, len(payloadPrefix)) require.Equal(t, payload[:expectedSize], payloadPrefix) } + +func TestGetNonPayloadFieldBounds(t *testing.T) { + t.Run("empty data", func(t *testing.T) { + _, _, _, err := iobject.GetNonPayloadFieldBounds([]byte{}) + require.EqualError(t, err, "empty data") + }) + + id := oidtest.ID() + sig := neofscryptotest.Signature() + + obj := objecttest.Object() + obj.SetID(id) + obj.SetSignature(&sig) + + buf := obj.Marshal() + + idf, sigf, hdrf, err := iobject.GetNonPayloadFieldBounds(buf) + require.NoError(t, err) + + assertFound := func(t *testing.T, f iprotobuf.FieldBounds, tag byte, exp []byte) { + require.False(t, f.IsMissing()) + require.EqualValues(t, tag, buf[f.From]) + ln, n, err := iprotobuf.ParseLENField(buf[f.From+1:], 42, protowire.BytesType) + require.NoError(t, err) + require.EqualValues(t, 1+n, f.ValueFrom-f.From) + require.EqualValues(t, ln, f.To-f.ValueFrom) + require.True(t, bytes.Equal(exp, buf[f.ValueFrom:f.To])) + } + + assertFound(t, idf, iprotobuf.TagBytes1, id.Marshal()) + assertFound(t, sigf, iprotobuf.TagBytes2, sig.Marshal()) + + hdr := obj.ProtoMessage().Header + hdrBuf := make([]byte, hdr.MarshaledSize()) + hdr.MarshalStable(hdrBuf) + assertFound(t, hdrf, iprotobuf.TagBytes3, hdrBuf) +} + +func BenchmarkGetNonPayloadFieldBounds(b *testing.B) { + id := oidtest.ID() + const sigLen = 100 + const hdrLen = 16 << 10 + + buf := slices.Concat( + []byte{iprotobuf.TagBytes1, oid.Size + 2, iprotobuf.TagBytes1, oid.Size}, id[:], + []byte{iprotobuf.TagBytes2, sigLen}, testutil.RandByteSlice(sigLen), + []byte{iprotobuf.TagBytes3, 128, 128, 1}, testutil.RandByteSlice(hdrLen), + ) + + idf, sigf, hdrf, err := iobject.GetNonPayloadFieldBounds(buf) + require.NoError(b, err) + require.EqualValues(b, 0, idf.From) + require.EqualValues(b, idf.From+2, idf.ValueFrom) + require.EqualValues(b, idf.ValueFrom+oid.Size, idf.To) + require.EqualValues(b, idf.To, sigf.From) + require.EqualValues(b, sigf.From+2, sigf.ValueFrom) + require.EqualValues(b, sigf.ValueFrom+sigLen, sigf.To) + require.EqualValues(b, sigf.To, hdrf.From) + require.EqualValues(b, hdrf.From+4, hdrf.ValueFrom) + require.EqualValues(b, hdrf.ValueFrom+hdrLen, hdrf.To) + require.EqualValues(b, len(buf), hdrf.To) + + b.ReportAllocs() + for b.Loop() { + _, _, _, err = iobject.GetNonPayloadFieldBounds(buf) + require.NoError(b, err) + } +} + +func TestGetParentNonPayloadFieldBounds(t *testing.T) { + t.Run("empty data", func(t *testing.T) { + _, _, _, err := iobject.GetParentNonPayloadFieldBounds([]byte{}) + require.EqualError(t, err, "empty data") + }) + + parID := oidtest.ID() + parSig := neofscryptotest.Signature() + + par := objecttest.Object() + par.SetID(parID) + par.SetSignature(&parSig) + par.ResetRelations() + + obj := objecttest.Object() + obj.SetParent(&par) + + buf := obj.Marshal() + + idf, sigf, hdrf, err := iobject.GetParentNonPayloadFieldBounds(buf) + require.NoError(t, err) + + assertFound := func(t *testing.T, f iprotobuf.FieldBounds, tag byte, exp []byte) { + require.False(t, f.IsMissing()) + require.EqualValues(t, tag, buf[f.From]) + ln, n, err := iprotobuf.ParseLENField(buf[f.From+1:], 42, protowire.BytesType) + require.NoError(t, err) + require.EqualValues(t, 1+n, f.ValueFrom-f.From) + require.EqualValues(t, ln, f.To-f.ValueFrom) + require.True(t, bytes.Equal(exp, buf[f.ValueFrom:f.To])) + } + + assertFound(t, idf, iprotobuf.TagBytes1, parID.Marshal()) + assertFound(t, sigf, iprotobuf.TagBytes3, parSig.Marshal()) + + parHdr := par.ProtoMessage().Header + parHdrBuf := make([]byte, parHdr.MarshaledSize()) + parHdr.MarshalStable(parHdrBuf) + assertFound(t, hdrf, iprotobuf.TagBytes4, parHdrBuf) +} + +func BenchmarkGetParentNonPayloadFieldBounds(b *testing.B) { + parID := oidtest.ID() + + sig := neofscrypto.NewSignatureFromRawKey(neofscrypto.ECDSA_DETERMINISTIC_SHA256, testutil.RandByteSlice(33), testutil.RandByteSlice(64)) + ver := version.New(123, 456) + pldHash := checksum.New(checksum.SHA256, testutil.RandByteSlice(sha256.Size)) + pldHomoHash := checksum.New(checksum.TillichZemor, testutil.RandByteSlice(tz.Size)) + + fillHeader := func(obj *object.Object) { + obj.SetID(oidtest.ID()) + obj.SetSignature(&sig) + obj.SetVersion(&ver) + obj.SetContainerID(cidtest.ID()) + obj.SetCreationEpoch(math.MaxUint64) + obj.SetPayloadSize(math.MaxUint64) + obj.SetPayloadChecksum(pldHash) + obj.SetType(math.MaxInt32) + obj.SetPayloadHomomorphicHash(pldHomoHash) + obj.SetAttributes( + object.NewAttribute("key1", "val1"), + object.NewAttribute("key2", "val2"), + object.NewAttribute("key3", "val3"), + ) + } + + var par object.Object + fillHeader(&par) + par.SetID(parID) + + var obj object.Object + fillHeader(&obj) + obj.SetPreviousID(oidtest.ID()) + obj.SetParent(&par) + + buf := obj.Marshal() + + idf, sigf, hdrf, err := iobject.GetParentNonPayloadFieldBounds(buf) + require.NoError(b, err) + require.Positive(b, idf.From) + require.EqualValues(b, idf.From+2, idf.ValueFrom) + require.EqualValues(b, idf.ValueFrom+2+oid.Size, idf.To) + require.EqualValues(b, idf.To+2+2+oid.Size, sigf.From) + require.EqualValues(b, sigf.From+2, sigf.ValueFrom) + require.EqualValues(b, sigf.ValueFrom+len(sig.Marshal()), sigf.To) + require.EqualValues(b, sigf.To, hdrf.From) + require.EqualValues(b, hdrf.From+3, hdrf.ValueFrom) + require.EqualValues(b, hdrf.ValueFrom+par.HeaderLen(), hdrf.To) + require.EqualValues(b, len(buf), hdrf.To) + + b.ReportAllocs() + for b.Loop() { + _, _, _, err = iobject.GetNonPayloadFieldBounds(buf) + require.NoError(b, err) + } +} diff --git a/internal/protobuf/errors.go b/internal/protobuf/errors.go new file mode 100644 index 0000000000..95a21016f1 --- /dev/null +++ b/internal/protobuf/errors.go @@ -0,0 +1,19 @@ +package protobuf + +import ( + "fmt" + + "google.golang.org/protobuf/encoding/protowire" +) + +// NewUnorderedFieldsError returns common error for field order violation when +// field #n2 goes after #n1. +func NewUnorderedFieldsError(n1, n2 protowire.Number) error { + return fmt.Errorf("unordered fields: #%d after #%d", n2, n1) +} + +// NewRepeatedFieldError returns common error for field #n repeated more than +// once. +func NewRepeatedFieldError(n protowire.Number) error { + return fmt.Errorf("repeated field #%d", n) +} diff --git a/internal/protobuf/parsers.go b/internal/protobuf/parsers.go new file mode 100644 index 0000000000..b11f6c57c2 --- /dev/null +++ b/internal/protobuf/parsers.go @@ -0,0 +1,213 @@ +package protobuf + +import ( + "errors" + "fmt" + "math" + + "google.golang.org/protobuf/encoding/protowire" +) + +// ParseVarint parses varint-encoded uint64 from buf. Returns parsed value and +// number of bytes read. +func ParseVarint(buf []byte) (uint64, int, error) { + u, n := protowire.ConsumeVarint(buf) + if n < 0 { + // TODO: protowire package adds 'proto' prefix to error. We don't need it. + return 0, 0, protowire.ParseError(n) + } + + return u, n, nil +} + +// ParseTag parses field tag from buf. Returns field number, type and number of +// bytes read. +func ParseTag(buf []byte) (protowire.Number, protowire.Type, int, error) { + u, n, err := ParseVarint(buf) + if err != nil { + return 0, 0, 0, fmt.Errorf("parse varint: %w", err) + } + + num, typ := protowire.DecodeTag(u) + if err = checkFieldNumber(num); err != nil { + return 0, 0, 0, err + } + + return num, typ, n, nil +} + +// ParseLEN parses varint-encoded length from buf and check its overflow. Returns +// parsed value and number of bytes read. +func ParseLEN(buf []byte) (int, int, error) { + ln, n, err := ParseVarint(buf) + if err != nil { + return 0, 0, fmt.Errorf("parse varint: %w", err) + } + + if ln > math.MaxInt { + return 0, 0, fmt.Errorf("value %d overflows int", ln) + } + + if rem := len(buf) - n; int(ln) > rem { + return 0, 0, newTruncatedBufferError(int(ln), rem) + } + + return int(ln), n, nil +} + +// ParseLENField parses length of LEN field with preread number and type from +// buf. Returns parsed value and number of bytes read. +// +// If there is an error, its text contains num and typ. +func ParseLENField(buf []byte, num protowire.Number, typ protowire.Type) (int, int, error) { + err := checkFieldType(num, protowire.BytesType, typ) + if err != nil { + return 0, 0, err + } + + ln, n, err := ParseLEN(buf) + if err != nil { + return 0, 0, wrapParseFieldError(num, protowire.BytesType, err) + } + + return ln, n, nil +} + +// ParseLENFieldBounds parses boundaries of LEN field with preread tag length, +// number and type at given offset from buf. +// +// If there is an error, its text contains num and typ. +func ParseLENFieldBounds(buf []byte, off int, tagLn int, num protowire.Number, typ protowire.Type) (FieldBounds, error) { + ln, n, err := ParseLENField(buf[off+tagLn:], num, typ) + if err != nil { + return FieldBounds{}, err + } + + var f FieldBounds + f.From = off + f.ValueFrom = f.From + tagLn + n + f.To = f.ValueFrom + ln + + return f, nil +} + +// ParseEnum parses enum value from buf. Returns parsed value and number of +// bytes read. +func ParseEnum[T ~int32](buf []byte) (T, int, error) { + u, n, err := ParseVarint(buf) + if err != nil { + return 0, 0, fmt.Errorf("parse varint: %w", err) + } + + if u > math.MaxInt32 { + return 0, 0, fmt.Errorf("value %d overflows int32", u) + } + + return T(u), n, nil +} + +// ParseEnumField parses value of enum field with preread number and type from +// buf. Returns parsed value and number of bytes read. +// +// If there is an error, its text contains num and typ. +func ParseEnumField[T ~int32](buf []byte, num protowire.Number, typ protowire.Type) (T, int, error) { + err := checkFieldType(num, protowire.VarintType, typ) + if err != nil { + return 0, 0, err + } + + e, n, err := ParseEnum[T](buf) + if err != nil { + return 0, 0, wrapParseFieldError(num, protowire.VarintType, err) + } + + return e, n, nil +} + +// ParseUint32 parses varint-encoded uint32 from buf. Returns parsed value and +// number of bytes read. +func ParseUint32(buf []byte) (uint32, int, error) { + u, n, err := ParseVarint(buf) + if err != nil { + return 0, 0, fmt.Errorf("parse varint: %w", err) + } + + if u > math.MaxUint32 { + return 0, 0, fmt.Errorf("value %d overflows uint32", u) + } + + return uint32(u), n, nil +} + +// ParseUint32Field parses value of uint32 field from buf. Returns parsed value +// and number of bytes read. +// +// If there is an error, its text contains num and typ. +func ParseUint32Field(buf []byte, num protowire.Number, typ protowire.Type) (uint32, int, error) { + err := checkFieldType(num, protowire.VarintType, typ) + if err != nil { + return 0, 0, err + } + + u, n, err := ParseUint32(buf) + if err != nil { + return 0, 0, wrapParseFieldError(num, protowire.VarintType, err) + } + + return u, n, nil +} + +// ParseUint64Field parses value of uint64 field with preread number and type +// from buf. Returns value and its length. +// +// If there is an error, its text contains num and typ. +func ParseUint64Field(buf []byte, num protowire.Number, typ protowire.Type) (uint64, int, error) { + err := checkFieldType(num, protowire.VarintType, typ) + if err != nil { + return 0, 0, err + } + + u, n, err := ParseVarint(buf) + if err != nil { + return 0, 0, wrapParseFieldError(num, protowire.VarintType, fmt.Errorf("parse varint: %w", err)) + } + + return u, n, nil +} + +// SkipField parses length of skipped field with preread number and type from +// buf and checks its overflow. Returns number of bytes read. +// +// If there is an error, its text contains num and typ. +func SkipField(buf []byte, num protowire.Number, typ protowire.Type) (int, error) { + var err error + + switch typ { + case protowire.VarintType: + var n int + if _, n, err = ParseVarint(buf); err == nil { + return n, nil + } + case protowire.Fixed64Type: + if len(buf) >= fixed64Len { + return fixed64Len, nil + } + err = newTruncatedBufferError(fixed64Len, len(buf)) + case protowire.BytesType: + var ln, n int + if ln, n, err = ParseLEN(buf); err == nil { + return n + ln, nil + } + case protowire.StartGroupType, protowire.EndGroupType: + err = errors.New("type is not supported") + case protowire.Fixed32Type: + if len(buf) >= fixed32Len { + return fixed32Len, nil + } + err = newTruncatedBufferError(fixed32Len, len(buf)) + default: + return 0, newUnknownFieldTypeError(typ) + } + + return 0, wrapParseFieldError(num, typ, err) +} diff --git a/internal/protobuf/parsers_test.go b/internal/protobuf/parsers_test.go new file mode 100644 index 0000000000..9af32cc41b --- /dev/null +++ b/internal/protobuf/parsers_test.go @@ -0,0 +1,503 @@ +package protobuf_test + +import ( + "encoding/binary" + "math" + "slices" + "strconv" + "testing" + + iprotobuf "github.com/nspcc-dev/neofs-node/internal/protobuf" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/encoding/protowire" +) + +var ( + int32OverflowVarint = []byte{128, 128, 128, 128, 8} // 2147483648 + uint32OverflowVarint = []byte{128, 128, 128, 128, 16} // 4294967296 + + uint64OverflowVarint = []byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 2} +) + +var varintTestcases = []struct { + val uint64 + ln int + buf []byte +}{ + {val: 0, ln: 1, buf: []byte{0}}, + {val: 127, ln: 1, buf: []byte{127}}, + {val: 128, ln: 2, buf: []byte{128, 1}}, + {val: 16256, ln: 2, buf: []byte{128, 127}}, + {val: 16384, ln: 3, buf: []byte{128, 128, 1}}, + {val: math.MaxUint64, ln: 10, buf: []byte{255, 255, 255, 255, 255, 255, 255, 255, 255, 1}}, +} + +var invalidVarintTestcases = []struct { + name string + err string + buf []byte +}{ + {name: "empty buffer", err: "unexpected EOF", buf: []byte{}}, + {name: "truncated", err: "unexpected EOF", buf: []byte{128}}, + {name: "overflow", err: "variable length integer overflow", buf: uint64OverflowVarint}, +} + +func TestParseVarint(t *testing.T) { + for _, tc := range invalidVarintTestcases { + t.Run(tc.name, func(t *testing.T) { + _, _, err := iprotobuf.ParseVarint(tc.buf) + require.ErrorContains(t, err, tc.err) + }) + } + + for i, tc := range varintTestcases { + u, n, err := iprotobuf.ParseVarint(tc.buf) + require.NoError(t, err, i) + require.EqualValues(t, tc.val, u, i) + require.EqualValues(t, tc.ln, n, i) + } +} + +func TestParseTag(t *testing.T) { + t.Run("invalid varint", func(t *testing.T) { + for _, tc := range invalidVarintTestcases { + t.Run(tc.name, func(t *testing.T) { + _, _, _, err := iprotobuf.ParseTag(tc.buf) + require.ErrorContains(t, err, "parse varint") + require.ErrorContains(t, err, tc.err) + }) + } + }) + + t.Run("invalid number", func(t *testing.T) { + t.Run("0", func(t *testing.T) { + _, _, _, err := iprotobuf.ParseTag([]byte{2}) // 0,LEN + require.EqualError(t, err, "invalid number 0") + }) + t.Run("negative", func(t *testing.T) { + _, _, _, err := iprotobuf.ParseTag([]byte{250, 255, 255, 255, 255, 255, 255, 255, 255, 1}) // -1,LEN + require.EqualError(t, err, "invalid number -1") + }) + t.Run("too big", func(t *testing.T) { + _, _, _, err := iprotobuf.ParseTag([]byte{130, 128, 128, 128, 16}) // 536870912,LEN + require.EqualError(t, err, "invalid number 536870912") + }) + }) + + check := func(t *testing.T, buf []byte, expNum protowire.Number, expTyp protowire.Type, expN int) { + num, typ, n, err := iprotobuf.ParseTag(buf) + require.NoError(t, err) + require.EqualValues(t, expNum, num) + require.EqualValues(t, expTyp, typ) + require.EqualValues(t, expN, n) + } + + check(t, []byte{130, 64}, 1024, protowire.BytesType, 2) + check(t, []byte{24, 1, 2, 3}, 3, protowire.VarintType, 1) +} + +func TestParseLEN(t *testing.T) { + t.Run("invalid len", func(t *testing.T) { + for _, tc := range invalidVarintTestcases { + t.Run(tc.name, func(t *testing.T) { + _, _, err := iprotobuf.ParseLEN(tc.buf) + require.ErrorContains(t, err, "parse varint") + require.ErrorContains(t, err, tc.err) + }) + } + }) + + t.Run("buffer overflow", func(t *testing.T) { + t.Run("int overflow", func(t *testing.T) { + buf := make([]byte, binary.MaxVarintLen64) + n := binary.PutUvarint(buf, uint64(math.MaxInt+1)) + + _, _, err := iprotobuf.ParseLEN(buf[:n]) + require.EqualError(t, err, "value "+strconv.FormatUint(math.MaxInt+1, 10)+" overflows int") + }) + + for i, tc := range varintTestcases { + if tc.val == 0 || tc.val > 1<<20 { + continue + } + + buf := slices.Concat(tc.buf, make([]byte, tc.val-1)) + _, _, err := iprotobuf.ParseLEN(buf) + require.EqualError(t, err, "unexpected EOF: need "+strconv.FormatUint(tc.val, 10)+" bytes, left "+strconv.FormatUint(tc.val-1, 10)+" in buffer", i) + } + }) + + for i, tc := range varintTestcases { + if tc.val > 1<<20 { + continue + } + + buf := make([]byte, tc.val) + u, n, err := iprotobuf.ParseLEN(slices.Concat(tc.buf, buf)) + require.NoError(t, err, i) + require.EqualValues(t, tc.val, u, i) + require.EqualValues(t, tc.ln, n, i) + } +} + +func TestParseLENField(t *testing.T) { + t.Run("wrong type", func(t *testing.T) { + _, _, err := iprotobuf.ParseLENField([]byte{}, 42, protowire.VarintType) + require.EqualError(t, err, "wrong type of field #42: expected LEN, got VARINT") + }) + + t.Run("invalid len", func(t *testing.T) { + for _, tc := range invalidVarintTestcases { + t.Run(tc.name, func(t *testing.T) { + _, _, err := iprotobuf.ParseLENField(tc.buf, 42, protowire.BytesType) + require.ErrorContains(t, err, "parse field #42 of LEN type: parse varint") + require.ErrorContains(t, err, tc.err) + }) + } + }) + + t.Run("buffer overflow", func(t *testing.T) { + t.Run("int overflow", func(t *testing.T) { + buf := make([]byte, binary.MaxVarintLen64) + n := binary.PutUvarint(buf, uint64(math.MaxInt+1)) + + _, _, err := iprotobuf.ParseLENField(buf[:n], 42, protowire.BytesType) + require.EqualError(t, err, "parse field #42 of LEN type: value "+strconv.FormatUint(math.MaxInt+1, 10)+" overflows int") + }) + + for i, tc := range varintTestcases { + if tc.val == 0 || tc.val > 1<<20 { + continue + } + + buf := slices.Concat(tc.buf, make([]byte, tc.val-1)) + _, _, err := iprotobuf.ParseLENField(buf, 42, protowire.BytesType) + require.EqualError(t, err, "parse field #42 of LEN type: unexpected EOF: need "+strconv.FormatUint(tc.val, 10)+" bytes, left "+strconv.FormatUint(tc.val-1, 10)+" in buffer", i) + } + }) + + for i, tc := range varintTestcases { + if tc.val > 1<<20 { + continue + } + + buf := make([]byte, tc.val) + u, n, err := iprotobuf.ParseLENField(slices.Concat(tc.buf, buf), 42, protowire.BytesType) + require.NoError(t, err, i) + require.EqualValues(t, tc.val, u, i) + require.EqualValues(t, tc.ln, n, i) + } +} + +func TestParseLENFieldBounds(t *testing.T) { + prefix := []byte{1, 2, 3, 4} + const off = 2 + tagLn := len(prefix) - off + + t.Run("cut prefix", func(t *testing.T) { + require.Panics(t, func() { + _, _ = iprotobuf.ParseLENFieldBounds(prefix, off+1, tagLn, 42, protowire.BytesType) + }) + require.Panics(t, func() { + _, _ = iprotobuf.ParseLENFieldBounds(prefix, off, tagLn+1, 42, protowire.BytesType) + }) + }) + + t.Run("wrong type", func(t *testing.T) { + _, err := iprotobuf.ParseLENFieldBounds(prefix, off, tagLn, 42, protowire.VarintType) + require.EqualError(t, err, "wrong type of field #42: expected LEN, got VARINT") + }) + + t.Run("invalid len", func(t *testing.T) { + for _, tc := range invalidVarintTestcases { + t.Run(tc.name, func(t *testing.T) { + buf := slices.Concat(prefix, tc.buf) + _, err := iprotobuf.ParseLENFieldBounds(buf, off, tagLn, 42, protowire.BytesType) + require.ErrorContains(t, err, "parse field #42 of LEN type: parse varint") + require.ErrorContains(t, err, tc.err) + }) + } + }) + + t.Run("buffer overflow", func(t *testing.T) { + t.Run("int overflow", func(t *testing.T) { + buf := slices.Concat(prefix, make([]byte, binary.MaxVarintLen64)) + n := binary.PutUvarint(buf[len(prefix):], uint64(math.MaxInt+1)) + + _, err := iprotobuf.ParseLENFieldBounds(buf[:len(prefix)+n], off, tagLn, 42, protowire.BytesType) + require.EqualError(t, err, "parse field #42 of LEN type: value "+strconv.FormatUint(math.MaxInt+1, 10)+" overflows int") + }) + + for i, tc := range varintTestcases { + if tc.val == 0 || tc.val > 1<<20 { + continue + } + + buf := slices.Concat(prefix, tc.buf, make([]byte, tc.val-1)) + _, err := iprotobuf.ParseLENFieldBounds(buf, off, tagLn, 42, protowire.BytesType) + require.EqualError(t, err, "parse field #42 of LEN type: unexpected EOF: need "+strconv.FormatUint(tc.val, 10)+" bytes, left "+strconv.FormatUint(tc.val-1, 10)+" in buffer", i) + } + }) + + for i, tc := range varintTestcases { + if tc.val > 1<<20 { + continue + } + + buf := make([]byte, tc.val) + f, err := iprotobuf.ParseLENFieldBounds(slices.Concat(prefix, tc.buf, buf), off, tagLn, 42, protowire.BytesType) + require.NoError(t, err, i) + require.EqualValues(t, off, f.From) + require.EqualValues(t, f.From+tagLn+tc.ln, f.ValueFrom) + require.EqualValues(t, f.ValueFrom+int(tc.val), f.To) + } +} + +func TestParseEnum(t *testing.T) { + t.Run("invalid varint", func(t *testing.T) { + for _, tc := range invalidVarintTestcases { + t.Run(tc.name, func(t *testing.T) { + _, _, err := iprotobuf.ParseEnum[int32](tc.buf) + require.ErrorContains(t, err, tc.err) + }) + } + }) + + t.Run("value overflow", func(t *testing.T) { + _, _, err := iprotobuf.ParseEnum[int32](int32OverflowVarint) + require.EqualError(t, err, "value 2147483648 overflows int32") + }) + + for i, tc := range varintTestcases { + if tc.val > math.MaxInt32 { + continue + } + + u, n, err := iprotobuf.ParseEnum[int32](tc.buf) + require.NoError(t, err, i) + require.EqualValues(t, tc.val, u, i) + require.EqualValues(t, tc.ln, n, i) + } +} + +func TestParseEnumField(t *testing.T) { + t.Run("wrong type", func(t *testing.T) { + _, _, err := iprotobuf.ParseEnumField[int32]([]byte{}, 42, protowire.BytesType) + require.EqualError(t, err, "wrong type of field #42: expected VARINT, got LEN") + }) + + t.Run("invalid varint", func(t *testing.T) { + for _, tc := range invalidVarintTestcases { + t.Run(tc.name, func(t *testing.T) { + _, _, err := iprotobuf.ParseEnumField[int32](tc.buf, 42, protowire.VarintType) + require.ErrorContains(t, err, "parse field #42 of VARINT type: parse varint") + require.ErrorContains(t, err, tc.err) + }) + } + }) + + t.Run("value overflow", func(t *testing.T) { + _, _, err := iprotobuf.ParseEnumField[int32](int32OverflowVarint, 42, protowire.VarintType) + require.EqualError(t, err, "parse field #42 of VARINT type: value 2147483648 overflows int32") + }) + + for i, tc := range varintTestcases { + if tc.val > math.MaxInt32 { + continue + } + + u, n, err := iprotobuf.ParseEnumField[int32](tc.buf, 42, protowire.VarintType) + require.NoError(t, err, i) + require.EqualValues(t, tc.val, u, i) + require.EqualValues(t, tc.ln, n, i) + } +} + +func TestParseUint32(t *testing.T) { + t.Run("invalid varint", func(t *testing.T) { + for _, tc := range invalidVarintTestcases { + t.Run(tc.name, func(t *testing.T) { + _, _, err := iprotobuf.ParseUint32(tc.buf) + require.ErrorContains(t, err, tc.err) + }) + } + }) + + t.Run("value overflow", func(t *testing.T) { + _, _, err := iprotobuf.ParseUint32(uint32OverflowVarint) + require.EqualError(t, err, "value 4294967296 overflows uint32") + }) + + for i, tc := range varintTestcases { + if tc.val > math.MaxUint32 { + continue + } + + u, n, err := iprotobuf.ParseUint32(tc.buf) + require.NoError(t, err, i) + require.EqualValues(t, tc.val, u, i) + require.EqualValues(t, tc.ln, n, i) + } +} + +func TestParseUint32Field(t *testing.T) { + t.Run("wrong type", func(t *testing.T) { + _, _, err := iprotobuf.ParseUint32Field([]byte{}, 42, protowire.BytesType) + require.EqualError(t, err, "wrong type of field #42: expected VARINT, got LEN") + }) + + t.Run("invalid varint", func(t *testing.T) { + for _, tc := range invalidVarintTestcases { + t.Run(tc.name, func(t *testing.T) { + _, _, err := iprotobuf.ParseUint32Field(tc.buf, 42, protowire.VarintType) + require.ErrorContains(t, err, "parse field #42 of VARINT type: parse varint") + require.ErrorContains(t, err, tc.err) + }) + } + }) + + t.Run("value overflow", func(t *testing.T) { + _, _, err := iprotobuf.ParseUint32Field(uint32OverflowVarint, 42, protowire.VarintType) + require.EqualError(t, err, "parse field #42 of VARINT type: value 4294967296 overflows uint32") + }) + + for i, tc := range varintTestcases { + if tc.val > math.MaxUint32 { + continue + } + + u, n, err := iprotobuf.ParseUint32Field(tc.buf, 42, protowire.VarintType) + require.NoError(t, err, i) + require.EqualValues(t, tc.val, u, i) + require.EqualValues(t, tc.ln, n, i) + } +} + +func TestParseUint64Field(t *testing.T) { + t.Run("wrong type", func(t *testing.T) { + _, _, err := iprotobuf.ParseUint64Field([]byte{}, 42, protowire.BytesType) + require.EqualError(t, err, "wrong type of field #42: expected VARINT, got LEN") + }) + + t.Run("invalid varint", func(t *testing.T) { + for _, tc := range invalidVarintTestcases { + t.Run(tc.name, func(t *testing.T) { + _, _, err := iprotobuf.ParseUint64Field(tc.buf, 42, protowire.VarintType) + require.ErrorContains(t, err, "parse field #42 of VARINT type: parse varint") + require.ErrorContains(t, err, tc.err) + }) + } + }) + + for i, tc := range varintTestcases { + if tc.val > math.MaxUint32 { + continue + } + + u, n, err := iprotobuf.ParseUint64Field(tc.buf, 42, protowire.VarintType) + require.NoError(t, err, i) + require.EqualValues(t, tc.val, u, i) + require.EqualValues(t, tc.ln, n, i) + } +} + +func TestSkipField(t *testing.T) { + t.Run("unknown type", func(t *testing.T) { + _, err := iprotobuf.SkipField([]byte{}, 10, -1) + require.EqualError(t, err, "unknown field type -1") + + _, err = iprotobuf.SkipField([]byte{}, 10, 6) + require.EqualError(t, err, "unknown field type 6") + }) + + t.Run("VARINT", func(t *testing.T) { + for _, tc := range invalidVarintTestcases { + t.Run(tc.name, func(t *testing.T) { + _, err := iprotobuf.SkipField(tc.buf, 42, protowire.VarintType) + require.ErrorContains(t, err, "parse field #42 of VARINT type") + require.ErrorContains(t, err, tc.err) + }) + } + + for i, tc := range varintTestcases { + n, err := iprotobuf.SkipField(tc.buf, 42, protowire.VarintType) + require.NoError(t, err, i) + require.EqualValues(t, tc.ln, n, i) + } + }) + + t.Run("I32", func(t *testing.T) { + for ln := range 4 { + _, err := iprotobuf.SkipField(make([]byte, ln), 42, protowire.Fixed32Type) + require.EqualError(t, err, "parse field #42 of I32 type: unexpected EOF: need 4 bytes, left "+strconv.Itoa(ln)+" in buffer") + } + + n, err := iprotobuf.SkipField(make([]byte, 4), 42, protowire.Fixed32Type) + require.NoError(t, err) + require.EqualValues(t, 4, n) + }) + + t.Run("I64", func(t *testing.T) { + for ln := range 8 { + _, err := iprotobuf.SkipField(make([]byte, ln), 42, protowire.Fixed64Type) + require.EqualError(t, err, "parse field #42 of I64 type: unexpected EOF: need 8 bytes, left "+strconv.Itoa(ln)+" in buffer") + } + + n, err := iprotobuf.SkipField(make([]byte, 8), 42, protowire.Fixed64Type) + require.NoError(t, err) + require.EqualValues(t, 8, n) + }) + + t.Run("LEN", func(t *testing.T) { + t.Run("invalid len", func(t *testing.T) { + for _, tc := range invalidVarintTestcases { + t.Run(tc.name, func(t *testing.T) { + _, err := iprotobuf.SkipField(tc.buf, 42, protowire.BytesType) + require.ErrorContains(t, err, "parse field #42 of LEN type: parse varint") + require.ErrorContains(t, err, tc.err) + }) + } + }) + + t.Run("buffer overflow", func(t *testing.T) { + t.Run("int overflow", func(t *testing.T) { + buf := make([]byte, binary.MaxVarintLen64) + n := binary.PutUvarint(buf, uint64(math.MaxInt+1)) + + _, err := iprotobuf.SkipField(buf[:n], 42, protowire.BytesType) + require.EqualError(t, err, "parse field #42 of LEN type: value "+strconv.FormatUint(math.MaxInt+1, 10)+" overflows int") + }) + + for i, tc := range varintTestcases { + if tc.val == 0 || tc.val > 1<<20 { + continue + } + + buf := slices.Concat(tc.buf, make([]byte, tc.val-1)) + _, err := iprotobuf.SkipField(buf, 42, protowire.BytesType) + require.EqualError(t, err, "parse field #42 of LEN type: unexpected EOF: need "+strconv.FormatUint(tc.val, 10)+" bytes, left "+strconv.FormatUint(tc.val-1, 10)+" in buffer", i) + } + }) + + for i, tc := range varintTestcases { + if tc.val > 1<<20 { + continue + } + + buf := make([]byte, tc.val) + n, err := iprotobuf.SkipField(slices.Concat(tc.buf, buf), 42, protowire.BytesType) + require.NoError(t, err, i) + require.EqualValues(t, tc.ln+int(tc.val), n, i) + } + }) + + t.Run("SGROUP", func(t *testing.T) { + _, err := iprotobuf.SkipField([]byte{}, 42, protowire.StartGroupType) + require.EqualError(t, err, "parse field #42 of SGROUP type: type is not supported") + }) + + t.Run("EGROUP", func(t *testing.T) { + _, err := iprotobuf.SkipField([]byte{}, 42, protowire.EndGroupType) + require.EqualError(t, err, "parse field #42 of EGROUP type: type is not supported") + }) +} diff --git a/internal/protobuf/protobuf.go b/internal/protobuf/protobuf.go index 807deefc37..34e38a78ae 100644 --- a/internal/protobuf/protobuf.go +++ b/internal/protobuf/protobuf.go @@ -6,6 +6,18 @@ import ( "google.golang.org/protobuf/encoding/protowire" ) +// FieldBounds represents boundaries of a field in a particular buffer. +type FieldBounds struct { + From int // first byte index + ValueFrom int // first value byte index + To int // last byte index +} + +// IsMissing returns field absence flag. +func (x FieldBounds) IsMissing() bool { + return x.From == x.To +} + // GetFirstBytesField gets VARLEN field with number = 1 from b. // // GetFirstBytesField returns slice of b, not copy. diff --git a/internal/protobuf/protobuf_test.go b/internal/protobuf/protobuf_test.go new file mode 100644 index 0000000000..5c75ecaee4 --- /dev/null +++ b/internal/protobuf/protobuf_test.go @@ -0,0 +1,22 @@ +package protobuf_test + +import ( + "testing" + + iprotobuf "github.com/nspcc-dev/neofs-node/internal/protobuf" + "github.com/stretchr/testify/require" +) + +func TestFieldBounds_IsMissing(t *testing.T) { + var f iprotobuf.FieldBounds + require.True(t, f.IsMissing()) + + f.To = 10 + require.False(t, f.IsMissing()) + + f.From = 9 + require.False(t, f.IsMissing()) + + f.To = 9 + require.True(t, f.IsMissing()) +} diff --git a/internal/protobuf/seekers.go b/internal/protobuf/seekers.go new file mode 100644 index 0000000000..12281d329b --- /dev/null +++ b/internal/protobuf/seekers.go @@ -0,0 +1,78 @@ +package protobuf + +import ( + "fmt" + + "google.golang.org/protobuf/encoding/protowire" +) + +// SeekFieldByNumber seeks field in buf by number and returns its offset, tag +// length and type. If field is missing, negative offset returns. +// +// Message should have ascending field order, otherwise error returns. +// +// Note that SeekFieldByNumber does not check value of found field, but checks +// intermediate ones for correct message traverse. +func SeekFieldByNumber(buf []byte, seekNum protowire.Number) (int, int, protowire.Type, error) { + if err := checkFieldNumber(seekNum); err != nil { + return 0, 0, 0, err + } + + if len(buf) == 0 { + return -1, 0, 0, nil + } + + var off int + var prevNum protowire.Number + + for { + num, typ, n, err := ParseTag(buf[off:]) + if err != nil { + return 0, 0, 0, fmt.Errorf("parse tag at offset %d: %w", off, err) + } + + if num == seekNum { + return off, n, typ, nil + } + if num > seekNum { + break + } + if num < prevNum { + return 0, 0, 0, NewUnorderedFieldsError(prevNum, num) + } + prevNum = num + + off += n + + n, err = SkipField(buf[off:], num, typ) + if err != nil { + return 0, 0, 0, err + } + off += n + + if off == len(buf) { + break + } + } + + return -1, 0, 0, nil +} + +// GetLENFieldBounds seeks LEN field in buf by number and parses its boundaries. +// If field is missing, no error is returned. +// +// Message should have ascending field order, otherwise error returns. +// +// If there is an error, its text contains num. +func GetLENFieldBounds(buf []byte, num protowire.Number) (FieldBounds, error) { + off, tagLn, typ, err := SeekFieldByNumber(buf, num) + if err != nil { + return FieldBounds{}, err + } + + if off < 0 { + return FieldBounds{}, nil + } + + return ParseLENFieldBounds(buf, off, tagLn, num, typ) +} diff --git a/internal/protobuf/seekers_test.go b/internal/protobuf/seekers_test.go new file mode 100644 index 0000000000..0fcd6c7cb1 --- /dev/null +++ b/internal/protobuf/seekers_test.go @@ -0,0 +1,378 @@ +package protobuf_test + +import ( + "encoding/binary" + "math" + "slices" + "strconv" + "testing" + + iprotobuf "github.com/nspcc-dev/neofs-node/internal/protobuf" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/encoding/protowire" +) + +func TestSeekFieldByNumber(t *testing.T) { + t.Run("invalid number", func(t *testing.T) { + t.Run("0", func(t *testing.T) { + for _, n := range []protowire.Number{-1, 0, 536870912} { + _, _, _, err := iprotobuf.SeekFieldByNumber([]byte{}, n) + require.EqualError(t, err, "invalid number "+strconv.Itoa(int(n))) + } + }) + }) + + t.Run("empty buffer", func(t *testing.T) { + off, _, _, err := iprotobuf.SeekFieldByNumber([]byte{}, 42) + require.NoError(t, err) + require.Negative(t, off) + }) + + // #1, VARINT, 1234567890 + fld1 := []byte{8, 210, 133, 216, 204, 4} + // #100, I64, 2345678901 + fld2 := []byte{161, 6, 210, 56, 251, 13, 0, 0, 0, 0} + // #5K, LEN, Hello, world! + fld3 := []byte{194, 184, 2, 13, 72, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100, 33} + // #1KK, I32, 3456789012 + fld4 := []byte{133, 164, 232, 3, 20, 106, 10, 206} + + t.Run("invalid tag", func(t *testing.T) { + t.Run("invalid varint", func(t *testing.T) { + for _, tc := range invalidVarintTestcases { + if len(tc.buf) == 0 { + continue + } + t.Run(tc.name, func(t *testing.T) { + buf := slices.Concat(fld1, tc.buf) + _, _, _, err := iprotobuf.SeekFieldByNumber(buf, 42) + require.ErrorContains(t, err, "parse tag at offset "+strconv.Itoa(len(fld1))) + require.ErrorContains(t, err, "parse varint") + require.ErrorContains(t, err, tc.err) + }) + } + }) + + t.Run("invalid number", func(t *testing.T) { + t.Run("0", func(t *testing.T) { + buf := slices.Concat(fld1, []byte{2}) + _, _, _, err := iprotobuf.SeekFieldByNumber(buf, 42) // 0,LEN + require.EqualError(t, err, "parse tag at offset "+strconv.Itoa(len(fld1))+": invalid number 0") + }) + t.Run("negative", func(t *testing.T) { + buf := slices.Concat(fld1, []byte{250, 255, 255, 255, 255, 255, 255, 255, 255, 1}) // -1,LEN + _, _, _, err := iprotobuf.SeekFieldByNumber(buf, 42) + require.EqualError(t, err, "parse tag at offset "+strconv.Itoa(len(fld1))+": invalid number -1") + }) + }) + }) + + t.Run("unordered fields", func(t *testing.T) { + buf := slices.Concat(fld1, fld3, fld2) + _, _, _, err := iprotobuf.SeekFieldByNumber(buf, 5001) + require.EqualError(t, err, "unordered fields: #100 after #5000") + }) + + t.Run("parse intermediate field failure", func(t *testing.T) { + t.Run("varint", func(t *testing.T) { + fld := slices.Concat([]byte{208, 2}, uint64OverflowVarint) + buf := slices.Concat(fld1, fld, fld2, fld3, fld4) + _, _, _, err := iprotobuf.SeekFieldByNumber(buf, 5000) + require.ErrorContains(t, err, "parse field #42 of VARINT type") + require.ErrorContains(t, err, "variable length integer overflow") + }) + t.Run("I64", func(t *testing.T) { + buf := slices.Concat(fld1, fld2[:len(fld2)-1]) + _, _, _, err := iprotobuf.SeekFieldByNumber(buf, 5000) + require.EqualError(t, err, "parse field #100 of I64 type: unexpected EOF: need 8 bytes, left 7 in buffer") + }) + t.Run("LEN", func(t *testing.T) { + tag := []byte{210, 2} + t.Run("invalid len", func(t *testing.T) { + for _, tc := range invalidVarintTestcases { + t.Run(tc.name, func(t *testing.T) { + buf := slices.Concat(fld1, tag, tc.buf) + _, _, _, err := iprotobuf.SeekFieldByNumber(buf, 43) + require.ErrorContains(t, err, "parse field #42 of LEN type") + require.ErrorContains(t, err, "parse varint") + require.ErrorContains(t, err, tc.err) + }) + } + }) + t.Run("buffer overflow", func(t *testing.T) { + t.Run("int overflow", func(t *testing.T) { + buf := make([]byte, binary.MaxVarintLen64) + n := binary.PutUvarint(buf, uint64(math.MaxInt+1)) + + buf = slices.Concat(fld1, tag, buf[:n]) + + _, _, _, err := iprotobuf.SeekFieldByNumber(buf, 43) + require.EqualError(t, err, "parse field #42 of LEN type: value "+strconv.FormatUint(math.MaxInt+1, 10)+" overflows int") + }) + + for i, tc := range varintTestcases { + if tc.val == 0 || tc.val > 1<<20 { + continue + } + + buf := slices.Concat(tc.buf, make([]byte, tc.val-1)) + buf = slices.Concat(fld1, tag, buf) + _, _, _, err := iprotobuf.SeekFieldByNumber(buf, 43) + require.EqualError(t, err, "parse field #42 of LEN type: unexpected EOF: need "+strconv.FormatUint(tc.val, 10)+" bytes, left "+strconv.FormatUint(tc.val-1, 10)+" in buffer", i) + } + }) + }) + t.Run("SGROUP", func(t *testing.T) { + buf := slices.Concat(fld1, []byte{211, 2}) + _, _, _, err := iprotobuf.SeekFieldByNumber(buf, 43) + require.EqualError(t, err, "parse field #42 of SGROUP type: type is not supported") + }) + t.Run("EGROUP", func(t *testing.T) { + buf := slices.Concat(fld1, []byte{212, 2}) + _, _, _, err := iprotobuf.SeekFieldByNumber(buf, 43) + require.EqualError(t, err, "parse field #42 of EGROUP type: type is not supported") + }) + t.Run("I32", func(t *testing.T) { + buf := slices.Concat(fld1, fld4[:len(fld4)-1]) + _, _, _, err := iprotobuf.SeekFieldByNumber(buf, 1_000_001) + require.EqualError(t, err, "parse field #1000000 of I32 type: unexpected EOF: need 4 bytes, left 3 in buffer") + }) + }) + + message := slices.Concat(fld1, fld2, fld3, fld4) + + t.Run("missing", func(t *testing.T) { + for _, n := range []protowire.Number{2, 99, 101, 4999, 50001, 999_999, 1_000_001} { + off, _, _, err := iprotobuf.SeekFieldByNumber(message, n) + require.NoError(t, err, n) + require.Negative(t, off, n) + } + }) + + off, tagLn, typ, err := iprotobuf.SeekFieldByNumber(message, 1) + require.NoError(t, err) + require.EqualValues(t, 0, off) + require.EqualValues(t, 1, tagLn) + require.EqualValues(t, protowire.VarintType, typ) + + off, tagLn, typ, err = iprotobuf.SeekFieldByNumber(message, 100) + require.NoError(t, err) + require.EqualValues(t, len(fld1), off) + require.EqualValues(t, 2, tagLn) + require.EqualValues(t, protowire.Fixed64Type, typ) + + off, tagLn, typ, err = iprotobuf.SeekFieldByNumber(message, 5_000) + require.NoError(t, err) + require.EqualValues(t, len(fld1)+len(fld2), off) + require.EqualValues(t, 3, tagLn) + require.EqualValues(t, protowire.BytesType, typ) + + off, tagLn, typ, err = iprotobuf.SeekFieldByNumber(message, 1_000_000) + require.NoError(t, err) + require.EqualValues(t, len(fld1)+len(fld2)+len(fld3), off) + require.EqualValues(t, 4, tagLn) + require.EqualValues(t, protowire.Fixed32Type, typ) +} + +func TestGetLENFieldBounds(t *testing.T) { + t.Run("invalid number", func(t *testing.T) { + t.Run("0", func(t *testing.T) { + for _, n := range []protowire.Number{-1, 0, 536870912} { + _, err := iprotobuf.GetLENFieldBounds([]byte{}, n) + require.EqualError(t, err, "invalid number "+strconv.Itoa(int(n))) + } + }) + }) + + t.Run("empty buffer", func(t *testing.T) { + f, err := iprotobuf.GetLENFieldBounds([]byte{}, 42) + require.NoError(t, err) + require.True(t, f.IsMissing()) + }) + + // #1, VARINT, 1234567890 + fld1 := []byte{8, 210, 133, 216, 204, 4} + // #100, I64, 2345678901 + fld2 := []byte{161, 6, 210, 56, 251, 13, 0, 0, 0, 0} + // #5K, LEN, Hello, world! + fld3 := []byte{194, 184, 2, 13, 72, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100, 33} + // #1KK, I32, 3456789012 + fld4 := []byte{133, 164, 232, 3, 20, 106, 10, 206} + + t.Run("seek failure", func(t *testing.T) { + t.Run("invalid tag", func(t *testing.T) { + t.Run("invalid varint", func(t *testing.T) { + for _, tc := range invalidVarintTestcases { + if len(tc.buf) == 0 { + continue + } + t.Run(tc.name, func(t *testing.T) { + buf := slices.Concat(fld1, tc.buf) + _, err := iprotobuf.GetLENFieldBounds(buf, 42) + require.ErrorContains(t, err, "parse tag at offset "+strconv.Itoa(len(fld1))) + require.ErrorContains(t, err, "parse varint") + require.ErrorContains(t, err, tc.err) + }) + } + }) + + t.Run("invalid number", func(t *testing.T) { + t.Run("0", func(t *testing.T) { + buf := slices.Concat(fld1, []byte{2}) + _, err := iprotobuf.GetLENFieldBounds(buf, 42) // 0,LEN + require.EqualError(t, err, "parse tag at offset "+strconv.Itoa(len(fld1))+": invalid number 0") + }) + t.Run("negative", func(t *testing.T) { + buf := slices.Concat(fld1, []byte{250, 255, 255, 255, 255, 255, 255, 255, 255, 1}) // -1,LEN + _, err := iprotobuf.GetLENFieldBounds(buf, 42) + require.EqualError(t, err, "parse tag at offset "+strconv.Itoa(len(fld1))+": invalid number -1") + }) + }) + }) + + t.Run("unordered fields", func(t *testing.T) { + buf := slices.Concat(fld1, fld3, fld2) + _, err := iprotobuf.GetLENFieldBounds(buf, 5001) + require.EqualError(t, err, "unordered fields: #100 after #5000") + }) + + t.Run("parse intermediate field failure", func(t *testing.T) { + t.Run("varint", func(t *testing.T) { + fld := slices.Concat([]byte{208, 2}, uint64OverflowVarint) + buf := slices.Concat(fld1, fld, fld2, fld3, fld4) + _, err := iprotobuf.GetLENFieldBounds(buf, 5000) + require.ErrorContains(t, err, "parse field #42 of VARINT type") + require.ErrorContains(t, err, "variable length integer overflow") + }) + t.Run("I64", func(t *testing.T) { + buf := slices.Concat(fld1, fld2[:len(fld2)-1]) + _, err := iprotobuf.GetLENFieldBounds(buf, 5000) + require.EqualError(t, err, "parse field #100 of I64 type: unexpected EOF: need 8 bytes, left 7 in buffer") + }) + t.Run("LEN", func(t *testing.T) { + tag := []byte{210, 2} + t.Run("invalid len", func(t *testing.T) { + for _, tc := range invalidVarintTestcases { + t.Run(tc.name, func(t *testing.T) { + buf := slices.Concat(fld1, tag, tc.buf) + _, err := iprotobuf.GetLENFieldBounds(buf, 43) + require.ErrorContains(t, err, "parse field #42 of LEN type") + require.ErrorContains(t, err, "parse varint") + require.ErrorContains(t, err, tc.err) + }) + } + }) + t.Run("buffer overflow", func(t *testing.T) { + t.Run("int overflow", func(t *testing.T) { + buf := make([]byte, binary.MaxVarintLen64) + n := binary.PutUvarint(buf, uint64(math.MaxInt+1)) + + buf = slices.Concat(fld1, tag, buf[:n]) + + _, err := iprotobuf.GetLENFieldBounds(buf, 43) + require.EqualError(t, err, "parse field #42 of LEN type: value "+strconv.FormatUint(math.MaxInt+1, 10)+" overflows int") + }) + + for i, tc := range varintTestcases { + if tc.val == 0 || tc.val > 1<<20 { + continue + } + + buf := slices.Concat(tc.buf, make([]byte, tc.val-1)) + buf = slices.Concat(fld1, tag, buf) + _, err := iprotobuf.GetLENFieldBounds(buf, 43) + require.EqualError(t, err, "parse field #42 of LEN type: unexpected EOF: need "+strconv.FormatUint(tc.val, 10)+" bytes, left "+strconv.FormatUint(tc.val-1, 10)+" in buffer", i) + } + }) + }) + t.Run("SGROUP", func(t *testing.T) { + buf := slices.Concat(fld1, []byte{211, 2}) + _, err := iprotobuf.GetLENFieldBounds(buf, 43) + require.EqualError(t, err, "parse field #42 of SGROUP type: type is not supported") + }) + t.Run("EGROUP", func(t *testing.T) { + buf := slices.Concat(fld1, []byte{212, 2}) + _, err := iprotobuf.GetLENFieldBounds(buf, 43) + require.EqualError(t, err, "parse field #42 of EGROUP type: type is not supported") + }) + t.Run("I32", func(t *testing.T) { + buf := slices.Concat(fld1, fld4[:len(fld4)-1]) + _, err := iprotobuf.GetLENFieldBounds(buf, 1_000_001) + require.EqualError(t, err, "parse field #1000000 of I32 type: unexpected EOF: need 4 bytes, left 3 in buffer") + }) + }) + }) + + t.Run("wrong type", func(t *testing.T) { + _, err := iprotobuf.GetLENFieldBounds([]byte{208, 2}, 42) + require.EqualError(t, err, "wrong type of field #42: expected LEN, got VARINT") + }) + + t.Run("parse failure", func(t *testing.T) { + tag := []byte{210, 2} + t.Run("invalid len", func(t *testing.T) { + for _, tc := range invalidVarintTestcases { + t.Run(tc.name, func(t *testing.T) { + buf := slices.Concat(fld1, tag, tc.buf) + _, err := iprotobuf.GetLENFieldBounds(buf, 42) + require.ErrorContains(t, err, "parse field #42 of LEN type") + require.ErrorContains(t, err, "parse varint") + require.ErrorContains(t, err, tc.err) + }) + } + }) + t.Run("buffer overflow", func(t *testing.T) { + t.Run("int overflow", func(t *testing.T) { + buf := make([]byte, binary.MaxVarintLen64) + n := binary.PutUvarint(buf, uint64(math.MaxInt+1)) + + buf = slices.Concat(fld1, tag, buf[:n]) + + _, err := iprotobuf.GetLENFieldBounds(buf, 42) + require.EqualError(t, err, "parse field #42 of LEN type: value "+strconv.FormatUint(math.MaxInt+1, 10)+" overflows int") + }) + + for i, tc := range varintTestcases { + if tc.val == 0 || tc.val > 1<<20 { + continue + } + + buf := slices.Concat(tc.buf, make([]byte, tc.val-1)) + buf = slices.Concat(fld1, tag, buf) + _, err := iprotobuf.GetLENFieldBounds(buf, 42) + require.EqualError(t, err, "parse field #42 of LEN type: unexpected EOF: need "+strconv.FormatUint(tc.val, 10)+" bytes, left "+strconv.FormatUint(tc.val-1, 10)+" in buffer", i) + } + }) + }) + + message := slices.Concat(fld1, fld2, fld3, fld4) + + t.Run("missing", func(t *testing.T) { + for _, n := range []protowire.Number{2, 99, 101, 4999, 50001, 999_999, 1_000_001} { + f, err := iprotobuf.GetLENFieldBounds(message, n) + require.NoError(t, err, n) + require.True(t, f.IsMissing()) + } + }) + + f, err := iprotobuf.GetLENFieldBounds(fld3, 5000) + require.NoError(t, err) + require.False(t, f.IsMissing()) + require.EqualValues(t, 0, f.From) + require.EqualValues(t, 4, f.ValueFrom) + require.EqualValues(t, len(fld3), f.To) + + f, err = iprotobuf.GetLENFieldBounds(slices.Concat(fld1, fld2, fld3), 5000) + require.NoError(t, err) + require.False(t, f.IsMissing()) + require.EqualValues(t, len(fld1)+len(fld2), f.From) + require.EqualValues(t, f.From+4, f.ValueFrom) + require.EqualValues(t, f.From+len(fld3), f.To) + + f, err = iprotobuf.GetLENFieldBounds(message, 5_000) + require.NoError(t, err) + require.False(t, f.IsMissing()) + require.EqualValues(t, len(fld1)+len(fld2), f.From) + require.EqualValues(t, f.From+4, f.ValueFrom) + require.EqualValues(t, f.From+len(fld3), f.To) +} diff --git a/internal/protobuf/tags.go b/internal/protobuf/tags.go new file mode 100644 index 0000000000..5705c7b8b5 --- /dev/null +++ b/internal/protobuf/tags.go @@ -0,0 +1,21 @@ +package protobuf + +// One-byte tags for LEN fields. +const ( + TagBytes1 = 10 + TagBytes2 = 18 + TagBytes3 = 26 + TagBytes4 = 34 + TagBytes5 = 42 + TagBytes6 = 50 +) + +// One-byte tags for VARINT fields. +const ( + TagVarint1 = 8 + TagVarint2 = 16 + TagVarint3 = 24 + TagVarint4 = 32 + TagVarint5 = 40 + TagVarint6 = 48 +) diff --git a/internal/protobuf/tags_test.go b/internal/protobuf/tags_test.go new file mode 100644 index 0000000000..09f0c25db6 --- /dev/null +++ b/internal/protobuf/tags_test.go @@ -0,0 +1,55 @@ +package protobuf_test + +import ( + "testing" + + iprotobuf "github.com/nspcc-dev/neofs-node/internal/protobuf" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/encoding/protowire" +) + +func TestTags(t *testing.T) { + t.Run("LEN", func(t *testing.T) { + for _, tc := range []struct { + tag byte + num int + }{ + {tag: iprotobuf.TagBytes1, num: 1}, + {tag: iprotobuf.TagBytes2, num: 2}, + {tag: iprotobuf.TagBytes3, num: 3}, + {tag: iprotobuf.TagBytes4, num: 4}, + {tag: iprotobuf.TagBytes5, num: 5}, + {tag: iprotobuf.TagBytes6, num: 6}, + } { + require.EqualValues(t, protowire.EncodeTag(protowire.Number(tc.num), protowire.BytesType), tc.tag) + + num, typ, n, err := iprotobuf.ParseTag([]byte{tc.tag}) + require.NoError(t, err) + require.EqualValues(t, 1, n) + require.EqualValues(t, protowire.BytesType, typ) + require.EqualValues(t, tc.num, num) + } + }) + + t.Run("VARINT", func(t *testing.T) { + for _, tc := range []struct { + tag byte + num int + }{ + {tag: iprotobuf.TagVarint1, num: 1}, + {tag: iprotobuf.TagVarint2, num: 2}, + {tag: iprotobuf.TagVarint3, num: 3}, + {tag: iprotobuf.TagVarint4, num: 4}, + {tag: iprotobuf.TagVarint5, num: 5}, + {tag: iprotobuf.TagVarint6, num: 6}, + } { + require.EqualValues(t, protowire.EncodeTag(protowire.Number(tc.num), protowire.VarintType), tc.tag) + + num, typ, n, err := iprotobuf.ParseTag([]byte{tc.tag}) + require.NoError(t, err) + require.EqualValues(t, 1, n) + require.EqualValues(t, protowire.VarintType, typ) + require.EqualValues(t, tc.num, num) + } + }) +} diff --git a/internal/protobuf/util.go b/internal/protobuf/util.go new file mode 100644 index 0000000000..1e34752d20 --- /dev/null +++ b/internal/protobuf/util.go @@ -0,0 +1,61 @@ +package protobuf + +import ( + "fmt" + "io" + "strconv" + + "google.golang.org/protobuf/encoding/protowire" +) + +const ( + fixed32Len = 4 + fixed64Len = 8 +) + +type wireType protowire.Type + +func (x wireType) String() string { + switch protowire.Type(x) { + case protowire.VarintType: + return "VARINT" + case protowire.Fixed64Type: + return "I64" + case protowire.BytesType: + return "LEN" + case protowire.StartGroupType: + return "SGROUP" + case protowire.EndGroupType: + return "EGROUP" + case protowire.Fixed32Type: + return "I32" + default: + return strconv.Itoa(int(x)) + } +} + +func checkFieldType(num protowire.Number, exp, got protowire.Type) error { + if exp != got { + return fmt.Errorf("wrong type of field #%d: expected %s, got %s", num, wireType(exp), wireType(got)) + } + return nil +} + +func checkFieldNumber(num protowire.Number) error { + if !num.IsValid() { + return fmt.Errorf("invalid number %d", num) + } + return nil +} + +func newUnknownFieldTypeError(t protowire.Type) error { + return fmt.Errorf("unknown field type %s", wireType(t)) +} + +func newTruncatedBufferError(need, left int) error { + return fmt.Errorf("%w: need %d bytes, left %d in buffer", io.ErrUnexpectedEOF, need, left) +} + +func wrapParseFieldError(n protowire.Number, t protowire.Type, cause error) error { + return fmt.Errorf("parse field #%d of %s type: %w", n, wireType(t), cause) +}