diff --git a/SPEC.md b/SPEC.md index 6031b9a..795b0ff 100644 --- a/SPEC.md +++ b/SPEC.md @@ -29,13 +29,14 @@ All fields are read sequentially. `[ ]` = conditionally present. | 9 | [topic] | uint8 length + UTF-8 | Present iff pid is NOT present. Length may be 0. | | 10 | type | uint8 + [US-ASCII string] | If flag bit 2 (common type) set: uint8 is a Common Media Type ID (see §4). Otherwise: uint8 is length of subsequent ASCII Media Type string. | | 11 | size | uint32 | Byte length of data on the wire (after compression, if zlib-deflate set). | -| 12 | attachment headers | uint8 count + [headers] | Count may be 0. Each header: see §5. | -| 13 | data | bytes | Exactly _size_ bytes. | -| 14 | [attachments data] | bytes | Concatenated attachment payloads, sizes defined by headers. | +| 12 | [expanded size] | uint32 | Decompressed byte length. Present iff zlib-deflate set. | +| 13 | attachment headers | uint8 count + [headers] | Count may be 0. Each header: see §5. | +| 14 | data | bytes | Exactly _size_ bytes. | +| 15 | [attachments data] | bytes | Concatenated attachment payloads, sizes defined by headers. | -**Message header** = fields 1–12. **Message header hash** = SHA-256(message header). **Message hash** = SHA-256(entire message, fields 1–14). +**Message header** = fields 1–13. **Message header hash** = SHA-256(message header). **Message hash** = SHA-256(entire message, fields 1–15). -The hash MUST be computed over the full message bytes: message header fields exactly as transmitted, followed by message data and any attachments data. When the zlib-deflate flag is set for message data or an attachment's data, that data MUST be decompressed prior to inclusion in the hash computation. +The hash MUST be computed over the full message bytes: message header fields exactly as transmitted, followed by message data and any attachments data. When the zlib-deflate flag is set for message data or an attachment's data, that data MUST be decompressed prior to inclusion in the hash computation and MUST exactly match the corresponding _expanded size_; mismatch means invalid → TERMINATE. **Sender** = _from_ when _has add to_ not set; _add to from_ when set. @@ -50,7 +51,7 @@ The hash MUST be computed over the full message bytes: message header fields exa | 2 | common type | type field is a 1-byte Common Media Type ID instead of length-prefixed string. | | 3 | important | Sender flags message as important. | | 4 | no reply | Sender will discard any reply. | -| 5 | zlib-deflate | Message data compressed with zlib/deflate (RFC 1950/1951). | +| 5 | zlib-deflate | Message data compressed with zlib/deflate (RFC 1950/1951); _expanded size_ field present. | | 6–7 | reserved | Must be 0. | ## 4. Common Media Types @@ -104,6 +105,7 @@ Each attachment header, in order: | type | uint8 + [ASCII string] | Same encoding rule as message type, using this attachment's own common type flag. | | filename | uint8 length + UTF-8 | < 256 bytes. Unicode letters/numbers, plus `-` `_` ` ` `.` non-consecutively, not at start/end. Unique per message (case-insensitive). | | size | uint32 | Byte length of this attachment's data on the wire (after compression, if zlib-deflate set). | +| [expanded size] | uint32 | Decompressed byte length. Present iff this attachment's zlib-deflate flag set. | Attachment data payloads follow all headers, concatenated in order. @@ -129,7 +131,7 @@ Single-value codes (sent as first/only byte): | 1 | invalid | Message header fails validation. | | 2 | unsupported version | Version not supported. | | 3 | undisclosed | No reason given. | -| 4 | too big | Exceeds MAX_SIZE. | +| 4 | too big | Exceeds MAX_SIZE or MAX_EXPANDED_SIZE. | | 5 | insufficient resources | e.g. disk full. | | 6 | parent not found | pid references unknown message. | | 7 | too old | Timestamp too far in past. | @@ -167,7 +169,8 @@ One message per connection. Two TCP connections used: Connection 1 (message tran | Variable | Description | |----------|-------------| -| MAX_SIZE | Max total bytes of data + attachment data. | +| MAX_SIZE | Max total bytes of data + attachment data on the wire. | +| MAX_EXPANDED_SIZE | Max total bytes after decompression; SHOULD normally equal MAX_SIZE. | | MAX_MESSAGE_AGE | Max seconds a message time may be in the past. | | MAX_TIME_SKEW | Max seconds a message time may be in the future. | @@ -278,6 +281,7 @@ An add-to message is a duplicate of the original message with these differences: ## 13. Security Requirements - Enforce MAX_SIZE before downloading data. +- Enforce MAX_EXPANDED_SIZE: reject when declared expanded size exceeds limit. - Enforce per-connection and per-IP rate limits. - Apply idle/slow-connection timeouts. - Verify sender IP via DNS BEFORE issuing any challenge. diff --git a/src/.env.example b/src/.env.example index 920169b..dbedf7e 100644 --- a/src/.env.example +++ b/src/.env.example @@ -7,6 +7,7 @@ FMSG_ID_URL=http://127.0.0.1:8080 FMSG_MAX_MSG_SIZE=10240 +FMSG_MAX_EXPANDED_SIZE=10240 FMSG_MAX_PAST_TIME_DELTA=604800 FMSG_MAX_FUTURE_TIME_DELTA=300 FMSG_MIN_DOWNLOAD_RATE=5000 diff --git a/src/deflate_test.go b/src/deflate_test.go index f3e7b56..7c10c34 100644 --- a/src/deflate_test.go +++ b/src/deflate_test.go @@ -337,14 +337,15 @@ func TestGetMessageHash_WithDeflate(t *testing.T) { // Build header with deflate flag pointing at compressed file h := &FMsgHeader{ - Version: 1, - Flags: FlagDeflate, - From: FMsgAddress{User: "alice", Domain: "example.com"}, - To: []FMsgAddress{{User: "bob", Domain: "other.com"}}, - Topic: "test", - Type: "text/plain;charset=UTF-8", - Size: cSize, - Filepath: dstPath, + Version: 1, + Flags: FlagDeflate, + From: FMsgAddress{User: "alice", Domain: "example.com"}, + To: []FMsgAddress{{User: "bob", Domain: "other.com"}}, + Topic: "test", + Type: "text/plain;charset=UTF-8", + Size: cSize, + ExpandedSize: uint32(len(original)), + Filepath: dstPath, } msgHash, err := h.GetMessageHash() @@ -432,6 +433,7 @@ func TestGetMessageHash_DeflateChangesHash(t *testing.T) { deflated := base deflated.Flags = FlagDeflate deflated.Size = cSize + deflated.ExpandedSize = uint32(len(original)) deflated.Filepath = dstPath hashDeflated, err := deflated.GetMessageHash() if err != nil { @@ -472,11 +474,12 @@ func TestGetMessageHash_AttachmentDeflate(t *testing.T) { Filepath: msgPath, Attachments: []FMsgAttachmentHeader{ { - Flags: 1 << 1, // attachment deflate bit - Type: "text/csv", - Filename: "data.csv", - Size: attCSize, - Filepath: attDstPath, + Flags: 1 << 1, // attachment deflate bit + Type: "text/csv", + Filename: "data.csv", + Size: attCSize, + ExpandedSize: uint32(len(attOriginal)), + Filepath: attDstPath, }, }, } diff --git a/src/defs.go b/src/defs.go index 084dc2e..a0ee4fa 100644 --- a/src/defs.go +++ b/src/defs.go @@ -18,11 +18,12 @@ type FMsgAddress struct { } type FMsgAttachmentHeader struct { - Flags uint8 - TypeID uint8 - Type string - Filename string - Size uint32 + Flags uint8 + TypeID uint8 + Type string + Filename string + Size uint32 + ExpandedSize uint32 Filepath string } @@ -41,8 +42,9 @@ type FMsgHeader struct { Type string // Size in bytes of entire message - Size uint32 - Attachments []FMsgAttachmentHeader + Size uint32 + ExpandedSize uint32 // Decompressed size; present on wire iff FlagDeflate set + Attachments []FMsgAttachmentHeader HeaderHash []byte ChallengeHash [32]byte @@ -117,6 +119,12 @@ func (h *FMsgHeader) Encode() []byte { if err := binary.Write(&b, binary.LittleEndian, h.Size); err != nil { panic(err) } + // expanded size (uint32 LE) — present iff zlib-deflate flag set + if h.Flags&FlagDeflate != 0 { + if err := binary.Write(&b, binary.LittleEndian, h.ExpandedSize); err != nil { + panic(err) + } + } // attachment headers b.WriteByte(byte(len(h.Attachments))) for _, att := range h.Attachments { @@ -138,6 +146,12 @@ func (h *FMsgHeader) Encode() []byte { if err := binary.Write(&b, binary.LittleEndian, att.Size); err != nil { panic(err) } + // attachment expanded size — present iff attachment zlib-deflate flag set + if att.Flags&(1<<1) != 0 { + if err := binary.Write(&b, binary.LittleEndian, att.ExpandedSize); err != nil { + panic(err) + } + } } return b.Bytes() } @@ -187,7 +201,7 @@ func (h *FMsgHeader) GetMessageHash() ([]byte, error) { return nil, err } - if err := hashPayload(hash, h.Filepath, int64(h.Size), h.Flags&FlagDeflate != 0); err != nil { + if err := hashPayload(hash, h.Filepath, int64(h.Size), h.Flags&FlagDeflate != 0, h.ExpandedSize); err != nil { return nil, err } @@ -195,7 +209,7 @@ func (h *FMsgHeader) GetMessageHash() ([]byte, error) { // the message body, bounded by attachment header sizes) for _, att := range h.Attachments { compressed := att.Flags&(1<<1) != 0 - if err := hashPayload(hash, att.Filepath, int64(att.Size), compressed); err != nil { + if err := hashPayload(hash, att.Filepath, int64(att.Size), compressed, att.ExpandedSize); err != nil { return nil, fmt.Errorf("hash attachment %s: %w", att.Filename, err) } } @@ -205,7 +219,7 @@ func (h *FMsgHeader) GetMessageHash() ([]byte, error) { return h.messageHash, nil } -func hashPayload(dst io.Writer, filepath string, wireSize int64, deflated bool) error { +func hashPayload(dst io.Writer, filepath string, wireSize int64, deflated bool, expandedSize uint32) error { f, err := os.Open(filepath) if err != nil { return err @@ -218,11 +232,14 @@ func hashPayload(dst io.Writer, filepath string, wireSize int64, deflated bool) if err != nil { return err } - _, err = io.Copy(dst, zr) + written, err := io.Copy(dst, zr) _ = zr.Close() if err != nil { return err } + if uint32(written) != expandedSize { + return fmt.Errorf("decompressed size %d does not match declared expanded size %d", written, expandedSize) + } return nil } diff --git a/src/defs_test.go b/src/defs_test.go index 3eee685..591b555 100644 --- a/src/defs_test.go +++ b/src/defs_test.go @@ -604,16 +604,17 @@ func TestGetMessageHashUsesDecompressedPayloads(t *testing.T) { } h := &FMsgHeader{ - Version: 1, - Flags: FlagDeflate, - From: FMsgAddress{User: "alice", Domain: "a.com"}, - To: []FMsgAddress{{User: "bob", Domain: "b.com"}}, - Timestamp: 1700000000, - Topic: "t", - Type: "text/plain", - Size: uint32(len(msgWire)), + Version: 1, + Flags: FlagDeflate, + From: FMsgAddress{User: "alice", Domain: "a.com"}, + To: []FMsgAddress{{User: "bob", Domain: "b.com"}}, + Timestamp: 1700000000, + Topic: "t", + Type: "text/plain", + Size: uint32(len(msgWire)), + ExpandedSize: uint32(len(msgPlain)), Attachments: []FMsgAttachmentHeader{ - {Flags: 1 << 1, Type: "application/octet-stream", Filename: "a.bin", Size: uint32(len(attWire)), Filepath: attPath}, + {Flags: 1 << 1, Type: "application/octet-stream", Filename: "a.bin", Size: uint32(len(attWire)), ExpandedSize: uint32(len(attPlain)), Filepath: attPath}, }, Filepath: msgPath, } @@ -639,3 +640,215 @@ func TestGetMessageHashUsesDecompressedPayloads(t *testing.T) { t.Fatalf("message hash mismatch: got %x want %x", got, want) } } + +func TestEncodeExpandedSizePresentWhenDeflateSet(t *testing.T) { + h := &FMsgHeader{ + Version: 1, + Flags: FlagDeflate, + From: FMsgAddress{User: "a", Domain: "b.com"}, + To: []FMsgAddress{{User: "c", Domain: "d.com"}}, + Timestamp: 0, + Topic: "", + Type: "text/plain", + Size: 50, + ExpandedSize: 200, + } + b := h.Encode() + r := bytes.NewReader(b) + + r.ReadByte() // version + r.ReadByte() // flags + + // skip from + fLen, _ := r.ReadByte() + r.Read(make([]byte, fLen)) + // skip to count + to[0] + r.ReadByte() + tLen, _ := r.ReadByte() + r.Read(make([]byte, tLen)) + // skip timestamp + var ts float64 + binary.Read(r, binary.LittleEndian, &ts) + // skip topic + topicLen, _ := r.ReadByte() + r.Read(make([]byte, topicLen)) + // skip type + typeLen, _ := r.ReadByte() + r.Read(make([]byte, typeLen)) + + // size + var size uint32 + binary.Read(r, binary.LittleEndian, &size) + if size != 50 { + t.Fatalf("size = %d, want 50", size) + } + + // expanded size must be present because FlagDeflate is set + var expandedSize uint32 + if err := binary.Read(r, binary.LittleEndian, &expandedSize); err != nil { + t.Fatalf("reading expanded size: %v", err) + } + if expandedSize != 200 { + t.Fatalf("expanded size = %d, want 200", expandedSize) + } + + // attachment count + attachCount, _ := r.ReadByte() + if attachCount != 0 { + t.Fatalf("attach count = %d, want 0", attachCount) + } + + if r.Len() != 0 { + t.Fatalf("unexpected %d trailing bytes", r.Len()) + } +} + +func TestEncodeNoExpandedSizeWhenDeflateUnset(t *testing.T) { + h := &FMsgHeader{ + Version: 1, + Flags: 0, + From: FMsgAddress{User: "a", Domain: "b.com"}, + To: []FMsgAddress{{User: "c", Domain: "d.com"}}, + Timestamp: 0, + Topic: "", + Type: "text/plain", + Size: 100, + ExpandedSize: 999, // must NOT appear on wire + } + b := h.Encode() + r := bytes.NewReader(b) + + r.ReadByte() // version + r.ReadByte() // flags + + fLen, _ := r.ReadByte() + r.Read(make([]byte, fLen)) + r.ReadByte() + tLen, _ := r.ReadByte() + r.Read(make([]byte, tLen)) + var ts float64 + binary.Read(r, binary.LittleEndian, &ts) + topicLen, _ := r.ReadByte() + r.Read(make([]byte, topicLen)) + typeLen, _ := r.ReadByte() + r.Read(make([]byte, typeLen)) + + var size uint32 + binary.Read(r, binary.LittleEndian, &size) + if size != 100 { + t.Fatalf("size = %d, want 100", size) + } + + // No expanded size field; next byte should be attachment count = 0 + attachCount, _ := r.ReadByte() + if attachCount != 0 { + t.Fatalf("attach count = %d, want 0", attachCount) + } + + if r.Len() != 0 { + t.Fatalf("unexpected %d trailing bytes (expanded size should not be present when deflate unset)", r.Len()) + } +} + +func TestEncodeAttachmentExpandedSizePresentWhenDeflateSet(t *testing.T) { + h := &FMsgHeader{ + Version: 1, + Flags: 0, + From: FMsgAddress{User: "a", Domain: "b.com"}, + To: []FMsgAddress{{User: "c", Domain: "d.com"}}, + Timestamp: 0, + Topic: "", + Type: "text/plain", + Size: 0, + Attachments: []FMsgAttachmentHeader{ + // attachment with zlib-deflate flag (bit 1 = 0b00000010) + {Flags: 1 << 1, Type: "text/plain", Filename: "doc.txt", Size: 60, ExpandedSize: 300}, + }, + } + b := h.Encode() + r := bytes.NewReader(b) + + r.ReadByte() // version + r.ReadByte() // flags + fLen, _ := r.ReadByte() + r.Read(make([]byte, fLen)) + r.ReadByte() + tLen, _ := r.ReadByte() + r.Read(make([]byte, tLen)) + var ts float64 + binary.Read(r, binary.LittleEndian, &ts) + topicLen, _ := r.ReadByte() + r.Read(make([]byte, topicLen)) + typeLen, _ := r.ReadByte() + r.Read(make([]byte, typeLen)) + var msgSize uint32 + binary.Read(r, binary.LittleEndian, &msgSize) + + // attachment count + attachCount, _ := r.ReadByte() + if attachCount != 1 { + t.Fatalf("attach count = %d, want 1", attachCount) + } + + // attachment flags + attFlags, _ := r.ReadByte() + if attFlags != 1<<1 { + t.Fatalf("att flags = %d, want %d", attFlags, 1<<1) + } + // type (length-prefixed) + attTypeLen, _ := r.ReadByte() + r.Read(make([]byte, attTypeLen)) + // filename + attFnLen, _ := r.ReadByte() + r.Read(make([]byte, attFnLen)) + // wire size + var attSize uint32 + binary.Read(r, binary.LittleEndian, &attSize) + if attSize != 60 { + t.Fatalf("att size = %d, want 60", attSize) + } + // expanded size must be present + var attExpandedSize uint32 + if err := binary.Read(r, binary.LittleEndian, &attExpandedSize); err != nil { + t.Fatalf("reading att expanded size: %v", err) + } + if attExpandedSize != 300 { + t.Fatalf("att expanded size = %d, want 300", attExpandedSize) + } + + if r.Len() != 0 { + t.Fatalf("unexpected %d trailing bytes", r.Len()) + } +} + +func TestHashPayloadRejectsExpandedSizeMismatch(t *testing.T) { + compress := func(data []byte) []byte { + var b bytes.Buffer + w := zlib.NewWriter(&b) + w.Write(data) + w.Close() + return b.Bytes() + } + + plain := []byte("hello world this is test data") + wire := compress(plain) + + tmpDir := t.TempDir() + p := filepath.Join(tmpDir, "data.bin") + if err := os.WriteFile(p, wire, 0600); err != nil { + t.Fatalf("write file: %v", err) + } + + // Correct expanded size should succeed + var dst bytes.Buffer + if err := hashPayload(&dst, p, int64(len(wire)), true, uint32(len(plain))); err != nil { + t.Fatalf("hashPayload with correct expanded size: %v", err) + } + + // Wrong expanded size should fail + dst.Reset() + err := hashPayload(&dst, p, int64(len(wire)), true, uint32(len(plain))+1) + if err == nil { + t.Fatal("hashPayload with wrong expanded size: expected error, got nil") + } +} diff --git a/src/go.mod b/src/go.mod index f624234..a27ea72 100644 --- a/src/go.mod +++ b/src/go.mod @@ -7,10 +7,10 @@ require ( github.com/joho/godotenv v1.5.1 github.com/levenlabs/golib v0.0.0-20180911183212-0f8974794783 github.com/lib/pq v1.10.9 + github.com/miekg/dns v1.1.68 ) require ( - github.com/miekg/dns v1.1.68 // indirect github.com/stretchr/testify v1.8.2 // indirect golang.org/x/mod v0.24.0 // indirect golang.org/x/net v0.40.0 // indirect diff --git a/src/go.sum b/src/go.sum index 0bb8650..1067b98 100644 --- a/src/go.sum +++ b/src/go.sum @@ -3,6 +3,8 @@ github.com/caitlinelfring/go-env-default v1.1.0/go.mod h1:tESXPr8zFPP/cRy3cwxrHB github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= diff --git a/src/host.go b/src/host.go index 11ef000..62815fc 100644 --- a/src/host.go +++ b/src/host.go @@ -160,6 +160,7 @@ var MinDownloadRate float64 = 5000 var MinUploadRate float64 = 5000 var ReadBufferSize = 1600 var MaxMessageSize = uint32(1024 * 10) +var MaxExpandedSize = uint32(1024 * 10) var SkipAuthorisedIPs = false var TLSInsecureSkipVerify = false var DataDir = "got on startup" @@ -210,6 +211,7 @@ func loadEnvConfig() { MinUploadRate = env.GetFloatDefault("FMSG_MIN_UPLOAD_RATE", 5000) ReadBufferSize = env.GetIntDefault("FMSG_READ_BUFFER_SIZE", 1600) MaxMessageSize = uint32(env.GetIntDefault("FMSG_MAX_MSG_SIZE", 1024*10)) + MaxExpandedSize = uint32(env.GetIntDefault("FMSG_MAX_EXPANDED_SIZE", int(MaxMessageSize))) SkipAuthorisedIPs = os.Getenv("FMSG_SKIP_AUTHORISED_IPS") == "true" TLSInsecureSkipVerify = os.Getenv("FMSG_TLS_INSECURE_SKIP_VERIFY") == "true" } @@ -882,6 +884,14 @@ func readAttachmentHeaders(c net.Conn, r *bufio.Reader, h *FMsgHeader) error { } totalSize := h.Size + // When message is compressed, expanded size comes from the header field. + // When uncompressed, the wire size IS the expanded size. + var totalExpandedSize uint32 + if h.Flags&FlagDeflate != 0 { + totalExpandedSize = h.ExpandedSize + } else { + totalExpandedSize = h.Size + } filenameSeen := make(map[string]bool) for i := uint8(0); i < attachCount; i++ { attFlags, err := r.ReadByte() @@ -922,12 +932,25 @@ func readAttachmentHeaders(c net.Conn, r *bufio.Reader, h *FMsgHeader) error { return err } + // read attachment expanded size — present iff attachment zlib-deflate flag set (§5) + var attExpandedSize uint32 + if attFlags&(1<<1) != 0 { + if err := binary.Read(r, binary.LittleEndian, &attExpandedSize); err != nil { + return err + } + totalExpandedSize += attExpandedSize + } else { + // uncompressed: expanded size equals wire size + totalExpandedSize += attSize + } + h.Attachments = append(h.Attachments, FMsgAttachmentHeader{ - Flags: attFlags, - TypeID: attTypeID, - Type: attType, - Filename: filename, - Size: attSize, + Flags: attFlags, + TypeID: attTypeID, + Type: attType, + Filename: filename, + Size: attSize, + ExpandedSize: attExpandedSize, }) totalSize += attSize } @@ -939,6 +962,13 @@ func readAttachmentHeaders(c net.Conn, r *bufio.Reader, h *FMsgHeader) error { return fmt.Errorf("total message size %d exceeds max %d", totalSize, MaxMessageSize) } + if totalExpandedSize > MaxExpandedSize { + if err := sendCode(c, RejectCodeTooBig); err != nil { + return err + } + return fmt.Errorf("total expanded size %d exceeds MAX_EXPANDED_SIZE %d", totalExpandedSize, MaxExpandedSize) + } + return nil } @@ -1018,6 +1048,18 @@ func readHeader(c net.Conn) (*FMsgHeader, *bufio.Reader, error) { if err := binary.Read(r, binary.LittleEndian, &h.Size); err != nil { return h, r, err } + // read expanded size — present iff zlib-deflate flag is set (§2 field 12) + if h.Flags&FlagDeflate != 0 { + if err := binary.Read(r, binary.LittleEndian, &h.ExpandedSize); err != nil { + return h, r, err + } + if h.ExpandedSize > MaxExpandedSize { + if err := sendCode(c, RejectCodeTooBig); err != nil { + return h, r, err + } + return h, r, fmt.Errorf("expanded size %d exceeds MAX_EXPANDED_SIZE %d", h.ExpandedSize, MaxExpandedSize) + } + } // Size check is deferred until attachment headers are parsed (see below) if err := readAttachmentHeaders(c, r, h); err != nil { diff --git a/src/host_test.go b/src/host_test.go index f0d3a58..48c54be 100644 --- a/src/host_test.go +++ b/src/host_test.go @@ -545,3 +545,78 @@ func TestResolvePostChallengeCode(t *testing.T) { }) } } + +func TestReadAttachmentHeadersReadsExpandedSizeForCompressedAttachment(t *testing.T) { + origMax := MaxMessageSize + origExpanded := MaxExpandedSize + MaxMessageSize = 1024 + MaxExpandedSize = 1024 + t.Cleanup(func() { + MaxMessageSize = origMax + MaxExpandedSize = origExpanded + }) + + h := &FMsgHeader{Size: 0} + b := []byte{1} // 1 attachment + b = append(b, 1<<1) // attachment flags: zlib-deflate (bit 1) + b = append(b, encodeUInt8String(t, "text/plain")...) + b = append(b, encodeUInt8String(t, "file.txt")...) + + var wireSize [4]byte + binary.LittleEndian.PutUint32(wireSize[:], 50) + b = append(b, wireSize[:]...) + + var expandedSize [4]byte + binary.LittleEndian.PutUint32(expandedSize[:], 200) + b = append(b, expandedSize[:]...) + + err := readAttachmentHeaders(nil, bufio.NewReader(bytes.NewReader(b)), h) + if err != nil { + t.Fatalf("readAttachmentHeaders returned error: %v", err) + } + if len(h.Attachments) != 1 { + t.Fatalf("len(h.Attachments) = %d, want 1", len(h.Attachments)) + } + att := h.Attachments[0] + if att.Size != 50 { + t.Fatalf("att.Size = %d, want 50", att.Size) + } + if att.ExpandedSize != 200 { + t.Fatalf("att.ExpandedSize = %d, want 200", att.ExpandedSize) + } +} + +func TestReadAttachmentHeadersRejectsExpandedSizeExceedsMax(t *testing.T) { + origMax := MaxMessageSize + origExpanded := MaxExpandedSize + MaxMessageSize = 1024 + MaxExpandedSize = 100 + t.Cleanup(func() { + MaxMessageSize = origMax + MaxExpandedSize = origExpanded + }) + + h := &FMsgHeader{Size: 0} + c := &testConn{} + b := []byte{1} // 1 attachment + b = append(b, 1<<1) // attachment flags: zlib-deflate (bit 1) + b = append(b, encodeUInt8String(t, "text/plain")...) + b = append(b, encodeUInt8String(t, "file.txt")...) + + var wireSize [4]byte + binary.LittleEndian.PutUint32(wireSize[:], 50) + b = append(b, wireSize[:]...) + + // expanded size exceeds MaxExpandedSize=100 + var expandedSize [4]byte + binary.LittleEndian.PutUint32(expandedSize[:], 200) + b = append(b, expandedSize[:]...) + + err := readAttachmentHeaders(c, bufio.NewReader(bytes.NewReader(b)), h) + if err == nil { + t.Fatalf("expected error when expanded size exceeds max") + } + if got := c.Bytes(); len(got) != 1 || got[0] != RejectCodeTooBig { + t.Fatalf("expected reject code %d, got %v", RejectCodeTooBig, got) + } +} diff --git a/src/sender.go b/src/sender.go index 3c54937..3d50a98 100644 --- a/src/sender.go +++ b/src/sender.go @@ -321,6 +321,7 @@ func deliverMessage(target pendingTarget) { log.Printf("INFO: sender: compressed msg %d data: %d -> %d bytes", target.MsgID, h.Size, cs) deflateCleanup = append(deflateCleanup, dp) h.Filepath = dp + h.ExpandedSize = h.Size h.Size = cs h.Flags |= FlagDeflate } @@ -335,6 +336,7 @@ func deliverMessage(target pendingTarget) { log.Printf("INFO: sender: compressed msg %d attachment %s: %d -> %d bytes", target.MsgID, att.Filename, att.Size, cs) deflateCleanup = append(deflateCleanup, dp) att.Filepath = dp + att.ExpandedSize = att.Size att.Size = cs att.Flags |= 1 << 1 }