diff --git a/tun/offload_linux.go b/tun/offload_linux.go index fb6ac5b94..9d0c9e707 100644 --- a/tun/offload_linux.go +++ b/tun/offload_linux.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "io" + "slices" "unsafe" "github.com/tailscale/wireguard-go/conn" @@ -72,8 +73,63 @@ const ( // virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the // shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr). virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{})) + + // Vector layout: [virtioHdr | headPacket | coalescedPayloads...] + virtioNetHdrIdx = 0 + emptyVectorLen = 1 + headPacketIdx = 1 + singlePacketLen = 2 + coalescedPacketsIdx = 2 + + maxScatterGatherFragments = 1024 // Limited by UIO_MAXIOV of 1024. ) +// groToWrite holds the write-ordered scatter-gather IO vectors for +// writev. +type groToWrite struct { + // Each iov is a [][]byte: + // - iovs[i][0] is the pre-allocated virtio header + // - iovs[i][1] is the head packet with transport headers + // - iovs[i][2:] are coalesced payload fragments + iovs [][][]byte + allocated int +} + +func newGROToWrite() groToWrite { + wi := groToWrite{ + iovs: make([][][]byte, 0, conn.IdealBatchSize), + } + for range cap(wi.iovs) { + _, _ = wi.next() // pre-allocate virtio headers + } + wi.iovs = wi.iovs[:0] + return wi +} + +// next extends iovs by one, reusing the pre-allocated backing. +// Returns the index and a pointer to the new item's [][]byte. +// +// Do not use the returned pointer after the next call to next. +func (w *groToWrite) next() (int, *[][]byte) { + n := len(w.iovs) + if n < w.allocated { + w.iovs = w.iovs[:n+1] + } else { + iov := make([][]byte, 1, conn.IdealBatchSize) + iov[virtioNetHdrIdx] = make([]byte, virtioNetHdrLen) + w.iovs = append(w.iovs, iov) + w.allocated++ + } + return n, &w.iovs[n] +} + +func (w *groToWrite) reset() { + for i := range w.iovs { + w.iovs[i] = w.iovs[i][:emptyVectorLen] // keep virtio header alloc + } + w.iovs = w.iovs[:0] +} + // tcpFlowKey represents the key for a TCP flow. type tcpFlowKey struct { srcAddr, dstAddr [16]byte @@ -114,28 +170,31 @@ func newTCPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset int) tcp // lookupOrInsert looks up a flow for the provided packet and metadata, // returning the packets found for the flow, or inserting a new one if none // is found. -func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) { +func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen int, wi *groToWrite) ([]tcpGROItem, bool) { key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) items, ok := t.itemsByFlow[key] if ok { return items, ok } // TODO: insert() performs another map lookup. This could be rearranged to avoid. - t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex) + t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, wi) return nil, false } // insert an item in the table for the provided packet and packet metadata. -func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) { +func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen int, wi *groToWrite) { key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + idx, iov := wi.next() + *iov = append(*iov, pkt) item := tcpGROItem{ - key: key, - bufsIndex: uint16(bufsIndex), - gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])), - iphLen: uint8(tcphOffset), - tcphLen: uint8(tcphLen), - sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]), - pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0, + key: key, + outputIdx: uint16(idx), + gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])), + iphLen: uint8(tcphOffset), + tcphLen: uint8(tcphLen), + sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]), + pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0, + payloadLen: uint16(len(pkt[tcphOffset+tcphLen:])), } items, ok := t.itemsByFlow[key] if !ok { @@ -159,14 +218,14 @@ func (t *tcpGROTable) deleteAt(key tcpFlowKey, i int) { // tcpGROItem represents bookkeeping data for a TCP packet during the lifetime // of a GRO evaluation across a vector of packets. type tcpGROItem struct { - key tcpFlowKey - sentSeq uint32 // the sequence number - bufsIndex uint16 // the index into the original bufs slice - numMerged uint16 // the number of packets merged into this item - gsoSize uint16 // payload size - iphLen uint8 // ip header len - tcphLen uint8 // tcp header len - pshSet bool // psh flag is set + key tcpFlowKey + sentSeq uint32 // the sequence number + outputIdx uint16 // index into groToWrite + payloadLen uint16 // accumulated payload bytes + gsoSize uint16 // payload size + iphLen uint8 // ip header len + tcphLen uint8 // tcp header len + pshSet bool // psh flag is set } func (t *tcpGROTable) newItems() []tcpGROItem { @@ -221,26 +280,29 @@ func newUDPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int) udp // lookupOrInsert looks up a flow for the provided packet and metadata, // returning the packets found for the flow, or inserting a new one if none // is found. -func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int) ([]udpGROItem, bool) { +func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int, wi *groToWrite) ([]udpGROItem, bool) { key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset) items, ok := u.itemsByFlow[key] if ok { return items, ok } // TODO: insert() performs another map lookup. This could be rearranged to avoid. - u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex, false) + u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, wi, false) return nil, false } // insert an item in the table for the provided packet and packet metadata. -func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int, cSumKnownInvalid bool) { +func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int, wi *groToWrite, cSumKnownInvalid bool) { key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset) + idx, iov := wi.next() + *iov = append(*iov, pkt) item := udpGROItem{ key: key, - bufsIndex: uint16(bufsIndex), + outputIdx: uint16(idx), gsoSize: uint16(len(pkt[udphOffset+udphLen:])), iphLen: uint8(udphOffset), cSumKnownInvalid: cSumKnownInvalid, + payloadLen: uint16(len(pkt[udphOffset+udphLen:])), } items, ok := u.itemsByFlow[key] if !ok { @@ -259,8 +321,8 @@ func (u *udpGROTable) updateAt(item udpGROItem, i int) { // of a GRO evaluation across a vector of packets. type udpGROItem struct { key udpFlowKey - bufsIndex uint16 // the index into the original bufs slice - numMerged uint16 // the number of packets merged into this item + outputIdx uint16 // index into groToWrite + payloadLen uint16 // accumulated payload bytes gsoSize uint16 // payload size iphLen uint8 // ip header len cSumKnownInvalid bool // UDP header checksum validity; a false value DOES NOT imply valid, just unknown. @@ -325,15 +387,16 @@ func ipHeadersCanCoalesce(pktA, pktB []byte) bool { } // udpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet -// described by item. iphLen and gsoSize describe pkt. bufs is the vector of -// packets involved in the current GRO evaluation. bufsOffset is the offset at -// which packet data begins within bufs. -func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { - pktTarget := bufs[item.bufsIndex][bufsOffset:] +// described by item. iphLen and gsoSize describe pkt. +func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, wi *groToWrite) canCoalesce { + if len(wi.iovs[item.outputIdx]) >= maxScatterGatherFragments { + return coalesceUnavailable + } + pktTarget := wi.iovs[item.outputIdx][headPacketIdx] if !ipHeadersCanCoalesce(pkt, pktTarget) { return coalesceUnavailable } - if len(pktTarget[iphLen+udphLen:])%int(item.gsoSize) != 0 { + if item.payloadLen%item.gsoSize != 0 { // A smaller than gsoSize packet has been appended previously. // Nothing can come after a smaller packet on the end. return coalesceUnavailable @@ -342,14 +405,20 @@ func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGRO // We cannot have a larger packet following a smaller one. return coalesceUnavailable } + if int(item.iphLen)+udphLen+int(item.payloadLen)+int(gsoSize) > maxUint16 { + return coalesceUnavailable + } return coalesceAppend } // tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet // described by item. This function makes considerations that match the kernel's // GRO self tests, which can be found in tools/testing/selftests/net/gro.c. -func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { - pktTarget := bufs[item.bufsIndex][bufsOffset:] +func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, wi *groToWrite) canCoalesce { + if len(wi.iovs[item.outputIdx]) >= maxScatterGatherFragments { + return coalesceUnavailable + } + pktTarget := wi.iovs[item.outputIdx][headPacketIdx] if tcphLen != item.tcphLen { // cannot coalesce with unequal tcp options len return coalesceUnavailable @@ -363,16 +432,17 @@ func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet if !ipHeadersCanCoalesce(pkt, pktTarget) { return coalesceUnavailable } + if int(item.iphLen)+int(item.tcphLen)+int(item.payloadLen)+int(gsoSize) > maxUint16 { + return coalesceUnavailable + } // seq adjacency - lhsLen := item.gsoSize - lhsLen += item.numMerged * item.gsoSize - if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective + if seq == item.sentSeq+uint32(item.payloadLen) { // pkt aligns following item from a seq num perspective if item.pshSet { // We cannot append to a segment that has the PSH flag set, PSH // can only be set on the final segment in a reassembled group. return coalesceUnavailable } - if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 { + if item.payloadLen%item.gsoSize != 0 { // A smaller than gsoSize packet has been appended previously. // Nothing can come after a smaller packet on the end. return coalesceUnavailable @@ -392,7 +462,7 @@ func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet // We cannot have a larger packet following a smaller one. return coalesceUnavailable } - if gsoSize > item.gsoSize && item.numMerged > 0 { + if gsoSize > item.gsoSize && len(wi.iovs[item.outputIdx]) > singlePacketLen { // There's at least one previous merge, and we're larger than all // previous. This would put multiple smaller packets on the end. return coalesceUnavailable @@ -414,13 +484,11 @@ func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool { return ^Checksum(pkt[iphLen:], cSum) == 0 } -// coalesceResult represents the result of attempting to coalesce two TCP -// packets. +// coalesceResult represents the result of attempting to coalesce two packets. type coalesceResult int const ( - coalesceInsufficientCap coalesceResult = iota - coalescePSHEnding + coalescePSHEnding coalesceResult = iota coalesceItemInvalidCSum coalescePktInvalidCSum coalesceSuccess @@ -428,54 +496,33 @@ const ( // coalesceUDPPackets attempts to coalesce pkt with the packet described by // item, and returns the outcome. -func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { - pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front - headersLen := item.iphLen + udphLen - coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) - - if cap(pktHead)-bufsOffset < coalescedLen { - // We don't want to allocate a new underlying array if capacity is - // too small. - return coalesceInsufficientCap - } - if item.numMerged == 0 { - if item.cSumKnownInvalid || !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) { +func coalesceUDPPackets(pkt []byte, item *udpGROItem, wi *groToWrite, isV6 bool) coalesceResult { + headersLen := int(item.iphLen) + udphLen + iov := &wi.iovs[item.outputIdx] + if len(*iov) == singlePacketLen { + if item.cSumKnownInvalid || !checksumValid((*iov)[headPacketIdx], item.iphLen, unix.IPPROTO_UDP, isV6) { return coalesceItemInvalidCSum } } if !checksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) { return coalescePktInvalidCSum } - extendBy := len(pkt) - int(headersLen) - bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) - copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) - - item.numMerged++ + *iov = append(*iov, pkt[headersLen:]) + item.payloadLen += uint16(len(pkt) - headersLen) return coalesceSuccess } // coalesceTCPPackets attempts to coalesce pkt with the packet described by -// item, and returns the outcome. This function may swap bufs elements in the -// event of a prepend as item's bufs index is already being tracked for writing -// to a Device. -func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { - var pktHead []byte // the packet that will end up at the front - headersLen := item.iphLen + item.tcphLen - coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) - - // Copy data +// item, and returns the outcome. +func coalesceTCPPackets(mode canCoalesce, pkt []byte, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, wi *groToWrite, isV6 bool) coalesceResult { + headersLen := int(item.iphLen) + int(item.tcphLen) + iov := &wi.iovs[item.outputIdx] if mode == coalescePrepend { - pktHead = pkt - if cap(pkt)-bufsOffset < coalescedLen { - // We don't want to allocate a new underlying array if capacity is - // too small. - return coalesceInsufficientCap - } if pshSet { return coalescePSHEnding } - if item.numMerged == 0 { - if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { + if len(*iov) == singlePacketLen { + if !checksumValid((*iov)[headPacketIdx], item.iphLen, unix.IPPROTO_TCP, isV6) { return coalesceItemInvalidCSum } } @@ -483,21 +530,15 @@ func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize return coalescePktInvalidCSum } item.sentSeq = seq - extendBy := coalescedLen - len(pktHead) - bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...) - copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):]) - // Flip the slice headers in bufs as part of prepend. The index of item - // is already being tracked for writing. - bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex] + oldHead := (*iov)[headPacketIdx] + (*iov)[headPacketIdx] = pkt + oldHeadPayload := oldHead[headersLen:] + if len(oldHeadPayload) > 0 { + *iov = slices.Insert(*iov, coalescedPacketsIdx, oldHeadPayload) + } } else { - pktHead = bufs[item.bufsIndex][bufsOffset:] - if cap(pktHead)-bufsOffset < coalescedLen { - // We don't want to allocate a new underlying array if capacity is - // too small. - return coalesceInsufficientCap - } - if item.numMerged == 0 { - if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { + if len(*iov) == singlePacketLen { + if !checksumValid((*iov)[headPacketIdx], item.iphLen, unix.IPPROTO_TCP, isV6) { return coalesceItemInvalidCSum } } @@ -507,18 +548,16 @@ func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize if pshSet { // We are appending a segment with PSH set. item.pshSet = pshSet - pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH + (*iov)[headPacketIdx][item.iphLen+tcpFlagsOffset] |= tcpFlagPSH } - extendBy := len(pkt) - int(headersLen) - bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) - copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) + *iov = append(*iov, pkt[headersLen:]) } if gsoSize > item.gsoSize { item.gsoSize = gsoSize } - item.numMerged++ + item.payloadLen += uint16(len(pkt) - headersLen) return coalesceSuccess } @@ -538,13 +577,12 @@ const ( groResultCoalesced ) -// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with +// tcpGRO evaluates the TCP packet for coalescing with // existing packets tracked in table. It returns a groResultNoop when no // action was taken, groResultTableInsert when the evaluated packet was // inserted into table, and groResultCoalesced when the evaluated packet was // coalesced with another packet in table. -func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult { - pkt := bufs[pktI][offset:] +func tcpGRO(pkt []byte, table *tcpGROTable, wi *groToWrite, isV6 bool) groResult { if len(pkt) > maxUint16 { // A valid IPv4 or IPv6 packet will never exceed this. return groResultNoop @@ -599,7 +637,7 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) srcAddrOffset = ipv6SrcAddrOffset addrLen = 16 } - items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) + items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, wi) if !existing { return groResultTableInsert } @@ -613,9 +651,9 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) // sequence number perspective, however once an item is inserted into // the table it is never compared across other items later. item := items[i] - can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset) + can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, wi) if can != coalesceUnavailable { - result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6) + result := coalesceTCPPackets(can, pkt, gsoSize, seq, pshSet, &item, wi, isV6) switch result { case coalesceSuccess: table.updateAt(item, i) @@ -631,16 +669,19 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) } } // failed to coalesce with any other packets; store the item in the flow - table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) + table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, wi) return groResultTableInsert } -// applyTCPCoalesceAccounting updates bufs to account for coalescing based on the +// applyTCPCoalesceAccounting updates headers to account for coalescing based on the // metadata found in table. -func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) error { +func applyTCPCoalesceAccounting(wi *groToWrite, table *tcpGROTable) error { for _, items := range table.itemsByFlow { for _, item := range items { - if item.numMerged > 0 { + iov := wi.iovs[item.outputIdx] + pkt := iov[headPacketIdx] + if len(iov) > singlePacketLen { + totalLen := uint16(item.iphLen) + uint16(item.tcphLen) + item.payloadLen hdr := virtioNetHdr{ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb hdrLen: uint16(item.iphLen + item.tcphLen), @@ -648,21 +689,20 @@ func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) e csumStart: uint16(item.iphLen), csumOffset: 16, } - pkt := bufs[item.bufsIndex][offset:] // Recalculate the total len (IPv4) or payload len (IPv6). // Recalculate the (IPv4) header checksum. if item.key.isV6 { hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 - binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len + binary.BigEndian.PutUint16(pkt[4:], totalLen-uint16(item.iphLen)) // set new IPv6 header payload len } else { hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 pkt[10], pkt[11] = 0, 0 - binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length - iphCSum := ^Checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum - binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field + binary.BigEndian.PutUint16(pkt[2:], totalLen) // set new total length + iphCSum := ^Checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum + binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field } - err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + err := hdr.encode(iov[virtioNetHdrIdx]) if err != nil { return err } @@ -676,56 +716,53 @@ func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) e addrLen = 16 addrOffset = ipv6SrcAddrOffset } - srcAddrAt := offset + addrOffset - srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] - dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] - psum := PseudoHeaderChecksum(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + srcAddr := pkt[addrOffset : addrOffset+addrLen] + dstAddr := pkt[addrOffset+addrLen : addrOffset+addrLen*2] + psum := PseudoHeaderChecksum(unix.IPPROTO_TCP, srcAddr, dstAddr, totalLen-uint16(item.iphLen)) binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], Checksum([]byte{}, psum)) } else { - hdr := virtioNetHdr{} - err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) - if err != nil { - return err - } + clear(iov[virtioNetHdrIdx]) } } } return nil } -// applyUDPCoalesceAccounting updates bufs to account for coalescing based on the +// applyUDPCoalesceAccounting updates headers to account for coalescing based on the // metadata found in table. -func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) error { +func applyUDPCoalesceAccounting(wi *groToWrite, table *udpGROTable) error { for _, items := range table.itemsByFlow { for _, item := range items { - if item.numMerged > 0 { + iov := wi.iovs[item.outputIdx] + pkt := iov[headPacketIdx] + if len(iov) > singlePacketLen { + totalLen := uint16(item.iphLen) + udphLen + item.payloadLen hdr := virtioNetHdr{ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb - hdrLen: uint16(item.iphLen + udphLen), + hdrLen: uint16(item.iphLen) + udphLen, gsoSize: item.gsoSize, csumStart: uint16(item.iphLen), csumOffset: 6, } - pkt := bufs[item.bufsIndex][offset:] // Recalculate the total len (IPv4) or payload len (IPv6). // Recalculate the (IPv4) header checksum. hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_UDP_L4 if item.key.isV6 { - binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len + binary.BigEndian.PutUint16(pkt[4:], totalLen-uint16(item.iphLen)) // set new IPv6 header payload len } else { pkt[10], pkt[11] = 0, 0 - binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length - iphCSum := ^Checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum - binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field + binary.BigEndian.PutUint16(pkt[2:], totalLen) // set new total length + iphCSum := ^Checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum + binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field } - err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + err := hdr.encode(iov[virtioNetHdrIdx]) if err != nil { return err } // Recalculate the UDP len field value - binary.BigEndian.PutUint16(pkt[item.iphLen+4:], uint16(len(pkt[item.iphLen:]))) + binary.BigEndian.PutUint16(pkt[item.iphLen+4:], udphLen+item.payloadLen) // Calculate the pseudo header checksum and place it at the UDP // checksum offset. Downstream checksum offloading will combine @@ -736,17 +773,12 @@ func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) e addrLen = 16 addrOffset = ipv6SrcAddrOffset } - srcAddrAt := offset + addrOffset - srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] - dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] - psum := PseudoHeaderChecksum(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + srcAddr := pkt[addrOffset : addrOffset+addrLen] + dstAddr := pkt[addrOffset+addrLen : addrOffset+addrLen*2] + psum := PseudoHeaderChecksum(unix.IPPROTO_UDP, srcAddr, dstAddr, totalLen-uint16(item.iphLen)) binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], Checksum([]byte{}, psum)) } else { - hdr := virtioNetHdr{} - err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) - if err != nil { - return err - } + clear(iov[virtioNetHdrIdx]) } } } @@ -793,13 +825,12 @@ const ( udphLen = 8 ) -// udpGRO evaluates the UDP packet at pktI in bufs for coalescing with +// udpGRO evaluates the UDP packet for coalescing with // existing packets tracked in table. It returns a groResultNoop when no // action was taken, groResultTableInsert when the evaluated packet was // inserted into table, and groResultCoalesced when the evaluated packet was // coalesced with another packet in table. -func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult { - pkt := bufs[pktI][offset:] +func udpGRO(pkt []byte, table *udpGROTable, wi *groToWrite, isV6 bool) groResult { if len(pkt) > maxUint16 { // A valid IPv4 or IPv6 packet will never exceed this. return groResultNoop @@ -840,7 +871,7 @@ func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) srcAddrOffset = ipv6SrcAddrOffset addrLen = 16 } - items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI) + items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, wi) if !existing { return groResultTableInsert } @@ -848,10 +879,10 @@ func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) // for a given flow. We must also always insert a new item, or successfully // coalesce with an existing item, for the same reason. item := items[len(items)-1] - can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset) + can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, wi) var pktCSumKnownInvalid bool if can == coalesceAppend { - result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6) + result := coalesceUDPPackets(pkt, &item, wi, isV6) switch result { case coalesceSuccess: table.updateAt(item, len(items)-1) @@ -868,44 +899,42 @@ func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) } } // failed to coalesce with any other packets; store the item in the flow - table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI, pktCSumKnownInvalid) + table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, wi, pktCSumKnownInvalid) return groResultTableInsert } -// handleGRO evaluates bufs for GRO, and writes the indices of the resulting -// packets into toWrite. toWrite, tcpTable, and udpTable should initially be +// handleGRO evaluates bufs for GRO, and populates wi with the resulting +// io vectors for writev. wi, tcpTable, and udpTable should initially be // empty (but non-nil), and are passed in to save allocs as the caller may reset // and recycle them across vectors of packets. gro indicates if TCP and UDP GRO // are supported/enabled. -func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, gro groDisablementFlags, toWrite *[]int) error { +func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, gro groDisablementFlags, wi *groToWrite) error { for i := range bufs { if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { return errors.New("invalid offset") } + pkt := bufs[i][offset:] var result groResult - switch packetIsGROCandidate(bufs[i][offset:], gro) { + switch packetIsGROCandidate(pkt, gro) { case tcp4GROCandidate: - result = tcpGRO(bufs, offset, i, tcpTable, false) + result = tcpGRO(pkt, tcpTable, wi, false) case tcp6GROCandidate: - result = tcpGRO(bufs, offset, i, tcpTable, true) + result = tcpGRO(pkt, tcpTable, wi, true) case udp4GROCandidate: - result = udpGRO(bufs, offset, i, udpTable, false) + result = udpGRO(pkt, udpTable, wi, false) case udp6GROCandidate: - result = udpGRO(bufs, offset, i, udpTable, true) + result = udpGRO(pkt, udpTable, wi, true) } switch result { case groResultNoop: - hdr := virtioNetHdr{} - err := hdr.encode(bufs[i][offset-virtioNetHdrLen:]) - if err != nil { - return err - } - fallthrough + _, iov := wi.next() + clear((*iov)[virtioNetHdrIdx]) + *iov = append(*iov, pkt) case groResultTableInsert: - *toWrite = append(*toWrite, i) + // already in wi via table insert } } - errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable) - errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable) + errTCP := applyTCPCoalesceAccounting(wi, tcpTable) + errUDP := applyUDPCoalesceAccounting(wi, udpTable) return errors.Join(errTCP, errUDP) } diff --git a/tun/offload_linux_test.go b/tun/offload_linux_test.go index 407037863..9b5aa0400 100644 --- a/tun/offload_linux_test.go +++ b/tun/offload_linux_test.go @@ -7,6 +7,7 @@ package tun import ( "net/netip" + "slices" "testing" "github.com/tailscale/wireguard-go/conn" @@ -289,32 +290,21 @@ func Fuzz_handleGRO(f *testing.F) { f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11, 0, offset) f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, gro int, offset int) { pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11} - toWrite := make([]int, 0, len(pkts)) - handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), groDisablementFlags(gro), &toWrite) - if len(toWrite) > len(pkts) { - t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts)) - } - seenWriteI := make(map[int]bool) - for _, writeI := range toWrite { - if writeI < 0 || writeI > len(pkts)-1 { - t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts)) - } - if seenWriteI[writeI] { - t.Errorf("duplicate toWrite value: %d", writeI) - } - seenWriteI[writeI] = true + wi := newGROToWrite() + handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), groDisablementFlags(gro), &wi) + if len(wi.iovs) > len(pkts) { + t.Errorf("len(wi.iovs): %d > len(pkts): %d", len(wi.iovs), len(pkts)) } }) } func Test_handleGRO(t *testing.T) { tests := []struct { - name string - pktsIn [][]byte - gro groDisablementFlags - wantToWrite []int - wantLens []int - wantErr bool + name string + pktsIn [][]byte + gro groDisablementFlags + wantLens [][]int + wantErr bool }{ { "multiple protocols and flows", @@ -332,8 +322,15 @@ func Test_handleGRO(t *testing.T) { udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 }, 0, - []int{0, 1, 2, 4, 5, 7, 9}, - []int{240, 228, 128, 140, 260, 160, 248}, + [][]int{ + {virtioNetHdrLen, 140, 100}, // tcp4 A->B merged + {virtioNetHdrLen, 128, 100}, // udp4 A->B merged + {virtioNetHdrLen, 128}, // udp4 A->C + {virtioNetHdrLen, 140}, // tcp4 A->C + {virtioNetHdrLen, 160, 100}, // tcp6 A->B merged + {virtioNetHdrLen, 160}, // tcp6 A->C + {virtioNetHdrLen, 148, 100}, // udp6 A->B merged + }, false, }, { @@ -352,8 +349,17 @@ func Test_handleGRO(t *testing.T) { udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 }, udpGRODisabled, - []int{0, 1, 2, 4, 5, 7, 8, 9, 10}, - []int{240, 128, 128, 140, 260, 160, 128, 148, 148}, + [][]int{ + {virtioNetHdrLen, 140, 100}, // tcp4 A->B merged + {virtioNetHdrLen, 128}, // udp4 A->B noop + {virtioNetHdrLen, 128}, // udp4 A->C noop + {virtioNetHdrLen, 140}, // tcp4 A->C + {virtioNetHdrLen, 160, 100}, // tcp6 A->B merged + {virtioNetHdrLen, 160}, // tcp6 A->C + {virtioNetHdrLen, 128}, // udp4 A->B noop + {virtioNetHdrLen, 148}, // udp6 A->B noop + {virtioNetHdrLen, 148}, // udp6 A->B noop + }, false, }, { @@ -369,8 +375,12 @@ func Test_handleGRO(t *testing.T) { tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1 }, 0, - []int{0, 2, 4, 6}, - []int{240, 240, 260, 260}, + [][]int{ + {virtioNetHdrLen, 140, 100}, // v4 merged (seq 1+101 PSH) + {virtioNetHdrLen, 140, 100}, // v4 merged (seq 201+301) + {virtioNetHdrLen, 160, 100}, // v6 merged (seq 1+101 PSH) + {virtioNetHdrLen, 160, 100}, // v6 merged (seq 201+301) + }, false, }, { @@ -384,8 +394,12 @@ func Test_handleGRO(t *testing.T) { udp4Packet(ip4PortA, ip4PortB, 100), }, 0, - []int{0, 1, 3, 4}, - []int{140, 240, 128, 228}, + [][]int{ + {virtioNetHdrLen, 140}, // tcp4 bad csum, unmerged + {virtioNetHdrLen, 140, 100}, // tcp4 merged (seq 101+201) + {virtioNetHdrLen, 128}, // udp4 bad csum, unmerged + {virtioNetHdrLen, 128, 100}, // udp4 merged + }, false, }, { @@ -396,8 +410,9 @@ func Test_handleGRO(t *testing.T) { tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 }, 0, - []int{0}, - []int{340}, + [][]int{ + {virtioNetHdrLen, 140, 100, 100}, // prepend seq 1, original seq 101 payload, append seq 201 + }, false, }, { @@ -413,8 +428,12 @@ func Test_handleGRO(t *testing.T) { }), }, 0, - []int{0, 1, 2, 3}, - []int{140, 140, 128, 128}, + [][]int{ + {virtioNetHdrLen, 140}, + {virtioNetHdrLen, 140}, + {virtioNetHdrLen, 128}, + {virtioNetHdrLen, 128}, + }, false, }, { @@ -430,8 +449,12 @@ func Test_handleGRO(t *testing.T) { }), }, 0, - []int{0, 1, 2, 3}, - []int{140, 140, 128, 128}, + [][]int{ + {virtioNetHdrLen, 140}, + {virtioNetHdrLen, 140}, + {virtioNetHdrLen, 128}, + {virtioNetHdrLen, 128}, + }, false, }, { @@ -447,8 +470,12 @@ func Test_handleGRO(t *testing.T) { }), }, 0, - []int{0, 1, 2, 3}, - []int{140, 140, 128, 128}, + [][]int{ + {virtioNetHdrLen, 140}, + {virtioNetHdrLen, 140}, + {virtioNetHdrLen, 128}, + {virtioNetHdrLen, 128}, + }, false, }, { @@ -464,8 +491,12 @@ func Test_handleGRO(t *testing.T) { }), }, 0, - []int{0, 1, 2, 3}, - []int{140, 140, 128, 128}, + [][]int{ + {virtioNetHdrLen, 140}, + {virtioNetHdrLen, 140}, + {virtioNetHdrLen, 128}, + {virtioNetHdrLen, 128}, + }, false, }, { @@ -481,8 +512,12 @@ func Test_handleGRO(t *testing.T) { }), }, 0, - []int{0, 1, 2, 3}, - []int{160, 160, 148, 148}, + [][]int{ + {virtioNetHdrLen, 160}, + {virtioNetHdrLen, 160}, + {virtioNetHdrLen, 148}, + {virtioNetHdrLen, 148}, + }, false, }, { @@ -498,31 +533,46 @@ func Test_handleGRO(t *testing.T) { }), }, 0, - []int{0, 1, 2, 3}, - []int{160, 160, 148, 148}, + [][]int{ + {virtioNetHdrLen, 160}, + {virtioNetHdrLen, 160}, + {virtioNetHdrLen, 148}, + {virtioNetHdrLen, 148}, + }, false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - toWrite := make([]int, 0, len(tt.pktsIn)) - err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.gro, &toWrite) - if err != nil { - if tt.wantErr { - return + wi := newGROToWrite() + for range 2 { // validating reset() correctness + wi.reset() + // Deep copy pktsIn since coalesce accounting mutates head packet headers. + pktsIn := make([][]byte, len(tt.pktsIn)) + for k, p := range tt.pktsIn { + pktsIn[k] = slices.Clone(p) } - t.Fatalf("got err: %v", err) - } - if len(toWrite) != len(tt.wantToWrite) { - t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite)) - } - for i, pktI := range tt.wantToWrite { - if tt.wantToWrite[i] != toWrite[i] { - t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i]) + err := handleGRO(pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.gro, &wi) + if err != nil { + if tt.wantErr { + return + } + t.Fatalf("got err: %v", err) + } + if len(wi.iovs) != len(tt.wantLens) { + t.Fatalf("got %d packets, wanted %d", len(wi.iovs), len(tt.wantLens)) } - if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) { - t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:])) + for i, wantFragLens := range tt.wantLens { + iov := wi.iovs[i] + if len(iov) != len(wantFragLens) { + t.Fatalf("items[%d]: got %d fragments, wanted %d", i, len(iov), len(wantFragLens)) + } + for j, wantLen := range wantFragLens { + if len(iov[j]) != wantLen { + t.Errorf("items[%d][%d]: got len %d, want %d", i, j, len(iov[j]), wantLen) + } + } } } }) @@ -669,12 +719,11 @@ func Test_udpPacketsCanCoalesce(t *testing.T) { udp4c := udp4Packet(ip4PortA, ip4PortB, 110) type args struct { - pkt []byte - iphLen uint8 - gsoSize uint16 - item udpGROItem - bufs [][]byte - bufsOffset int + pkt []byte + iphLen uint8 + gsoSize uint16 + item udpGROItem + wi groToWrite } tests := []struct { name string @@ -688,14 +737,12 @@ func Test_udpPacketsCanCoalesce(t *testing.T) { iphLen: 20, gsoSize: 100, item: udpGROItem{ - gsoSize: 100, - iphLen: 20, + gsoSize: 100, + iphLen: 20, + outputIdx: 0, + payloadLen: 100, }, - bufs: [][]byte{ - udp4a, - udp4b, - }, - bufsOffset: offset, + wi: groToWrite{iovs: [][][]byte{{make([]byte, virtioNetHdrLen), udp4b[offset:]}}}, }, coalesceAppend, }, @@ -706,14 +753,12 @@ func Test_udpPacketsCanCoalesce(t *testing.T) { iphLen: 20, gsoSize: 10, item: udpGROItem{ - gsoSize: 100, - iphLen: 20, - }, - bufs: [][]byte{ - udp4a, - udp4b, + gsoSize: 100, + iphLen: 20, + outputIdx: 0, + payloadLen: 100, }, - bufsOffset: offset, + wi: groToWrite{iovs: [][][]byte{{make([]byte, virtioNetHdrLen), udp4b[offset:]}}}, }, coalesceAppend, }, @@ -724,14 +769,12 @@ func Test_udpPacketsCanCoalesce(t *testing.T) { iphLen: 20, gsoSize: 100, item: udpGROItem{ - gsoSize: 100, - iphLen: 20, + gsoSize: 100, + iphLen: 20, + outputIdx: 0, + payloadLen: 110, }, - bufs: [][]byte{ - udp4c, - udp4b, - }, - bufsOffset: offset, + wi: groToWrite{iovs: [][][]byte{{make([]byte, virtioNetHdrLen), udp4c[offset:]}}}, }, coalesceUnavailable, }, @@ -742,23 +785,123 @@ func Test_udpPacketsCanCoalesce(t *testing.T) { iphLen: 20, gsoSize: 110, item: udpGROItem{ - gsoSize: 100, - iphLen: 20, + gsoSize: 100, + iphLen: 20, + outputIdx: 0, + payloadLen: 100, }, - bufs: [][]byte{ - udp4a, - udp4c, + wi: groToWrite{iovs: [][][]byte{{make([]byte, virtioNetHdrLen), udp4a[offset:]}}}, + }, + coalesceUnavailable, + }, + { + "coalesceUnavailable too many fragments", + args{ + pkt: udp4a[offset:], + iphLen: 20, + gsoSize: 100, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + outputIdx: 0, + payloadLen: 100, }, - bufsOffset: offset, + wi: groToWrite{iovs: [][][]byte{make([][]byte, maxScatterGatherFragments)}}, + }, + coalesceUnavailable, + }, + { + "coalesceUnavailable payload overflow", + args{ + pkt: udp4a[offset:], + iphLen: 20, + gsoSize: 100, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + outputIdx: 0, + payloadLen: maxUint16 - 20 - 8, + }, + wi: groToWrite{iovs: [][][]byte{{make([]byte, virtioNetHdrLen), udp4a[offset:]}}}, }, coalesceUnavailable, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := udpPacketsCanCoalesce(tt.args.pkt, tt.args.iphLen, tt.args.gsoSize, tt.args.item, tt.args.bufs, tt.args.bufsOffset); got != tt.want { + if got := udpPacketsCanCoalesce(tt.args.pkt, tt.args.iphLen, tt.args.gsoSize, tt.args.item, &tt.args.wi); got != tt.want { t.Errorf("udpPacketsCanCoalesce() = %v, want %v", got, tt.want) } }) } } + +func Test_tcpPacketsCanCoalesce(t *testing.T) { + tcp4a := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 100) + + type args struct { + pkt []byte + iphLen uint8 + tcphLen uint8 + seq uint32 + pshSet bool + gsoSize uint16 + item tcpGROItem + wi groToWrite + } + tests := []struct { + name string + args args + want canCoalesce + }{ + { + "coalesceUnavailable too many fragments", + args{ + pkt: tcp4a[offset:], + iphLen: 20, + tcphLen: 20, + seq: 200, + pshSet: false, + gsoSize: 100, + item: tcpGROItem{ + gsoSize: 100, + iphLen: 20, + tcphLen: 20, + sentSeq: 100, + outputIdx: 0, + payloadLen: 100, + }, + wi: groToWrite{iovs: [][][]byte{make([][]byte, maxScatterGatherFragments)}}, + }, + coalesceUnavailable, + }, + { + "coalesceUnavailable payload overflow", + args{ + pkt: tcp4a[offset:], + iphLen: 20, + tcphLen: 20, + seq: 200, + pshSet: false, + gsoSize: 100, + item: tcpGROItem{ + gsoSize: 100, + iphLen: 20, + tcphLen: 20, + sentSeq: 100, + outputIdx: 0, + payloadLen: maxUint16 - 20 - 20, + }, + wi: groToWrite{iovs: [][][]byte{{make([]byte, virtioNetHdrLen), tcp4a[offset:]}}}, + }, + coalesceUnavailable, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tcpPacketsCanCoalesce(tt.args.pkt, tt.args.iphLen, tt.args.tcphLen, tt.args.seq, tt.args.pshSet, tt.args.gsoSize, tt.args.item, &tt.args.wi); got != tt.want { + t.Errorf("tcpPacketsCanCoalesce() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 7cdbf8825..57f15aba7 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -29,6 +29,7 @@ const ( type NativeTun struct { tunFile *os.File + tunRawConn syscall.RawConn index int32 // if index errors chan error // async error handling events chan Event // device related events @@ -49,7 +50,7 @@ type NativeTun struct { readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr writeOpMu sync.Mutex // writeOpMu guards the following fields - toWrite []int + toWrite groToWrite tcpGROTable *tcpGROTable udpGROTable *udpGROTable gro groDisablementFlags @@ -354,31 +355,53 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { defer func() { tun.tcpGROTable.reset() tun.udpGROTable.reset() + tun.toWrite.reset() tun.writeOpMu.Unlock() }() var ( errs error total int ) - tun.toWrite = tun.toWrite[:0] - if tun.vnetHdr { - err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.gro, &tun.toWrite) - if err != nil { - return 0, err - } - offset -= virtioNetHdrLen - } else { + if !tun.vnetHdr { for i := range bufs { - tun.toWrite = append(tun.toWrite, i) + n, err := tun.tunFile.Write(bufs[i][offset:]) + if errors.Is(err, syscall.EBADFD) { + return total, os.ErrClosed + } + if err != nil { + errs = errors.Join(errs, err) + } else { + total += n + } } + return total, errs } - for _, bufsI := range tun.toWrite { - n, err := tun.tunFile.Write(bufs[bufsI][offset:]) - if errors.Is(err, syscall.EBADFD) { + err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.gro, &tun.toWrite) + if err != nil { + return 0, err + } + for _, nb := range tun.toWrite.iovs { + var werr error + var n int + err := tun.tunRawConn.Write(func(fd uintptr) bool { + for { + n, werr = unix.Writev(int(fd), nb) + if werr == syscall.EINTR { + continue // quick retry on interrupt, EINTR is never returned with partial writes + } + return werr != syscall.EAGAIN // poller retry on "would block" + } + }) + // err is a poller error (e.g. fd closed before the syscall) + // werr is the Writev syscall error itself. + if err != nil { + return total, err + } + if errors.Is(werr, syscall.EBADFD) { return total, os.ErrClosed } - if err != nil { - errs = errors.Join(errs, err) + if werr != nil { + errs = errors.Join(errs, werr) } else { total += n } @@ -577,6 +600,7 @@ func CreateTUN(name string, mtu int) (Device, error) { // CreateTUNFromFile creates a Device from an os.File with the provided MTU. func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { + var err error tun := &NativeTun{ tunFile: file, events: make(chan Event, 5), @@ -584,7 +608,12 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { statusListenersShutdown: make(chan struct{}), tcpGROTable: newTCPGROTable(), udpGROTable: newUDPGROTable(), - toWrite: make([]int, 0, conn.IdealBatchSize), + toWrite: newGROToWrite(), + } + + tun.tunRawConn, err = tun.tunFile.SyscallConn() + if err != nil { + return nil, err } name, err := tun.Name() @@ -640,7 +669,11 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { errors: make(chan error, 5), tcpGROTable: newTCPGROTable(), udpGROTable: newUDPGROTable(), - toWrite: make([]int, 0, conn.IdealBatchSize), + toWrite: newGROToWrite(), + } + tun.tunRawConn, err = tun.tunFile.SyscallConn() + if err != nil { + return nil, "", err } name, err := tun.Name() if err != nil {