diff --git a/nfqueue_linux.go b/nfqueue_linux.go index baaefb54..63710c92 100644 --- a/nfqueue_linux.go +++ b/nfqueue_linux.go @@ -54,10 +54,11 @@ func newNFQueueHandler(options nfqueueOptions) (*nfqueueHandler, error) { }, nil } -func (h *nfqueueHandler) setVerdict(packetID uint32, verdict int, mark uint32) { +func (h *nfqueueHandler) setVerdict(packetID uint32, verdict int, mark uint32, existingMark uint32) { var err error if mark != 0 { - err = h.nfq.SetVerdictWithOption(packetID, verdict, nfqueue.WithMark(mark)) + combined := (existingMark &^ AutoRedirectMarkMask) | mark + err = h.nfq.SetVerdictWithOption(packetID, verdict, nfqueue.WithMark(combined)) } else { err = h.nfq.SetVerdict(packetID, verdict) } @@ -165,10 +166,15 @@ func (h *nfqueueHandler) handlePacket(attr nfqueue.Attribute) int { payload := *attr.Payload if len(payload) < header.IPv4MinimumSize { - h.setVerdict(packetID, nfqueue.NfAccept, 0) + h.setVerdict(packetID, nfqueue.NfAccept, 0, 0) return 0 } + var existingMark uint32 + if attr.Mark != nil { + existingMark = *attr.Mark + } + var srcAddr, dstAddr M.Socksaddr var tcpOffset int @@ -176,7 +182,7 @@ func (h *nfqueueHandler) handlePacket(attr nfqueue.Attribute) int { if version == 4 { ipv4 := header.IPv4(payload) if !ipv4.IsValid(len(payload)) || ipv4.Protocol() != uint8(unix.IPPROTO_TCP) { - h.setVerdict(packetID, nfqueue.NfAccept, 0) + h.setVerdict(packetID, nfqueue.NfAccept, 0, 0) return 0 } srcAddr = M.SocksaddrFrom(ipv4.SourceAddr(), 0) @@ -185,7 +191,7 @@ func (h *nfqueueHandler) handlePacket(attr nfqueue.Attribute) int { } else if version == 6 { transportProto, transportOffset, ok := parseIPv6TransportHeader(payload) if !ok || transportProto != unix.IPPROTO_TCP { - h.setVerdict(packetID, nfqueue.NfAccept, 0) + h.setVerdict(packetID, nfqueue.NfAccept, 0, 0) return 0 } ipv6 := header.IPv6(payload) @@ -193,12 +199,12 @@ func (h *nfqueueHandler) handlePacket(attr nfqueue.Attribute) int { dstAddr = M.SocksaddrFrom(ipv6.DestinationAddr(), 0) tcpOffset = transportOffset } else { - h.setVerdict(packetID, nfqueue.NfAccept, 0) + h.setVerdict(packetID, nfqueue.NfAccept, 0, 0) return 0 } if len(payload) < tcpOffset+header.TCPMinimumSize { - h.setVerdict(packetID, nfqueue.NfAccept, 0) + h.setVerdict(packetID, nfqueue.NfAccept, 0, 0) return 0 } @@ -208,7 +214,7 @@ func (h *nfqueueHandler) handlePacket(attr nfqueue.Attribute) int { flags := tcp.Flags() if !flags.Contains(header.TCPFlagSyn) || flags.Contains(header.TCPFlagAck) { - h.setVerdict(packetID, nfqueue.NfAccept, 0) + h.setVerdict(packetID, nfqueue.NfAccept, 0, 0) return 0 } @@ -220,13 +226,13 @@ func (h *nfqueueHandler) handlePacket(attr nfqueue.Attribute) int { // the chain immediately, skipping any rules after the queue statement. switch { case errors.Is(pErr, ErrBypass): - h.setVerdict(packetID, nfqueue.NfRepeat, h.outputMark) + h.setVerdict(packetID, nfqueue.NfRepeat, h.outputMark, existingMark) case errors.Is(pErr, ErrReset): - h.setVerdict(packetID, nfqueue.NfRepeat, h.resetMark) + h.setVerdict(packetID, nfqueue.NfRepeat, h.resetMark, existingMark) case errors.Is(pErr, ErrDrop): - h.setVerdict(packetID, nfqueue.NfDrop, 0) + h.setVerdict(packetID, nfqueue.NfDrop, 0, 0) default: - h.setVerdict(packetID, nfqueue.NfAccept, 0) + h.setVerdict(packetID, nfqueue.NfAccept, 0, 0) } return 0 diff --git a/redirect.go b/redirect.go index dcf3e720..0dd29e8f 100644 --- a/redirect.go +++ b/redirect.go @@ -14,6 +14,11 @@ const ( DefaultAutoRedirectOutputMark = 0x2024 DefaultAutoRedirectResetMark = 0x2025 DefaultAutoRedirectNFQueue = 100 + + // AutoRedirectMarkMask defines which bits of the 32-bit mark field are + // reserved for auto_redirect loop prevention. Bits outside this mask + // (the upper 16) are available for routing_mark / WAN selection. + AutoRedirectMarkMask = 0x0000FFFF ) type AutoRedirect interface { diff --git a/redirect_nftables.go b/redirect_nftables.go index 266bbe91..5c6d100d 100644 --- a/redirect_nftables.go +++ b/redirect_nftables.go @@ -213,69 +213,29 @@ func (r *autoRedirect) setupNFTables() error { }, }, }) + inputMarkValue := binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark) + outputMarkValue := binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectOutputMark) + // If iface != tun AND (ct mark & mask) == InputMark → set InputMark on meta mark (preserving high bits) nft.AddRule(&nftables.Rule{ Table: table, Chain: chainPreRoutingUDP, - Exprs: []expr.Any{ - &expr.Meta{ - Key: expr.MetaKeyIIFNAME, - Register: 1, - }, - &expr.Cmp{ - Op: expr.CmpOpNeq, - Register: 1, - Data: nftablesIfname(r.tunOptions.Name), - }, - &expr.Ct{ - Key: expr.CtKeyMARK, - Register: 1, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark), - }, - &expr.Meta{ - Key: expr.MetaKeyMARK, - Register: 1, - SourceRegister: true, - }, + Exprs: append(append(append([]expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{Op: expr.CmpOpNeq, Register: 1, Data: nftablesIfname(r.tunOptions.Name)}, + }, maskedCtMarkCmp(expr.CmpOpEq, nftMarkMaskBytes, inputMarkValue)...), + maskedMetaMarkSet(inputMarkValue)...), &expr.Counter{}, - }, + ), }) + // If (ct mark & mask) != InputMark → set OutputMark (preserving high bits), copy to ct mark nft.AddRule(&nftables.Rule{ Table: table, Chain: chainPreRoutingUDP, - Exprs: []expr.Any{ - &expr.Ct{ - Key: expr.CtKeyMARK, - Register: 1, - }, - &expr.Cmp{ - Op: expr.CmpOpNeq, - Register: 1, - Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark), - }, - &expr.Immediate{ - Register: 1, - Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectOutputMark), - }, - &expr.Meta{ - Key: expr.MetaKeyMARK, - Register: 1, - SourceRegister: true, - }, - &expr.Meta{ - Key: expr.MetaKeyMARK, - Register: 1, - }, - &expr.Ct{ - Key: expr.CtKeyMARK, - Register: 1, - SourceRegister: true, - }, + Exprs: append(append( + maskedCtMarkCmp(expr.CmpOpNeq, nftMarkMaskBytes, inputMarkValue), + maskedMetaMarkSetWithCtCopy(outputMarkValue)...), &expr.Counter{}, - }, + ), }) } @@ -432,38 +392,36 @@ func (r *autoRedirect) nftablesAddPreMatchRules(nft *nftables.Conn, table *nftab nft.AddRule(&nftables.Rule{ Table: table, Chain: chain, - Exprs: []expr.Any{ - &expr.Meta{Key: expr.MetaKeyMARK, Register: 1}, - &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.NativeEndian.PutUint32(r.effectiveOutputMark())}, + Exprs: append( + maskedMetaMarkCmp(expr.CmpOpEq, nftMarkMaskBytes, + binaryutil.NativeEndian.PutUint32(r.effectiveOutputMark())), &expr.Ct{Key: expr.CtKeyMARK, Register: 1, SourceRegister: true}, &expr.Counter{}, &expr.Verdict{Kind: expr.VerdictReturn}, - }, + ), }) // Reset mark: reject with TCP RST. - // When the NFQUEUE handler returns NF_REPEAT with the reset mark, - // the packet re-enters this chain and is rejected here. nft.AddRule(&nftables.Rule{ Table: table, Chain: chain, - Exprs: []expr.Any{ - &expr.Meta{Key: expr.MetaKeyMARK, Register: 1}, - &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.NativeEndian.PutUint32(r.effectiveResetMark())}, + Exprs: append( + maskedMetaMarkCmp(expr.CmpOpEq, nftMarkMaskBytes, + binaryutil.NativeEndian.PutUint32(r.effectiveResetMark())), &expr.Counter{}, &expr.Reject{Type: unix.NFT_REJECT_TCP_RST}, - }, + ), }) // Already-tracked bypass connections: return immediately. nft.AddRule(&nftables.Rule{ Table: table, Chain: chain, - Exprs: []expr.Any{ - &expr.Ct{Key: expr.CtKeyMARK, Register: 1}, - &expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: binaryutil.NativeEndian.PutUint32(r.effectiveOutputMark())}, + Exprs: append( + maskedCtMarkCmp(expr.CmpOpEq, nftMarkMaskBytes, + binaryutil.NativeEndian.PutUint32(r.effectiveOutputMark())), &expr.Verdict{Kind: expr.VerdictReturn}, - }, + ), }) // TCP SYN: send to NFQUEUE for pre-match evaluation. diff --git a/redirect_nftables_rules.go b/redirect_nftables_rules.go index 1b9af3c2..d912b76b 100644 --- a/redirect_nftables_rules.go +++ b/redirect_nftables_rules.go @@ -26,6 +26,55 @@ func init() { allocSetID = 6 } +var ( + nftMarkMaskBytes = binaryutil.NativeEndian.PutUint32(AutoRedirectMarkMask) + nftMarkInvMaskBytes = binaryutil.NativeEndian.PutUint32(^uint32(AutoRedirectMarkMask)) + nftZeroBytes4 = []byte{0, 0, 0, 0} + nftAllOnesBytes4 = binaryutil.NativeEndian.PutUint32(0xFFFFFFFF) +) + +// maskedMetaMarkCmp returns exprs: (meta mark & mask) == value +func maskedMetaMarkCmp(op expr.CmpOp, mask, value []byte) []expr.Any { + return []expr.Any{ + &expr.Meta{Key: expr.MetaKeyMARK, Register: 1}, + &expr.Bitwise{SourceRegister: 1, DestRegister: 1, Len: 4, Mask: mask, Xor: nftZeroBytes4}, + &expr.Cmp{Op: op, Register: 1, Data: value}, + } +} + +// maskedCtMarkCmp returns exprs: (ct mark & mask) == value +func maskedCtMarkCmp(op expr.CmpOp, mask, value []byte) []expr.Any { + return []expr.Any{ + &expr.Ct{Key: expr.CtKeyMARK, Register: 1}, + &expr.Bitwise{SourceRegister: 1, DestRegister: 1, Len: 4, Mask: mask, Xor: nftZeroBytes4}, + &expr.Cmp{Op: op, Register: 1, Data: value}, + } +} + +// maskedMetaMarkSet returns exprs: meta mark = (meta mark & ~mask) | value +// Preserves bits outside the mask. +func maskedMetaMarkSet(value []byte) []expr.Any { + return []expr.Any{ + &expr.Meta{Key: expr.MetaKeyMARK, Register: 1}, + &expr.Bitwise{SourceRegister: 1, DestRegister: 1, Len: 4, Mask: nftMarkInvMaskBytes, Xor: nftZeroBytes4}, + &expr.Bitwise{SourceRegister: 1, DestRegister: 1, Len: 4, Mask: nftAllOnesBytes4, Xor: value}, + &expr.Meta{Key: expr.MetaKeyMARK, Register: 1, SourceRegister: true}, + } +} + +// maskedMetaMarkSetWithCtCopy returns exprs that set meta mark (preserving +// high bits) and copy the result to ct mark. +func maskedMetaMarkSetWithCtCopy(value []byte) []expr.Any { + return []expr.Any{ + &expr.Meta{Key: expr.MetaKeyMARK, Register: 1}, + &expr.Bitwise{SourceRegister: 1, DestRegister: 1, Len: 4, Mask: nftMarkInvMaskBytes, Xor: nftZeroBytes4}, + &expr.Bitwise{SourceRegister: 1, DestRegister: 1, Len: 4, Mask: nftAllOnesBytes4, Xor: value}, + &expr.Meta{Key: expr.MetaKeyMARK, Register: 1, SourceRegister: true}, + &expr.Meta{Key: expr.MetaKeyMARK, Register: 1}, + &expr.Ct{Key: expr.CtKeyMARK, Register: 1, SourceRegister: true}, + } +} + func (r *autoRedirect) nftablesCreateAddressSets( nft *nftables.Conn, table *nftables.Table, update bool, @@ -171,41 +220,23 @@ func (r *autoRedirect) nftablesCreateExcludeRules(nft *nftables.Conn, table *nft nft.AddRule(&nftables.Rule{ Table: table, Chain: chain, - Exprs: []expr.Any{ - &expr.Meta{ - Key: expr.MetaKeyMARK, - Register: 1, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectOutputMark), - }, + Exprs: append( + maskedMetaMarkCmp(expr.CmpOpEq, nftMarkMaskBytes, + binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectOutputMark)), &expr.Counter{}, - &expr.Verdict{ - Kind: expr.VerdictReturn, - }, - }, + &expr.Verdict{Kind: expr.VerdictReturn}, + ), }) if chain.Type == nftables.ChainTypeRoute { nft.AddRule(&nftables.Rule{ Table: table, Chain: chain, - Exprs: []expr.Any{ - &expr.Ct{ - Key: expr.CtKeyMARK, - Register: 1, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectOutputMark), - }, + Exprs: append( + maskedCtMarkCmp(expr.CmpOpEq, nftMarkMaskBytes, + binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectOutputMark)), &expr.Counter{}, - &expr.Verdict{ - Kind: expr.VerdictReturn, - }, - }, + &expr.Verdict{Kind: expr.VerdictReturn}, + ), }) } } @@ -213,42 +244,24 @@ func (r *autoRedirect) nftablesCreateExcludeRules(nft *nftables.Conn, table *nft nft.AddRule(&nftables.Rule{ Table: table, Chain: chain, - Exprs: []expr.Any{ - &expr.Ct{ - Key: expr.CtKeyMARK, - Register: 1, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: binaryutil.NativeEndian.PutUint32(r.effectiveOutputMark()), - }, + Exprs: append( + maskedCtMarkCmp(expr.CmpOpEq, nftMarkMaskBytes, + binaryutil.NativeEndian.PutUint32(r.effectiveOutputMark())), &expr.Counter{}, - &expr.Verdict{ - Kind: expr.VerdictReturn, - }, - }, + &expr.Verdict{Kind: expr.VerdictReturn}, + ), }) } if r.nfqueueEnabled && chain.Hooknum == nftables.ChainHookOutput && chain.Type == nftables.ChainTypeNAT { nft.AddRule(&nftables.Rule{ Table: table, Chain: chain, - Exprs: []expr.Any{ - &expr.Ct{ - Key: expr.CtKeyMARK, - Register: 1, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: binaryutil.NativeEndian.PutUint32(r.effectiveOutputMark()), - }, + Exprs: append( + maskedCtMarkCmp(expr.CmpOpEq, nftMarkMaskBytes, + binaryutil.NativeEndian.PutUint32(r.effectiveOutputMark())), &expr.Counter{}, - &expr.Verdict{ - Kind: expr.VerdictReturn, - }, - }, + &expr.Verdict{Kind: expr.VerdictReturn}, + ), }) } if chain.Hooknum == nftables.ChainHookPrerouting { @@ -615,33 +628,12 @@ func (r *autoRedirect) nftablesCreateExcludeRules(nft *nftables.Conn, table *nft } func (r *autoRedirect) nftablesCreateMark(nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain) { + markValue := binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark) + markExprs := maskedMetaMarkSetWithCtCopy(markValue) nft.AddRule(&nftables.Rule{ Table: table, Chain: chain, - Exprs: []expr.Any{ - &expr.Immediate{ - Register: 1, - Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark), - }, - &expr.Meta{ - Key: expr.MetaKeyMARK, - Register: 1, - SourceRegister: true, - }, - &expr.Meta{ - Key: expr.MetaKeyMARK, - Register: 1, - }, // output meta mark set myMark ct mark set meta mark - &expr.Ct{ - Key: expr.CtKeyMARK, - Register: 1, - SourceRegister: true, - }, - &expr.Counter{}, - &expr.Verdict{ - Kind: expr.VerdictReturn, - }, - }, + Exprs: append(markExprs, &expr.Counter{}, &expr.Verdict{Kind: expr.VerdictReturn}), }) } @@ -749,7 +741,7 @@ func (r *autoRedirect) nftablesCreateRedirect( func (r *autoRedirect) nftablesCreateLoopbackReroute( nft *nftables.Conn, table *nftables.Table, chain *nftables.Chain, ) error { - exprs := []expr.Any{ + exprs := append([]expr.Any{ &expr.Meta{ Key: expr.MetaKeyL4PROTO, Register: 1, @@ -759,16 +751,8 @@ func (r *autoRedirect) nftablesCreateLoopbackReroute( Register: 1, Data: []byte{unix.IPPROTO_TCP}, }, - &expr.Meta{ - Key: expr.MetaKeyMARK, - Register: 1, - }, - &expr.Cmp{ - Op: expr.CmpOpNeq, - Register: 1, - Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark), - }, - } + }, maskedMetaMarkCmp(expr.CmpOpNeq, nftMarkMaskBytes, + binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark))...) var exprs4 []expr.Any if r.enableIPv4 && len(r.tunOptions.Inet4LoopbackAddress) > 0 { exprs4 = append(exprs, nftablesCreateDestinationIPSetExprs(7, "inet4_local_redirect_address_set", nftables.TableFamilyIPv4, false)...) @@ -777,42 +761,12 @@ func (r *autoRedirect) nftablesCreateLoopbackReroute( if r.enableIPv6 && len(r.tunOptions.Inet6LoopbackAddress) > 0 { exprs6 = append(exprs, nftablesCreateDestinationIPSetExprs(8, "inet6_local_redirect_address_set", nftables.TableFamilyIPv6, false)...) } + markValue := binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark) var exprsCreateMark []expr.Any if chain.Hooknum == nftables.ChainHookPrerouting { - exprsCreateMark = []expr.Any{ - &expr.Immediate{ - Register: 1, - Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark), - }, - &expr.Meta{ - Key: expr.MetaKeyMARK, - Register: 1, - SourceRegister: true, - }, - &expr.Counter{}, - } + exprsCreateMark = append(maskedMetaMarkSet(markValue), &expr.Counter{}) } else { - exprsCreateMark = []expr.Any{ - &expr.Immediate{ - Register: 1, - Data: binaryutil.NativeEndian.PutUint32(r.tunOptions.AutoRedirectInputMark), - }, - &expr.Meta{ - Key: expr.MetaKeyMARK, - Register: 1, - SourceRegister: true, - }, - &expr.Meta{ - Key: expr.MetaKeyMARK, - Register: 1, - }, - &expr.Ct{ - Key: expr.CtKeyMARK, - Register: 1, - SourceRegister: true, - }, - &expr.Counter{}, - } + exprsCreateMark = append(maskedMetaMarkSetWithCtCopy(markValue), &expr.Counter{}) } if len(exprs4) > 0 { exprs4 = append(exprs4, exprsCreateMark...) diff --git a/tun_linux.go b/tun_linux.go index 20fdce23..b63ba886 100644 --- a/tun_linux.go +++ b/tun_linux.go @@ -587,10 +587,12 @@ func (t *NativeTun) rules() []*netlink.Rule { priority6 := priority if t.options.AutoRedirectMarkMode { + fwMask := int(AutoRedirectMarkMask) if p4 { it = netlink.NewRule() it.Priority = priority it.Mark = t.options.AutoRedirectOutputMark + it.Mask = fwMask it.MarkSet = true it.Goto = priority + 2 it.Family = unix.AF_INET @@ -600,6 +602,7 @@ func (t *NativeTun) rules() []*netlink.Rule { it = netlink.NewRule() it.Priority = priority it.Mark = t.options.AutoRedirectInputMark + it.Mask = fwMask it.MarkSet = true it.Table = t.options.IPRoute2TableIndex it.Family = unix.AF_INET @@ -615,6 +618,7 @@ func (t *NativeTun) rules() []*netlink.Rule { it = netlink.NewRule() it.Priority = priority6 it.Mark = t.options.AutoRedirectOutputMark + it.Mask = fwMask it.MarkSet = true it.Goto = priority6 + 2 it.Family = unix.AF_INET6 @@ -624,6 +628,7 @@ func (t *NativeTun) rules() []*netlink.Rule { it = netlink.NewRule() it.Priority = priority6 it.Mark = t.options.AutoRedirectInputMark + it.Mask = fwMask it.MarkSet = true it.Table = t.options.IPRoute2TableIndex it.Family = unix.AF_INET6