diff --git a/android/src/main/java/com/tailscale/ipn/App.kt b/android/src/main/java/com/tailscale/ipn/App.kt index abf4ea9e47..255b266458 100644 --- a/android/src/main/java/com/tailscale/ipn/App.kt +++ b/android/src/main/java/com/tailscale/ipn/App.kt @@ -314,6 +314,8 @@ class App : UninitializedApp(), libtailscale.AppContext, ViewModelStoreOwner { override fun getOSVersion(): String = Build.VERSION.RELEASE + override fun getSDKInt(): Long = Build.VERSION.SDK_INT.toLong() + override fun isChromeOS(): Boolean { return packageManager.hasSystemFeature("android.hardware.type.pc") } diff --git a/android/src/main/java/com/tailscale/ipn/ui/view/ExitNodePicker.kt b/android/src/main/java/com/tailscale/ipn/ui/view/ExitNodePicker.kt index 827925b40e..7e8b02f008 100644 --- a/android/src/main/java/com/tailscale/ipn/ui/view/ExitNodePicker.kt +++ b/android/src/main/java/com/tailscale/ipn/ui/view/ExitNodePicker.kt @@ -100,9 +100,7 @@ fun ExitNodePicker( } } - // https://developer.android.com/reference/android/net/VpnService.Builder#excludeRoute(android.net.IpPrefix) - excludeRoute is only supported in API 33+, so don't show the option if allow LAN access is not enabled. - if (!allowLanAccessMDMDisposition.value.hiddenFromUser && - Build.VERSION.SDK_INT >= Build.VERSION_CODES.TIRAMISU) { + if (!allowLanAccessMDMDisposition.value.hiddenFromUser) { item(key = "allowLANAccess") { Lists.SectionDivider() diff --git a/libtailscale/interfaces.go b/libtailscale/interfaces.go index cb6ac258dd..7f6c54bb58 100644 --- a/libtailscale/interfaces.go +++ b/libtailscale/interfaces.go @@ -36,6 +36,9 @@ type AppContext interface { // GetOSVersion gets the Android version. GetOSVersion() (string, error) + // GetSDKInt returns the Android SDK_INT (android.os.Build.VERSION.SDK_INT). + GetSDKInt() (int, error) + // GetDeviceName gets the Android device's user-set name, or hardware model name as a fallback. GetDeviceName() (string, error) diff --git a/libtailscale/net.go b/libtailscale/net.go index b005c98821..17986e0c00 100644 --- a/libtailscale/net.go +++ b/libtailscale/net.go @@ -13,6 +13,7 @@ import ( "syscall" "github.com/tailscale/tailscale-android/libtailscale/ifaceparse" + rangescalc "github.com/tailscale/tailscale-android/libtailscale/ranges_calc" "github.com/tailscale/wireguard-go/tun" "tailscale.com/net/dns" "tailscale.com/net/netmon" @@ -134,25 +135,69 @@ func (b *backend) updateTUN(rcfg *router.Config, dcfg *dns.OSConfig) (err error) b.logger.Logf("updateTUN: set nameservers") } - for _, route := range rcfg.Routes { - // Normalize route address; Builder.addRoute does not accept non-zero masked bits. - route = route.Masked() - if err := builder.AddRoute(route.Addr().String(), int32(route.Bits())); err != nil { - return err - } + // Decide whether to use ExcludeRoute (API 33+) or compute included prefixes + // and pass them to AddRoute (older APIs). + useExclude := false + if sdk, err := b.appCtx.GetSDKInt(); err == nil && sdk >= 33 { + useExclude = true } - for _, route := range rcfg.LocalRoutes { - addr := route.Addr() - if addr.IsLoopback() { - continue // Skip the loopback addresses since VpnService throws an exception for those (both IPv4 and IPv6) - see https://android.googlesource.com/platform/frameworks/base/+/c741553/core/java/android/net/VpnService.java#303 + if useExclude { + // For API 33+, use ExcludeRoute for LocalRoutes and AddRoute for Routes. + for _, route := range rcfg.Routes { + // Normalize route address; Builder.addRoute does not accept non-zero masked bits. + route = route.Masked() + if err := builder.AddRoute(route.Addr().String(), int32(route.Bits())); err != nil { + return err + } + } + + for _, route := range rcfg.LocalRoutes { + addr := route.Addr() + if addr.IsLoopback() { + continue // Skip the loopback addresses since VpnService throws an exception for those (both IPv4 and IPv6) - see https://android.googlesource.com/platform/frameworks/base/+/c741553/core/java/android/net/VpnService.java#303 + } + route = route.Masked() + if err := builder.ExcludeRoute(route.Addr().String(), int32(route.Bits())); err != nil { + return err + } } - route = route.Masked() - if err := builder.ExcludeRoute(route.Addr().String(), int32(route.Bits())); err != nil { + + b.logger.Logf("updateTUN: added %d routes (exclude-mode), localRoutes=%d", len(rcfg.Routes), len(rcfg.LocalRoutes)) + } else { + // Older APIs: compute allowed-minus-disallowed prefixes and AddRoute them. + prefixesV4, prefixesV6, err := rangescalc.Calculate(rcfg.Routes, rcfg.LocalRoutes) + if err != nil { + b.logger.Logf("updateTUN: route calculation error: %v", err) return err } + + for _, route := range prefixesV4 { + route = route.Masked() + if err := builder.AddRoute(route.Addr().String(), int32(route.Bits())); err != nil { + return err + } + } + for _, route := range prefixesV6 { + route = route.Masked() + if err := builder.AddRoute(route.Addr().String(), int32(route.Bits())); err != nil { + return err + } + } + + b.logger.Logf( + "updateTUN: added routes: v4=%d v6=%d total=%d (input routes=%d, localRoutes=%d)", + len(prefixesV4), + len(prefixesV6), + len(prefixesV4)+len(prefixesV6), + len(rcfg.Routes), + len(rcfg.LocalRoutes), + ) + b.logger.Logf("updateTUN: input routes: %v", rcfg.Routes) + b.logger.Logf("updateTUN: input local routes: %v", rcfg.LocalRoutes) + b.logger.Logf("updateTUN: effective routes v4: %v", prefixesV4) + b.logger.Logf("updateTUN: effective routes v6: %v", prefixesV6) } - b.logger.Logf("updateTUN: added %d routes", len(rcfg.Routes)) for _, addr := range rcfg.LocalAddrs { if err := builder.AddAddress(addr.Addr().String(), int32(addr.Bits())); err != nil { diff --git a/libtailscale/ranges_calc/ranges_calc.go b/libtailscale/ranges_calc/ranges_calc.go new file mode 100644 index 0000000000..dd70464a33 --- /dev/null +++ b/libtailscale/ranges_calc/ranges_calc.go @@ -0,0 +1,275 @@ +package ranges_calc + +import ( + "fmt" + "math/big" + "net/netip" + "sort" +) + +// Internal representation of an IP range [Start, End] (inclusive) +type ipRange struct { + Start netip.Addr + End netip.Addr +} + +// space describes the address space (32 for IPv4, 128 for IPv6) +type space struct { + bits uint +} + +// ---------- netip.Addr <-> big.Int ---------- +func (s space) addrToInt(a netip.Addr) *big.Int { + if s.bits == 32 { + b := a.As4() + return new(big.Int).SetBytes(b[:]) + } + b := a.As16() + return new(big.Int).SetBytes(b[:]) +} + +func (s space) intToAddr(i *big.Int) netip.Addr { + b := i.FillBytes(make([]byte, s.bits/8)) + if s.bits == 32 { + var a [4]byte + copy(a[:], b) + return netip.AddrFrom4(a) + } + var a [16]byte + copy(a[:], b) + return netip.AddrFrom16(a) +} + +// ---------- merge overlapping ranges ---------- +func (s space) mergeRanges(ranges []ipRange) []ipRange { + if len(ranges) == 0 { + return nil + } + sort.Slice(ranges, func(i, j int) bool { + return ranges[i].Start.Compare(ranges[j].Start) < 0 + }) + merged := []ipRange{ranges[0]} + one := big.NewInt(1) + for _, r := range ranges[1:] { + last := &merged[len(merged)-1] + lastEnd := s.addrToInt(last.End) + curStart := s.addrToInt(r.Start) + if curStart.Cmp(new(big.Int).Add(lastEnd, one)) <= 0 { + if r.End.Compare(last.End) > 0 { + last.End = r.End + } + } else { + merged = append(merged, r) + } + } + return merged +} + +// ---------- range -> minimal number of CIDRs ---------- +// Every IP range defined by a start and end address can be represented +// by one or more CIDR prefixes. This function calculates the minimal set of CIDR +// prefixes that cover the given range. +func (s space) rangeToCIDRs(r ipRange) []netip.Prefix { + var result []netip.Prefix + cur := s.addrToInt(r.Start) + last := s.addrToInt(r.End) + one := big.NewInt(1) + + for cur.Cmp(last) <= 0 { + // Find the largest power-of-2 block starting at cur + var maxSize uint + for size := uint(0); size <= s.bits; size++ { + block := new(big.Int).Lsh(one, size) + if new(big.Int).And(cur, new(big.Int).Sub(block, one)).Cmp(big.NewInt(0)) != 0 { + break + } + maxSize = size + } + + // Shrink maxSize if it would go past last + for { + block := new(big.Int).Lsh(one, maxSize) + lastAddr := new(big.Int).Add(cur, new(big.Int).Sub(block, one)) + if lastAddr.Cmp(last) <= 0 { + break + } + if maxSize == 0 { + break + } + maxSize-- + } + + prefixLen := int(s.bits - maxSize) + result = append(result, netip.PrefixFrom(s.intToAddr(cur), prefixLen)) + cur = cur.Add(cur, new(big.Int).Lsh(one, maxSize)) + } + + return result +} + +// ---------- CIDR -> range ---------- +// prefixToRange converts a netip.Prefix to an ipRange with Start and End addresses. +// Start is the network address and End is the broadcast address. +func (s space) prefixToRange(p netip.Prefix) ipRange { + start := s.addrToInt(p.Addr()) + hostBits := int(s.bits) - p.Bits() + size := new(big.Int).Lsh(big.NewInt(1), uint(hostBits)) + size.Sub(size, big.NewInt(1)) + end := new(big.Int).Add(start, size) + return ipRange{Start: p.Addr(), End: s.intToAddr(end)} +} + +// ---------- helper: subtract disallowed from allowed ---------- +func (s space) subtractRanges(allowed []ipRange, disallowed []ipRange) []ipRange { + if len(allowed) == 0 { + return nil + } + if len(disallowed) == 0 { + return allowed + } + + var result []ipRange + for _, a := range allowed { + cur := []ipRange{a} + for _, d := range disallowed { + cur2 := []ipRange{} + for _, r := range cur { + cur2 = append(cur2, s.subtractOneRange(r, d)...) + } + cur = cur2 + if len(cur) == 0 { + break + } + } + result = append(result, cur...) + } + return s.mergeRanges(result) +} + +// subtractOneRange subtracts a single disallowed range from a single allowed range +func (s space) subtractOneRange(allowed ipRange, disallowed ipRange) []ipRange { + aStart := s.addrToInt(allowed.Start) + aEnd := s.addrToInt(allowed.End) + dStart := s.addrToInt(disallowed.Start) + dEnd := s.addrToInt(disallowed.End) + one := big.NewInt(1) + + // No overlap + if aEnd.Cmp(dStart) < 0 || aStart.Cmp(dEnd) > 0 { + return []ipRange{allowed} + } + + var result []ipRange + + // left side + if aStart.Cmp(dStart) < 0 { + result = append(result, ipRange{ + Start: allowed.Start, + End: s.intToAddr(new(big.Int).Sub(dStart, one)), + }) + } + + // right side + if aEnd.Cmp(dEnd) > 0 { + result = append(result, ipRange{ + Start: s.intToAddr(new(big.Int).Add(dEnd, one)), + End: allowed.End, + }) + } + + return result +} + +// rangesCalc performs the calculation: Routes (allowed) minus LocalRoutes (disallowed) +type rangesCalc struct { + allowed []netip.Prefix + disallowed []netip.Prefix +} + +func newRangesCalc(routes, localRoutes []netip.Prefix) *rangesCalc { + return &rangesCalc{allowed: routes, disallowed: localRoutes} +} + +const maxCalculatedRoutes = 500 + +// calculate computes allowed routes (Routes minus LocalRoutes) and returns +// separate IPv4 and IPv6 prefix lists. If the resulting route set exceeds +// a conservative cap, an error is returned so the caller can fail fast. +func (rc *rangesCalc) calculate() (ipv4 []netip.Prefix, ipv6 []netip.Prefix, err error) { + var out4 []netip.Prefix + var out6 []netip.Prefix + + // Collect IPv4 and IPv6 separately + var allowed4 []ipRange + var disallowed4 []ipRange + var allowed6 []ipRange + var disallowed6 []ipRange + + for _, p := range rc.allowed { + if p.Addr().Is4() { + s := space{bits: 32} + r := s.prefixToRange(p) + allowed4 = append(allowed4, r) + } else { + s := space{bits: 128} + r := s.prefixToRange(p) + allowed6 = append(allowed6, r) + } + } + + for _, p := range rc.disallowed { + // Skip loopback prefixes; mirror behavior of ExcludeRoutes handling. + if p.Addr().IsLoopback() { + continue + } + if p.Addr().Is4() { + s := space{bits: 32} + r := s.prefixToRange(p) + disallowed4 = append(disallowed4, r) + } else { + s := space{bits: 128} + r := s.prefixToRange(p) + disallowed6 = append(disallowed6, r) + } + } + + // Process IPv4 + if len(allowed4) > 0 { + s := space{bits: 32} + mergedAllowed := s.mergeRanges(allowed4) + mergedDisallowed := s.mergeRanges(disallowed4) + finalAllowed := s.subtractRanges(mergedAllowed, mergedDisallowed) + for _, r := range finalAllowed { + for _, pref := range s.rangeToCIDRs(r) { + out4 = append(out4, pref) + } + } + } + + // Process IPv6 + if len(allowed6) > 0 { + s := space{bits: 128} + mergedAllowed := s.mergeRanges(allowed6) + mergedDisallowed := s.mergeRanges(disallowed6) + finalAllowed := s.subtractRanges(mergedAllowed, mergedDisallowed) + for _, r := range finalAllowed { + for _, pref := range s.rangeToCIDRs(r) { + out6 = append(out6, pref) + } + } + } + + total := len(out4) + len(out6) + if total > maxCalculatedRoutes { + return nil, nil, fmt.Errorf("calculated routes (%d) exceed cap (%d)", total, maxCalculatedRoutes) + } + + return out4, out6, nil +} + +// Calculate is the exported helper that computes effective allowed prefixes +// given allowed routes and localRoutes to exclude. +func Calculate(routes, localRoutes []netip.Prefix) (ipv4 []netip.Prefix, ipv6 []netip.Prefix, err error) { + rc := newRangesCalc(routes, localRoutes) + return rc.calculate() +} diff --git a/libtailscale/ranges_calc/ranges_calc_test.go b/libtailscale/ranges_calc/ranges_calc_test.go new file mode 100644 index 0000000000..301788c003 --- /dev/null +++ b/libtailscale/ranges_calc/ranges_calc_test.go @@ -0,0 +1,72 @@ +package ranges_calc + +import ( + "fmt" + "net/netip" + "testing" +) + +func TestCalculate_NoDisallowed(t *testing.T) { + allowed := []netip.Prefix{} + p, _ := netip.ParsePrefix("10.0.0.0/8") + allowed = append(allowed, p) + + v4, v6, err := Calculate(allowed, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(v6) != 0 { + t.Fatalf("expected no IPv6 prefixes, got %d", len(v6)) + } + if len(v4) == 0 { + t.Fatalf("expected some IPv4 prefixes, got none") + } +} + +func TestCalculate_LoopbackIgnored(t *testing.T) { + allowed := []netip.Prefix{} + a, _ := netip.ParsePrefix("127.0.0.0/8") + allowed = append(allowed, a) + + // disallowed contains a loopback address which should be ignored. + d := []netip.Prefix{} + lp, _ := netip.ParsePrefix("127.0.0.1/32") + d = append(d, lp) + + v4a, _, err := Calculate(allowed, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + v4b, _, err := Calculate(allowed, d) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Results should be identical because loopback in disallowed is skipped. + if len(v4a) != len(v4b) { + t.Fatalf("loopback disallowed altered result: before=%d after=%d", len(v4a), len(v4b)) + } +} + +func TestCalculate_CapExceeded(t *testing.T) { + // Create more than maxCalculatedRoutes separate /32 prefixes. + want := maxCalculatedRoutes + 1 + allowed := make([]netip.Prefix, 0, want) + for i := 0; i < want; i++ { + // Generate addresses 10.X.Y.1 where X = i/256, Y = i%256 + x := (i / 256) % 256 + y := i % 256 + s := fmt.Sprintf("10.%d.%d.1/32", x, y) + p, err := netip.ParsePrefix(s) + if err != nil { + t.Fatalf("parse prefix %q: %v", s, err) + } + allowed = append(allowed, p) + } + + _, _, err := Calculate(allowed, nil) + if err == nil { + t.Fatalf("expected error when exceeding cap (%d), got nil", maxCalculatedRoutes) + } +}