From 5b51ddd927416f73bef522009f4760db44a94a39 Mon Sep 17 00:00:00 2001 From: Vladimir Parfenov Date: Sat, 13 Dec 2025 12:09:26 +0200 Subject: [PATCH 1/3] feat: native go tunnel and connection management --- cmd/root.go | 6 +- cmd/setup.go | 63 ++++ go.mod | 10 + go.sum | 22 ++ pkg/console/console.go | 117 ------- pkg/console/console_test.go | 3 - pkg/vpn/vpn.go | 251 +++++++++------ pkg/vpn/vpn_test.go | 468 ++++++---------------------- pkg/wireguard/linux_manager.go | 399 ++++++++++++++++++++++++ pkg/wireguard/linux_manager_test.go | 261 ++++++++++++++++ pkg/wireguard/manager.go | 73 +++++ 11 files changed, 1085 insertions(+), 588 deletions(-) create mode 100644 cmd/setup.go delete mode 100644 pkg/console/console.go delete mode 100644 pkg/console/console_test.go create mode 100644 pkg/wireguard/linux_manager.go create mode 100644 pkg/wireguard/linux_manager_test.go create mode 100644 pkg/wireguard/manager.go diff --git a/cmd/root.go b/cmd/root.go index 5993407..49ee7b9 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -72,7 +72,11 @@ func init() { holocron := remote.NewDefaultHolocron(machineIdProvider) sm := session.NewDefaultSessionManager(cp, holocron) ss := servers.NewDefaultServerStorage(dirProvider) - vpn := vpn.NewDefaultVpn(cp, holocron, ss, dirProvider) + vpn, err := vpn.NewDefaultVpn(cp, holocron, ss, dirProvider) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: failed to initialize VPN manager: %v\n", err) + os.Exit(1) + } rootCmd.AddCommand(NewLoginCommand(sm)) rootCmd.AddCommand(NewLogoutCommand(sm)) diff --git a/cmd/setup.go b/cmd/setup.go new file mode 100644 index 0000000..d532fe3 --- /dev/null +++ b/cmd/setup.go @@ -0,0 +1,63 @@ +package cmd + +import ( + "fmt" + "os" + "os/exec" + + "github.com/malwarebytes/mbvpn-linux/pkg/output" + "github.com/spf13/cobra" +) + +func NewSetupCommand() *cobra.Command { + cmd := &cobra.Command{ + Use: "setup", + Short: "Set up mbvpn with required capabilities", + Long: `Sets the CAP_NET_ADMIN capability on the mbvpn binary. + +This command must be run as root (using sudo) and only needs to be run once +after installation. After setup, all VPN commands can be run without sudo. + +Example: + sudo mbvpn setup`, + RunE: func(cmd *cobra.Command, args []string) error { + return runSetup() + }, + } + + return cmd +} + +func runSetup() error { + // Check if running as root + if os.Geteuid() != 0 { + return fmt.Errorf("setup must be run as root (use: sudo mbvpn setup)") + } + + // Get the path to the current executable + execPath, err := os.Executable() + if err != nil { + return fmt.Errorf("failed to get executable path: %w", err) + } + + output.PrintMsg(fmt.Sprintf("Setting up capabilities for: %s", execPath), output.MsgOutput) + + // Set CAP_NET_ADMIN capability using setcap + setcapCmd := exec.Command("setcap", "cap_net_admin+ep", execPath) + setcapCmd.Stdout = os.Stdout + setcapCmd.Stderr = os.Stderr + + if err := setcapCmd.Run(); err != nil { + return fmt.Errorf("failed to set capabilities: %w\n\nMake sure 'setcap' is installed (usually in libcap2-bin package)", err) + } + + output.PrintMsg("Setup complete. You can now run mbvpn commands without sudo.", output.MsgSuccess) + output.PrintMsg("", output.MsgOutput) + output.PrintMsg("Note: If you reinstall or update mbvpn, you'll need to run 'sudo mbvpn setup' again.", output.MsgOutput) + + return nil +} + +func init() { + rootCmd.AddCommand(NewSetupCommand()) +} diff --git a/go.mod b/go.mod index bc88f97..73f5a2f 100644 --- a/go.mod +++ b/go.mod @@ -8,19 +8,29 @@ require ( github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.8.1 github.com/stretchr/testify v1.10.0 + github.com/vishvananda/netlink v1.3.0 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 gopkg.in/yaml.v3 v3.0.1 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect + github.com/google/go-cmp v0.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/josharian/native v1.1.0 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mdlayher/genetlink v1.3.2 // indirect + github.com/mdlayher/netlink v1.7.2 // indirect + github.com/mdlayher/socket v0.5.1 // indirect github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/vishvananda/netns v0.0.4 // indirect golang.org/x/crypto v0.35.0 // indirect + golang.org/x/net v0.33.0 // indirect + golang.org/x/sync v0.10.0 // indirect golang.org/x/sys v0.30.0 // indirect + golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 // indirect gopkg.in/check.v1 v1.0.0-20200902074654-038fdea0a05b // indirect ) diff --git a/go.sum b/go.sum index 20c4592..1415675 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/josharian/native v1.1.0 h1:uuaP0hAbW7Y4l0ZRQ6C9zfb7Mg1mbFKry/xzDAfmtLA= +github.com/josharian/native v1.1.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= @@ -18,6 +20,14 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= +github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= +github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= +github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= +github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= +github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= +github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws= +github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWbfPhv4DMiApHyliiK5xCTNVSPiaAs= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -33,13 +43,25 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/vishvananda/netlink v1.3.0 h1:X7l42GfcV4S6E4vHTsw48qbrV+9PVojNfIhZcwQdrZk= +github.com/vishvananda/netlink v1.3.0/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs= +github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= +github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs= golang.org/x/crypto v0.35.0/go.mod h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ= +golang.org/x/net v0.33.0 h1:74SYHlV8BIgHIFC/LrYkOGIwL19eTYXQ5wc6TBuO36I= +golang.org/x/net v0.33.0/go.mod h1:HXLR5J+9DxmrqMwG9qjGCxZ+zKXxBru04zlTvWlWuN4= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173 h1:/jFs0duh4rdb8uIfPMv78iAJGcPKDeqAFnaLBropIC4= +golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU= golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/pkg/console/console.go b/pkg/console/console.go deleted file mode 100644 index 4e3bca1..0000000 --- a/pkg/console/console.go +++ /dev/null @@ -1,117 +0,0 @@ -package console - -import ( - "fmt" - "os" - "os/exec" - "path/filepath" - "strings" - - "github.com/malwarebytes/mbvpn-linux/pkg/config" -) - -func WgShow() (string, error) { - var stdout strings.Builder - var stderr strings.Builder - - cmd := exec.Command("sudo", "wg", "show") - - cmd.Stdout = &stdout - cmd.Stderr = &stderr - - err := cmd.Run() - if err != nil { - return stdout.String(), fmt.Errorf("wg show command execution failed: %w, stderr: %v, stdout: %v", err, stderr.String(), stdout.String()) - } - - return stdout.String(), nil -} - -func WgUp(cfgPath string, dirProvider config.DirectoryProvider) error { - cleanedPath, err := sanitizeWgConfigPath(cfgPath, dirProvider) - if err != nil { - return fmt.Errorf("invalid WireGuard config path: %w", err) - } - - var stdout strings.Builder - var stderr strings.Builder - - cmd := exec.Command("sudo", "wg-quick", "up", cleanedPath) - - cmd.Stdout = &stdout - cmd.Stderr = &stderr - - err = cmd.Run() - if err != nil { - return fmt.Errorf("wg-quick up command execution failed: %w, stderr: %v, stdout: %v", err, stderr.String(), stdout.String()) - } - - return nil -} - -func WgDown(cfgPath string, dirProvider config.DirectoryProvider) error { - cleanedPath, err := sanitizeWgConfigPath(cfgPath, dirProvider) - if err != nil { - return fmt.Errorf("invalid WireGuard config path: %w", err) - } - - var stdout strings.Builder - var stderr strings.Builder - - cmd := exec.Command("sudo", "wg-quick", "down", cleanedPath) - - cmd.Stdout = &stdout - cmd.Stderr = &stderr - - err = cmd.Run() - if err != nil { - return fmt.Errorf("wg-quick down command execution failed: %w, stderr: %v, stdout: %v", err, stderr.String(), stdout.String()) - } - - return nil -} - -// TODO: avoid passing dirProvider here -func sanitizeWgConfigPath(cfgPath string, dirProvider config.DirectoryProvider) (string, error) { - cleanPath := cfgPath - - // Expand home directory if present - home, err := os.UserHomeDir() - if err != nil { - return "", fmt.Errorf("failed to get home directory: %w", err) - } else if strings.HasPrefix(cleanPath, "~/") { - cleanPath = filepath.Join(home, cleanPath[2:]) - } - - // Convert to absolute path - absPath, err := filepath.Abs(cleanPath) - if err != nil { - return "", fmt.Errorf("failed to resolve absolute path: %w", err) - } - - // Verify the file exists and is a regular file - fileInfo, err := os.Stat(absPath) - if err != nil { - return "", fmt.Errorf("config file not accessible: %w", err) - } - - if !fileInfo.Mode().IsRegular() { - return "", fmt.Errorf("config path is not a regular file") - } - - // Ensure it's within the expected config directory - expectedBase, err := dirProvider.GetServersDir() - if err != nil { - return "", fmt.Errorf("failed to get servers directory: %w", err) - } - if !strings.HasPrefix(absPath, expectedBase) { - return "", fmt.Errorf("config file must be within %s", expectedBase) - } - - // Verify file extension (WireGuard configs must be .conf) - if filepath.Ext(absPath) != ".conf" { - return "", fmt.Errorf("config file must have .conf extension") - } - - return absPath, nil -} diff --git a/pkg/console/console_test.go b/pkg/console/console_test.go deleted file mode 100644 index 86c6b9e..0000000 --- a/pkg/console/console_test.go +++ /dev/null @@ -1,3 +0,0 @@ -package console - -//TODO: implement unit tests for wg commands diff --git a/pkg/vpn/vpn.go b/pkg/vpn/vpn.go index 015e550..d1a36de 100644 --- a/pkg/vpn/vpn.go +++ b/pkg/vpn/vpn.go @@ -2,16 +2,15 @@ package vpn import ( "fmt" - "os" - "path/filepath" + "net" "strings" "github.com/malwarebytes/mbvpn-linux/pkg/config" - "github.com/malwarebytes/mbvpn-linux/pkg/console" "github.com/malwarebytes/mbvpn-linux/pkg/errors" "github.com/malwarebytes/mbvpn-linux/pkg/output" "github.com/malwarebytes/mbvpn-linux/pkg/remote" "github.com/malwarebytes/mbvpn-linux/pkg/servers" + "github.com/malwarebytes/mbvpn-linux/pkg/wireguard" log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) @@ -28,15 +27,22 @@ type DefaultVpn struct { holocron remote.Holocron serverStorage servers.ServerStorage dirProvider config.DirectoryProvider + wgManager wireguard.Manager } -func NewDefaultVpn(cfgProvider config.ConfigProvider, holocron remote.Holocron, serverStorage servers.ServerStorage, dirProvider config.DirectoryProvider) Vpn { +func NewDefaultVpn(cfgProvider config.ConfigProvider, holocron remote.Holocron, serverStorage servers.ServerStorage, dirProvider config.DirectoryProvider) (Vpn, error) { + mgr, err := wireguard.NewManager() + if err != nil { + return nil, fmt.Errorf("failed to create wireguard manager: %w", err) + } + return &DefaultVpn{ cfgProvider: cfgProvider, holocron: holocron, serverStorage: serverStorage, dirProvider: dirProvider, - } + wgManager: mgr, + }, nil } func (vpn *DefaultVpn) Servers(showCities bool, showServers bool) error { @@ -235,9 +241,9 @@ func (vpn *DefaultVpn) Connect(cfg string) error { return errors.NewUserError(fmt.Sprintf("Server %s not found", cfg), errors.ErrNotFound) } - cfgName := strings.Split(server.Hostname, ".")[0] + ifaceName := sanitizeInterfaceName(server.Hostname) - output.PrintMsg(fmt.Sprintf("Connecting to %s...", cfgName), output.MsgOutput) + output.PrintMsg(fmt.Sprintf("Connecting to %s...", ifaceName), output.MsgOutput) keyData, err := vpn.cfgProvider.Get() if err != nil || keyData.PrivateKey == "" || keyData.PublicKey == "" { @@ -263,47 +269,132 @@ func (vpn *DefaultVpn) Connect(cfg string) error { return errors.NewNetworkError("register public key", err) } - cfgPath, err := vpn.writeConfig(cfgName, *server, keyData.PrivateKey, ipAddrs.IpV4, ipAddrs.IpV6) - if err != nil { - return errors.NewVPNError("write config", err) + // Create WireGuard interface + if err := vpn.wgManager.CreateInterface(ifaceName); err != nil { + // If interface already exists, try to remove it first + if strings.Contains(err.Error(), "already exists") { + log.Infof("Interface %s already exists, removing and recreating", ifaceName) + if rmErr := vpn.wgManager.RemoveInterface(ifaceName); rmErr != nil { + return errors.NewVPNError("remove existing interface", rmErr) + } + if err := vpn.wgManager.CreateInterface(ifaceName); err != nil { + return errors.NewVPNError("create interface", err) + } + } else { + return errors.NewVPNError("create interface", err) + } } - output.PrintMsg(fmt.Sprintf("Calling 'wg-quick up %s'", cfgPath), output.MsgOutput) - err = console.WgUp(cfgPath, vpn.dirProvider) - if err != nil { - return errors.NewVPNError("connect", err) + // Determine port from server's port ranges, fallback to 51820 + port := 51820 + if len(server.PortRanges) > 0 { + port = server.PortRanges[0].From + } + + // Configure WireGuard + wgCfg := wireguard.DeviceConfig{ + PrivateKey: keyData.PrivateKey, + Peers: []wireguard.PeerConfig{{ + PublicKey: server.PublicKey, + Endpoint: fmt.Sprintf("%s:%d", server.IPv4AddrIn, port), + AllowedIPs: []string{"0.0.0.0/0", "::/0"}, + }}, + } + if err := vpn.wgManager.Configure(ifaceName, wgCfg); err != nil { + vpn.wgManager.RemoveInterface(ifaceName) // Cleanup on failure + return errors.NewVPNError("configure device", err) + } + + // Assign addresses + addrs := []string{ipAddrs.IpV4, ipAddrs.IpV6} + if err := vpn.wgManager.AssignAddresses(ifaceName, addrs); err != nil { + vpn.wgManager.RemoveInterface(ifaceName) + return errors.NewVPNError("assign addresses", err) + } + + // Bring interface up + if err := vpn.wgManager.SetInterfaceUp(ifaceName); err != nil { + vpn.wgManager.RemoveInterface(ifaceName) + return errors.NewVPNError("set interface up", err) + } + + // Parse endpoint IP for routing + endpointIP := net.ParseIP(server.IPv4AddrIn) + + // Add routes + if err := vpn.wgManager.AddRoutes(ifaceName, []string{"0.0.0.0/0", "::/0"}, endpointIP); err != nil { + vpn.wgManager.RemoveInterface(ifaceName) + return errors.NewVPNError("add routes", err) } output.PrintMsg("Connected.", output.MsgSuccess) return nil } +// sanitizeInterfaceName converts a server hostname to a valid interface name +// Linux interface names are limited to 15 characters +func sanitizeInterfaceName(hostname string) string { + // Extract just the server name part (before the first dot) + name := strings.Split(hostname, ".")[0] + + // Replace any non-alphanumeric characters with empty string + var sanitized strings.Builder + for _, r := range name { + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '-' || r == '_' { + sanitized.WriteRune(r) + } + } + + result := sanitized.String() + + // Truncate to 15 characters (Linux interface name limit) + if len(result) > 15 { + result = result[:15] + } + + return result +} + func (vpn *DefaultVpn) Disconnect(cfg string) error { if cfg == "" { - servers, err := getConnectedServers() + devices, err := vpn.wgManager.ListDevices() if err != nil { - return errors.NewVPNError("get connected servers", err) + return errors.NewVPNError("list devices", err) } - for _, s := range servers { - err := vpn.Disconnect(s) - if err != nil { - // Continue trying to disconnect other servers even if one fails - log.Errorf("Failed to disconnect from %s: %v", s, err) + if len(devices) == 0 { + output.PrintMsg("No active connections.", output.MsgOutput) + return nil + } + + for _, name := range devices { + output.PrintMsg(fmt.Sprintf("Disconnecting from %s...", name), output.MsgOutput) + if err := vpn.wgManager.RemoveInterface(name); err != nil { + log.Errorf("Failed to disconnect from %s: %v", name, err) + } else { + output.PrintMsg(fmt.Sprintf("Disconnected from %s.", name), output.MsgSuccess) } } return nil } - output.PrintMsg(fmt.Sprintf("Disconnecting from %s...", cfg), output.MsgOutput) - - cfgDir, err := vpn.dirProvider.GetServersDir() + // If a specific server name is provided, find and disconnect it + server, err := vpn.serverStorage.GetByServerName(cfg) if err != nil { - return errors.NewConfigError("get servers directory", err) + return errors.NewConfigError("get server by name", err) } - err = console.WgDown(filepath.Join(cfgDir, cfg+".conf"), vpn.dirProvider) - if err != nil { + var ifaceName string + if server != nil { + ifaceName = sanitizeInterfaceName(server.Hostname) + } else { + // Maybe they provided the interface name directly + ifaceName = sanitizeInterfaceName(cfg) + } + + output.PrintMsg(fmt.Sprintf("Disconnecting from %s...", ifaceName), output.MsgOutput) + + if err := vpn.wgManager.RemoveInterface(ifaceName); err != nil { return errors.NewVPNError("disconnect", err) } @@ -312,16 +403,32 @@ func (vpn *DefaultVpn) Disconnect(cfg string) error { } func (vpn *DefaultVpn) Status() error { - servers, err := getConnectedServers() + devices, err := vpn.wgManager.ListDevices() if err != nil { return errors.NewVPNError("get status", err) } - if len(servers) == 0 { + if len(devices) == 0 { output.PrintMsg("No active connections.", output.MsgOutput) } else { - for _, s := range servers { - output.PrintMsg(fmt.Sprintf("Connected to: %s", s), output.MsgSuccess) + for _, name := range devices { + dev, err := vpn.wgManager.GetDevice(name) + if err != nil { + log.Errorf("Failed to get device %s: %v", name, err) + continue + } + output.PrintMsg(fmt.Sprintf("Connected to: %s", name), output.MsgSuccess) + if dev != nil && len(dev.Peers) > 0 { + for _, peer := range dev.Peers { + if peer.Endpoint != "" { + output.PrintMsg(fmt.Sprintf(" Endpoint: %s", peer.Endpoint), output.MsgOutput) + } + if !peer.LastHandshakeTime.IsZero() { + output.PrintMsg(fmt.Sprintf(" Last handshake: %s", peer.LastHandshakeTime.Format("2006-01-02 15:04:05")), output.MsgOutput) + } + output.PrintMsg(fmt.Sprintf(" Transfer: ↓ %s / ↑ %s", formatBytes(peer.ReceiveBytes), formatBytes(peer.TransmitBytes)), output.MsgOutput) + } + } } } @@ -330,14 +437,26 @@ func (vpn *DefaultVpn) Status() error { return errors.NewNetworkError("get network details", err) } - output.PrintMsg(fmt.Sprintf("IP Address: %s", network.Ip), output.MsgOutput) - output.PrintMsg(fmt.Sprintf("VPN enabled: %t", network.VpnEnabled), output.MsgOutput) - output.PrintMsg(fmt.Sprintf("Country: %s", network.Geo.Country), output.MsgOutput) - output.PrintMsg(fmt.Sprintf("City: %s", network.Geo.City), output.MsgOutput) + output.PrintMsg(fmt.Sprintf(" Country: %s", network.Geo.Country), output.MsgOutput) + output.PrintMsg(fmt.Sprintf(" City: %s", network.Geo.City), output.MsgOutput) return nil } +// formatBytes formats bytes into a human-readable string +func formatBytes(bytes int64) string { + const unit = 1024 + if bytes < unit { + return fmt.Sprintf("%d B", bytes) + } + div, exp := int64(unit), 0 + for n := bytes / unit; n >= unit; n /= unit { + div *= unit + exp++ + } + return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp]) +} + func generateKeys() (wgtypes.Key, wgtypes.Key, wgtypes.Key, error) { privateKey, err := wgtypes.GeneratePrivateKey() if err != nil { @@ -355,65 +474,3 @@ func generateKeys() (wgtypes.Key, wgtypes.Key, wgtypes.Key, error) { return publicKey, preSharedKey, privateKey, nil } - -func (vpn *DefaultVpn) writeConfig(cfgName string, server remote.Server, privateKey string, ipv4 string, ipv6 string) (string, error) { - content := fmt.Sprintf(`[Interface] -PrivateKey = %s -Address = %s, %s - -[Peer] -PublicKey = %s -Endpoint = %s:51820 -AllowedIPs = 0.0.0.0/0, ::/0`, - privateKey, - ipv4, - ipv6, - server.PublicKey, - server.IPv4AddrIn, - ) - - fullPath, err := vpn.saveWgConfig(cfgName, content) - if err != nil { - return "", fmt.Errorf("failed to save config file: %w", err) - } - - return fullPath, nil -} - -func getConnectedServers() ([]string, error) { - output, err := console.WgShow() - if err != nil { - return nil, fmt.Errorf("failed to execute 'wg show': %w", err) - } - - lines := strings.Split(string(output), "\n") - n := strings.Count(string(output), "interface:") - s := make([]string, n) - i := 0 - for _, l := range lines { - if strings.HasPrefix(l, "interface") { - s[i] = strings.TrimPrefix(l, "interface: ") - i++ - } - } - return s, nil -} - -func (vpn *DefaultVpn) saveWgConfig(serverName string, configContent string) (string, error) { - configDir, err := vpn.dirProvider.GetServersDir() - if err != nil { - return "", fmt.Errorf("failed to get servers directory: %w", err) - } - - // Ensure directory exists - if err := os.MkdirAll(configDir, 0o755); err != nil { - return "", fmt.Errorf("failed to create servers directory: %w", err) - } - - // Sanitize the server name to avoid path traversal attacks - serverName = filepath.Base(serverName) - - // Create the file with restricted permissions (600) as it contains private keys - filePath := filepath.Join(configDir, serverName+".conf") - return filePath, os.WriteFile(filePath, []byte(configContent), 0o600) -} diff --git a/pkg/vpn/vpn_test.go b/pkg/vpn/vpn_test.go index 7b44d00..3eefe35 100644 --- a/pkg/vpn/vpn_test.go +++ b/pkg/vpn/vpn_test.go @@ -2,26 +2,12 @@ package vpn import ( "os" - "path/filepath" "strings" "testing" "github.com/malwarebytes/mbvpn-linux/pkg/config" - "github.com/malwarebytes/mbvpn-linux/pkg/remote" ) -// Helper function to set up test directory -func setupTestDir(t *testing.T) (config.DirectoryProvider, func()) { - tempDir := t.TempDir() - configDir := filepath.Join(tempDir, "mbvpn") - serversDir := filepath.Join(configDir, "servers") - if err := os.MkdirAll(serversDir, 0755); err != nil { - t.Fatalf("Failed to create test directories: %v", err) - } - dirProvider := config.NewDirectoryProvider(tempDir) - return dirProvider, func() {} -} - // Mock DirectoryProvider for error testing type mockDirectoryProvider struct { shouldError bool @@ -228,264 +214,6 @@ func TestGetCountryFlag_AllMappedCountries(t *testing.T) { } } -func TestGetServersDir(t *testing.T) { - // Test creating config directory - dirProvider, cleanup := setupTestDir(t) - defer cleanup() - - vpn := &DefaultVpn{dirProvider: dirProvider} - - configDir, err := vpn.dirProvider.GetServersDir() - if err != nil { - t.Fatalf("GetServersDir should succeed: %v", err) - } - - // Verify directory path is correct - if !strings.Contains(configDir, "mbvpn") || !strings.Contains(configDir, "servers") { - t.Errorf("Expected config dir to contain 'mbvpn/servers', got '%s'", configDir) - } - - // Verify directory was created during setup - if _, err := os.Stat(configDir); os.IsNotExist(err) { - t.Error("Config directory should exist") - } - - // Verify directory permissions - info, err := os.Stat(configDir) - if err != nil { - t.Fatalf("Failed to stat config directory: %v", err) - } - - if info.Mode().Perm() != 0755 { - t.Errorf("Expected directory permissions 0755, got %o", info.Mode().Perm()) - } -} - -func TestGetServersDir_ErrorCase(t *testing.T) { - // Test error case using mockDirectoryProvider - mockDirProvider := &mockDirectoryProvider{shouldError: true} - vpn := &DefaultVpn{dirProvider: mockDirProvider} - - _, err := vpn.dirProvider.GetServersDir() - if err == nil { - t.Error("GetServersDir should fail when directory provider fails") - } -} - -func TestSaveWgConfig(t *testing.T) { - dirProvider, cleanup := setupTestDir(t) - defer cleanup() - - vpn := &DefaultVpn{dirProvider: dirProvider} - - serverName := "test-server" - configContent := `[Interface] -PrivateKey = test-private-key -Address = 10.0.0.1/32, 2001:db8::1/128 - -[Peer] -PublicKey = test-public-key -Endpoint = 203.0.113.1:51820 -AllowedIPs = 0.0.0.0/0, ::/0` - - filePath, err := vpn.saveWgConfig(serverName, configContent) - if err != nil { - t.Fatalf("saveWgConfig should succeed: %v", err) - } - - // Verify file path contains expected components - if !strings.Contains(filePath, "mbvpn") || !strings.Contains(filePath, "servers") || !strings.Contains(filePath, "test-server.conf") { - t.Errorf("Expected file path to contain 'mbvpn/servers/test-server.conf', got '%s'", filePath) - } - - // Verify file was created - if _, err := os.Stat(filePath); os.IsNotExist(err) { - t.Error("Config file should be created") - } - - // Verify file permissions (should be restrictive for private keys) - info, err := os.Stat(filePath) - if err != nil { - t.Fatalf("Failed to stat config file: %v", err) - } - - if info.Mode().Perm() != 0600 { - t.Errorf("Expected file permissions 0600, got %o", info.Mode().Perm()) - } - - // Verify file contents - savedContent, err := os.ReadFile(filePath) - if err != nil { - t.Fatalf("Failed to read saved file: %v", err) - } - - if string(savedContent) != configContent { - t.Errorf("File content mismatch.\nExpected:\n%s\nGot:\n%s", configContent, string(savedContent)) - } -} - -func TestSaveWgConfig_PathTraversal(t *testing.T) { - dirProvider, cleanup := setupTestDir(t) - defer cleanup() - - vpn := &DefaultVpn{dirProvider: dirProvider} - - // Test that path traversal attempts are sanitized - maliciousNames := []string{ - "../../../etc/passwd", - "/etc/passwd", - "server/../../../etc/passwd", - } - - for _, name := range maliciousNames { - t.Run(name, func(t *testing.T) { - filePath, err := vpn.saveWgConfig(name, "test content") - if err != nil { - t.Fatalf("saveWgConfig should not fail due to path sanitization: %v", err) - } - - // Verify the file is created in the correct directory (contains mbvpn/servers) - if !strings.Contains(filePath, "mbvpn") || !strings.Contains(filePath, "servers") { - t.Errorf("File should be created in mbvpn/servers directory, got: %s", filePath) - } - - // Verify the filename is sanitized (should not contain path separators) - fileName := filepath.Base(filePath) - if strings.Contains(fileName, "..") || strings.Contains(fileName, "/") { - t.Errorf("Filename should be sanitized, got: %s", fileName) - } - }) - } -} - -func TestWriteConfig(t *testing.T) { - dirProvider, cleanup := setupTestDir(t) - defer cleanup() - - vpn := &DefaultVpn{dirProvider: dirProvider} - - cfgName := "test-server" - server := remote.Server{ - Hostname: "vpn.example.com", - IPv4AddrIn: "203.0.113.1", - PublicKey: "test-server-public-key", - } - privateKey := "test-private-key" - ipv4 := "10.0.0.1/32" - ipv6 := "2001:db8::1/128" - - configPath, err := vpn.writeConfig(cfgName, server, privateKey, ipv4, ipv6) - if err != nil { - t.Fatalf("writeConfig should succeed: %v", err) - } - - // Verify file was created - if _, err := os.Stat(configPath); os.IsNotExist(err) { - t.Error("Config file should be created") - } - - // Read and verify config content - content, err := os.ReadFile(configPath) - if err != nil { - t.Fatalf("Failed to read config file: %v", err) - } - - configStr := string(content) - - // Verify required sections and values are present - expectedValues := []string{ - "[Interface]", - "PrivateKey = " + privateKey, - "Address = " + ipv4 + ", " + ipv6, - "[Peer]", - "PublicKey = " + server.PublicKey, - "Endpoint = " + server.IPv4AddrIn + ":51820", - "AllowedIPs = 0.0.0.0/0, ::/0", - } - - for _, expected := range expectedValues { - if !strings.Contains(configStr, expected) { - t.Errorf("Config should contain '%s'\nActual config:\n%s", expected, configStr) - } - } -} - -func TestGetConnectedServers_ParseOutput(t *testing.T) { - // This tests the parsing logic without actually running wg command - // We'll test the internal parsing logic by extracting it to a separate function - - // Test parsing wg show output - testOutput := `interface: wg0 - public key: test-public-key-1 - private key: (hidden) - listening port: 51820 - -peer: peer-public-key-1 - endpoint: 203.0.113.1:51820 - allowed ips: 0.0.0.0/0, ::/0 - -interface: wg1 - public key: test-public-key-2 - private key: (hidden) - listening port: 51821 - -peer: peer-public-key-2 - endpoint: 203.0.113.2:51820 - allowed ips: 0.0.0.0/0, ::/0` - - // Extract interface names from output - lines := strings.Split(testOutput, "\n") - interfaceCount := strings.Count(testOutput, "interface:") - interfaces := make([]string, interfaceCount) - i := 0 - for _, line := range lines { - if strings.HasPrefix(line, "interface") { - interfaceName := strings.TrimPrefix(line, "interface: ") - interfaces[i] = interfaceName - i++ - } - } - - expectedInterfaces := []string{"wg0", "wg1"} - if len(interfaces) != len(expectedInterfaces) { - t.Errorf("Expected %d interfaces, got %d", len(expectedInterfaces), len(interfaces)) - } - - for i, expected := range expectedInterfaces { - if i < len(interfaces) && interfaces[i] != expected { - t.Errorf("Expected interface %s, got %s", expected, interfaces[i]) - } - } -} - -func TestGetConnectedServers_EmptyOutput(t *testing.T) { - // Test parsing empty wg show output - testOutput := "" - - interfaceCount := strings.Count(testOutput, "interface:") - if interfaceCount != 0 { - t.Errorf("Expected 0 interfaces for empty output, got %d", interfaceCount) - } -} - -func TestGetConnectedServers_NoInterfaces(t *testing.T) { - // Test parsing wg show output with no interfaces - testOutput := `No interfaces configured` - - lines := strings.Split(testOutput, "\n") - interfaces := make([]string, 0) - for _, line := range lines { - if strings.HasPrefix(line, "interface") { - interfaceName := strings.TrimPrefix(line, "interface: ") - interfaces = append(interfaces, interfaceName) - } - } - - if len(interfaces) != 0 { - t.Errorf("Expected no interfaces, got %v", interfaces) - } -} - // TestVpnInterface verifies that DefaultVpn implements the Vpn interface func TestVpnInterface(t *testing.T) { // This test ensures DefaultVpn implements all required methods @@ -532,122 +260,71 @@ func TestGenerateKeys(t *testing.T) { } } -// Test NewDefaultVpn constructor - simplified test without mocks +// Test NewDefaultVpn constructor - requires root/CAP_NET_ADMIN func TestNewDefaultVpn(t *testing.T) { - // This test just ensures the constructor works without testing complex interactions - vpn := &DefaultVpn{} - if vpn == nil { - t.Error("DefaultVpn should be constructible") - } - - // Verify it implements the Vpn interface - var _ Vpn = vpn -} - -// Test individual functions that can be tested independently - -// Test error handling in helper functions -func TestSaveWgConfig_ErrorHandling(t *testing.T) { - // Test with mock directory provider that returns errors - mockDirProvider := &mockDirectoryProvider{shouldError: true} - vpn := &DefaultVpn{dirProvider: mockDirProvider} - - _, err := vpn.saveWgConfig("test", "content") - if err == nil { - t.Error("saveWgConfig should fail when directory provider fails") - } -} - -func TestWriteConfig_Integration(t *testing.T) { - dirProvider, cleanup := setupTestDir(t) - defer cleanup() - - vpn := &DefaultVpn{dirProvider: dirProvider} - - server := remote.Server{ - Hostname: "test.example.com", - IPv4AddrIn: "203.0.113.1", - PublicKey: "server-public-key", + if os.Geteuid() != 0 { + t.Skip("requires root or CAP_NET_ADMIN") } - configPath, err := vpn.writeConfig("test-server", server, "private-key", "10.0.0.1/32", "2001:db8::1/128") + dirProvider, err := config.NewDefaultDirectoryProvider() if err != nil { - t.Fatalf("writeConfig should succeed: %v", err) + t.Fatalf("Failed to create directory provider: %v", err) } - // Verify file was created and contains expected content - content, err := os.ReadFile(configPath) + _, err = NewDefaultVpn(nil, nil, nil, dirProvider) if err != nil { - t.Fatalf("Failed to read config file: %v", err) - } - - configStr := string(content) - expectedValues := []string{ - "PrivateKey = private-key", - "PublicKey = server-public-key", - "Endpoint = 203.0.113.1:51820", - "Address = 10.0.0.1/32, 2001:db8::1/128", - } - - for _, expected := range expectedValues { - if !strings.Contains(configStr, expected) { - t.Errorf("Config should contain '%s'\nActual config:\n%s", expected, configStr) - } + t.Fatalf("NewDefaultVpn should succeed with root/CAP_NET_ADMIN: %v", err) } } -// Additional test functions to improve coverage -func TestGetConnectedServers_EmptyInput(t *testing.T) { - // Test with empty string to cover error paths - lines := strings.Split("", "\n") - interfaceCount := strings.Count("", "interface:") - if interfaceCount != 0 { - t.Errorf("Expected 0 interfaces for empty input, got %d", interfaceCount) - } - - // Test the interface parsing logic - interfaces := make([]string, 0) - for _, line := range lines { - if strings.HasPrefix(line, "interface") { - interfaceName := strings.TrimPrefix(line, "interface: ") - interfaces = append(interfaces, interfaceName) - } +func TestSanitizeInterfaceName(t *testing.T) { + tests := []struct { + name string + hostname string + expected string + }{ + {"simple hostname", "server1.example.com", "server1"}, + {"hostname with dashes", "us-east-01.vpn.example.com", "us-east-01"}, + {"long hostname", "very-long-server-name-that-exceeds-limit.example.com", "very-long-serve"}, + {"special chars", "server!@#$%^&*().example.com", "server"}, + {"underscores", "server_01.example.com", "server_01"}, + {"mixed case", "ServerName.example.com", "ServerName"}, + {"numbers only", "12345.example.com", "12345"}, + {"empty after sanitize", "!@#$.example.com", ""}, } - if len(interfaces) != 0 { - t.Errorf("Expected no interfaces, got %v", interfaces) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := sanitizeInterfaceName(tt.hostname) + if result != tt.expected { + t.Errorf("sanitizeInterfaceName(%s) = %s, expected %s", tt.hostname, result, tt.expected) + } + }) } } -func TestGetConnectedServers_MultipleInterfaces(t *testing.T) { - // Test parsing multiple interfaces - testOutput := "interface: wg0\ninterface: wg1\ninterface: wg2" - - lines := strings.Split(testOutput, "\n") - interfaceCount := strings.Count(testOutput, "interface:") - - if interfaceCount != 3 { - t.Errorf("Expected 3 interfaces, got %d", interfaceCount) +func TestFormatBytes(t *testing.T) { + tests := []struct { + name string + bytes int64 + expected string + }{ + {"zero bytes", 0, "0 B"}, + {"small bytes", 512, "512 B"}, + {"one KB", 1024, "1.0 KB"}, + {"1.5 KB", 1536, "1.5 KB"}, + {"one MB", 1024 * 1024, "1.0 MB"}, + {"one GB", 1024 * 1024 * 1024, "1.0 GB"}, + {"one TB", 1024 * 1024 * 1024 * 1024, "1.0 TB"}, } - // Test interface extraction - interfaces := make([]string, interfaceCount) - i := 0 - for _, line := range lines { - if strings.HasPrefix(line, "interface") { - interfaceName := strings.TrimPrefix(line, "interface: ") - if i < len(interfaces) { - interfaces[i] = interfaceName - i++ + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := formatBytes(tt.bytes) + if result != tt.expected { + t.Errorf("formatBytes(%d) = %s, expected %s", tt.bytes, result, tt.expected) } - } - } - - expectedInterfaces := []string{"wg0", "wg1", "wg2"} - for idx, expected := range expectedInterfaces { - if idx >= len(interfaces) || interfaces[idx] != expected { - t.Errorf("Expected interface %s at index %d, got %s", expected, idx, interfaces[idx]) - } + }) } } @@ -672,6 +349,51 @@ func TestGenerateKeys_ErrorPaths(t *testing.T) { } } +// Test parsing logic for connected servers output (similar to wg show output) +func TestParseWgShowOutput(t *testing.T) { + testOutput := `interface: wg0 + public key: test-public-key-1 + private key: (hidden) + listening port: 51820 + +peer: peer-public-key-1 + endpoint: 203.0.113.1:51820 + allowed ips: 0.0.0.0/0, ::/0 + +interface: wg1 + public key: test-public-key-2 + private key: (hidden) + listening port: 51821 + +peer: peer-public-key-2 + endpoint: 203.0.113.2:51820 + allowed ips: 0.0.0.0/0, ::/0` + + // Extract interface names from output + lines := strings.Split(testOutput, "\n") + interfaceCount := strings.Count(testOutput, "interface:") + interfaces := make([]string, interfaceCount) + i := 0 + for _, line := range lines { + if strings.HasPrefix(line, "interface") { + interfaceName := strings.TrimPrefix(line, "interface: ") + interfaces[i] = interfaceName + i++ + } + } + + expectedInterfaces := []string{"wg0", "wg1"} + if len(interfaces) != len(expectedInterfaces) { + t.Errorf("Expected %d interfaces, got %d", len(expectedInterfaces), len(interfaces)) + } + + for i, expected := range expectedInterfaces { + if i < len(interfaces) && interfaces[i] != expected { + t.Errorf("Expected interface %s, got %s", expected, interfaces[i]) + } + } +} + // Benchmark tests func BenchmarkGetCountryFlag(b *testing.B) { for i := 0; i < b.N; i++ { @@ -690,3 +412,9 @@ func BenchmarkGenerateKeys(b *testing.B) { generateKeys() } } + +func BenchmarkSanitizeInterfaceName(b *testing.B) { + for i := 0; i < b.N; i++ { + sanitizeInterfaceName("very-long-server-name.example.com") + } +} diff --git a/pkg/wireguard/linux_manager.go b/pkg/wireguard/linux_manager.go new file mode 100644 index 0000000..ad2bbdd --- /dev/null +++ b/pkg/wireguard/linux_manager.go @@ -0,0 +1,399 @@ +package wireguard + +import ( + "errors" + "fmt" + "net" + "os" + "strings" + "syscall" + "time" + + "github.com/vishvananda/netlink" + "golang.zx2c4.com/wireguard/wgctrl" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +// LinuxManager implements Manager using netlink and wgctrl +type LinuxManager struct { + wgClient *wgctrl.Client +} + +// NewManager creates a new WireGuard manager for Linux +func NewManager() (*LinuxManager, error) { + client, err := wgctrl.New() + if err != nil { + return nil, fmt.Errorf("failed to create wgctrl client: %w", err) + } + return &LinuxManager{wgClient: client}, nil +} + +// Close releases resources held by the manager +func (m *LinuxManager) Close() error { + if m.wgClient != nil { + return m.wgClient.Close() + } + return nil +} + +// CreateInterface creates a new WireGuard interface +func (m *LinuxManager) CreateInterface(name string) error { + la := netlink.NewLinkAttrs() + la.Name = name + + wg := &netlink.Wireguard{LinkAttrs: la} + err := netlink.LinkAdd(wg) + if err != nil { + // Check if interface already exists + if errors.Is(err, syscall.EEXIST) { + return fmt.Errorf("interface %s already exists", name) + } + return fmt.Errorf("failed to create interface %s: %w", name, err) + } + return nil +} + +// RemoveInterface removes a WireGuard interface +func (m *LinuxManager) RemoveInterface(name string) error { + link, err := netlink.LinkByName(name) + if err != nil { + // Interface doesn't exist - not an error + var linkNotFoundErr netlink.LinkNotFoundError + if errors.As(err, &linkNotFoundErr) { + return nil + } + return fmt.Errorf("failed to find interface %s: %w", name, err) + } + return netlink.LinkDel(link) +} + +// Configure applies WireGuard configuration to an interface +func (m *LinuxManager) Configure(name string, cfg DeviceConfig) error { + // Parse private key + privKey, err := wgtypes.ParseKey(cfg.PrivateKey) + if err != nil { + return fmt.Errorf("invalid private key: %w", err) + } + + // Build peer configs + var peers []wgtypes.PeerConfig + for _, p := range cfg.Peers { + peerCfg, err := buildPeerConfig(p) + if err != nil { + return fmt.Errorf("invalid peer config: %w", err) + } + peers = append(peers, peerCfg) + } + + // Build device config + wgCfg := wgtypes.Config{ + PrivateKey: &privKey, + ReplacePeers: true, + Peers: peers, + } + + // Set listen port if specified + if cfg.ListenPort > 0 { + wgCfg.ListenPort = &cfg.ListenPort + } + + // Apply configuration + if err := m.wgClient.ConfigureDevice(name, wgCfg); err != nil { + return fmt.Errorf("failed to configure device %s: %w", name, err) + } + + return nil +} + +func buildPeerConfig(p PeerConfig) (wgtypes.PeerConfig, error) { + pubKey, err := wgtypes.ParseKey(p.PublicKey) + if err != nil { + return wgtypes.PeerConfig{}, fmt.Errorf("invalid public key: %w", err) + } + + var endpoint *net.UDPAddr + if p.Endpoint != "" { + endpoint, err = net.ResolveUDPAddr("udp", p.Endpoint) + if err != nil { + return wgtypes.PeerConfig{}, fmt.Errorf("invalid endpoint %s: %w", p.Endpoint, err) + } + } + + var allowedIPs []net.IPNet + for _, cidr := range p.AllowedIPs { + _, ipnet, err := net.ParseCIDR(cidr) + if err != nil { + return wgtypes.PeerConfig{}, fmt.Errorf("invalid allowed IP %s: %w", cidr, err) + } + allowedIPs = append(allowedIPs, *ipnet) + } + + peerCfg := wgtypes.PeerConfig{ + PublicKey: pubKey, + Endpoint: endpoint, + AllowedIPs: allowedIPs, + ReplaceAllowedIPs: true, + } + + if p.PersistentKeepalive > 0 { + keepalive := time.Duration(p.PersistentKeepalive) * time.Second + peerCfg.PersistentKeepaliveInterval = &keepalive + } + + return peerCfg, nil +} + +// AssignAddresses assigns IP addresses to an interface +func (m *LinuxManager) AssignAddresses(name string, addrs []string) error { + link, err := netlink.LinkByName(name) + if err != nil { + return fmt.Errorf("failed to find interface %s: %w", name, err) + } + + for _, addr := range addrs { + // Ensure address is in CIDR notation + cidrAddr := ensureCIDR(addr) + nlAddr, err := netlink.ParseAddr(cidrAddr) + if err != nil { + return fmt.Errorf("failed to parse address %s: %w", cidrAddr, err) + } + if err := netlink.AddrAdd(link, nlAddr); err != nil { + // Ignore if address already exists + if !errors.Is(err, syscall.EEXIST) { + return fmt.Errorf("failed to add address %s: %w", cidrAddr, err) + } + } + } + return nil +} + +// ensureCIDR adds a CIDR suffix if the address doesn't have one +func ensureCIDR(addr string) string { + // Skip empty addresses + if addr == "" { + return addr + } + + // Already has CIDR notation + if strings.Contains(addr, "/") { + return addr + } + + // Parse the IP to determine if it's IPv4 or IPv6 + ip := net.ParseIP(addr) + if ip == nil { + return addr // Return as-is if not a valid IP + } + + if ip.To4() != nil { + return addr + "/32" + } + return addr + "/128" +} + +// SetInterfaceUp brings the interface up +func (m *LinuxManager) SetInterfaceUp(name string) error { + link, err := netlink.LinkByName(name) + if err != nil { + return fmt.Errorf("failed to find interface %s: %w", name, err) + } + return netlink.LinkSetUp(link) +} + +// AddRoutes adds routes for allowed IPs through the interface +// endpointIP is used to add a specific route to the VPN server via the original gateway +func (m *LinuxManager) AddRoutes(name string, allowedIPs []string, endpointIP net.IP) error { + link, err := netlink.LinkByName(name) + if err != nil { + return fmt.Errorf("failed to find interface %s: %w", name, err) + } + + // Get current default gateway for routing VPN endpoint traffic + var defaultGW net.IP + var defaultLinkIndex int + routes, err := netlink.RouteList(nil, netlink.FAMILY_V4) + if err != nil { + return fmt.Errorf("failed to list routes: %w", err) + } + for _, r := range routes { + // Default route can be Dst == nil OR Dst == 0.0.0.0/0 + isDefault := r.Dst == nil + if !isDefault && r.Dst != nil { + ones, bits := r.Dst.Mask.Size() + isDefault = ones == 0 && bits == 32 + } + if isDefault && r.Gw != nil { + defaultGW = r.Gw + defaultLinkIndex = r.LinkIndex + break + } + } + + // Add route to VPN endpoint via original gateway (prevents routing loop) + if endpointIP != nil && defaultGW != nil { + var mask net.IPMask + if endpointIP.To4() != nil { + mask = net.CIDRMask(32, 32) + } else { + mask = net.CIDRMask(128, 128) + } + endpointRoute := &netlink.Route{ + LinkIndex: defaultLinkIndex, + Dst: &net.IPNet{IP: endpointIP, Mask: mask}, + Gw: defaultGW, + } + if err := netlink.RouteAdd(endpointRoute); err != nil { + // Ignore if route already exists + if !errors.Is(err, syscall.EEXIST) { + return fmt.Errorf("failed to add endpoint route: %w", err) + } + } + } + + // Add routes for allowed IPs through the WireGuard interface + for _, cidr := range allowedIPs { + _, dst, err := net.ParseCIDR(cidr) + if err != nil { + return fmt.Errorf("failed to parse CIDR %s: %w", cidr, err) + } + + // For default routes (0.0.0.0/0 or ::/0), we need special handling + // Using two /1 routes avoids replacing the system default route + if isDefaultRoute(dst) { + if err := addSplitDefaultRoutes(link.Attrs().Index, dst); err != nil { + return err + } + continue + } + + route := &netlink.Route{ + LinkIndex: link.Attrs().Index, + Dst: dst, + } + if err := netlink.RouteAdd(route); err != nil { + if !errors.Is(err, syscall.EEXIST) { + return fmt.Errorf("failed to add route for %s: %w", cidr, err) + } + } + } + return nil +} + +// isDefaultRoute checks if a route is a default route (0.0.0.0/0 or ::/0) +func isDefaultRoute(ipnet *net.IPNet) bool { + ones, bits := ipnet.Mask.Size() + return ones == 0 && (bits == 32 || bits == 128) +} + +// addSplitDefaultRoutes adds two /1 routes that effectively cover all traffic +// without replacing the system's default route +func addSplitDefaultRoutes(linkIndex int, dst *net.IPNet) error { + _, bits := dst.Mask.Size() + + if bits == 32 { + // IPv4: Add 0.0.0.0/1 and 128.0.0.0/1 + routes := []*netlink.Route{ + { + LinkIndex: linkIndex, + Dst: &net.IPNet{IP: net.IPv4(0, 0, 0, 0), Mask: net.CIDRMask(1, 32)}, + }, + { + LinkIndex: linkIndex, + Dst: &net.IPNet{IP: net.IPv4(128, 0, 0, 0), Mask: net.CIDRMask(1, 32)}, + }, + } + for _, route := range routes { + if err := netlink.RouteAdd(route); err != nil { + if !errors.Is(err, syscall.EEXIST) { + return fmt.Errorf("failed to add split route: %w", err) + } + } + } + } else { + // IPv6: Add ::/1 and 8000::/1 + routes := []*netlink.Route{ + { + LinkIndex: linkIndex, + Dst: &net.IPNet{IP: net.ParseIP("::"), Mask: net.CIDRMask(1, 128)}, + }, + { + LinkIndex: linkIndex, + Dst: &net.IPNet{IP: net.ParseIP("8000::"), Mask: net.CIDRMask(1, 128)}, + }, + } + for _, route := range routes { + if err := netlink.RouteAdd(route); err != nil { + if !errors.Is(err, syscall.EEXIST) { + return fmt.Errorf("failed to add split IPv6 route: %w", err) + } + } + } + } + return nil +} + +// GetDevice returns information about a WireGuard device +func (m *LinuxManager) GetDevice(name string) (*Device, error) { + dev, err := m.wgClient.Device(name) + if err != nil { + if os.IsNotExist(err) { + return nil, nil + } + return nil, fmt.Errorf("failed to get device %s: %w", name, err) + } + + device := &Device{ + Name: dev.Name, + PublicKey: dev.PublicKey.String(), + ListenPort: dev.ListenPort, + } + + for _, p := range dev.Peers { + var allowedIPs []string + for _, ip := range p.AllowedIPs { + allowedIPs = append(allowedIPs, ip.String()) + } + + peer := Peer{ + PublicKey: p.PublicKey.String(), + AllowedIPs: allowedIPs, + LastHandshakeTime: p.LastHandshakeTime, + TransmitBytes: p.TransmitBytes, + ReceiveBytes: p.ReceiveBytes, + PersistentKeepaliveInterval: int(p.PersistentKeepaliveInterval.Seconds()), + } + if p.Endpoint != nil { + peer.Endpoint = p.Endpoint.String() + } + device.Peers = append(device.Peers, peer) + } + + return device, nil +} + +// ListDevices returns a list of all WireGuard interface names +func (m *LinuxManager) ListDevices() ([]string, error) { + devices, err := m.wgClient.Devices() + if err != nil { + return nil, fmt.Errorf("failed to list devices: %w", err) + } + + names := make([]string, len(devices)) + for i, d := range devices { + names[i] = d.Name + } + return names, nil +} + +// IsInterfaceUp checks if an interface is up +func (m *LinuxManager) IsInterfaceUp(name string) (bool, error) { + link, err := netlink.LinkByName(name) + if err != nil { + var linkNotFoundErr netlink.LinkNotFoundError + if errors.As(err, &linkNotFoundErr) { + return false, nil + } + return false, fmt.Errorf("failed to find interface %s: %w", name, err) + } + return link.Attrs().Flags&net.FlagUp != 0, nil +} diff --git a/pkg/wireguard/linux_manager_test.go b/pkg/wireguard/linux_manager_test.go new file mode 100644 index 0000000..3387214 --- /dev/null +++ b/pkg/wireguard/linux_manager_test.go @@ -0,0 +1,261 @@ +package wireguard + +import ( + "net" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func skipIfNotPrivileged(t *testing.T) { + if os.Geteuid() != 0 { + t.Skip("requires root or CAP_NET_ADMIN") + } +} + +func TestNewManager(t *testing.T) { + skipIfNotPrivileged(t) + + mgr, err := NewManager() + require.NoError(t, err) + defer mgr.Close() + + assert.NotNil(t, mgr.wgClient) +} + +func TestCreateAndRemoveInterface(t *testing.T) { + skipIfNotPrivileged(t) + + mgr, err := NewManager() + require.NoError(t, err) + defer mgr.Close() + + const testInterface = "wgtest0" + + // Ensure clean state + _ = mgr.RemoveInterface(testInterface) + + // Create interface + err = mgr.CreateInterface(testInterface) + require.NoError(t, err) + + // Verify interface exists + up, err := mgr.IsInterfaceUp(testInterface) + require.NoError(t, err) + assert.False(t, up) // Interface is created but not up + + // Remove interface + err = mgr.RemoveInterface(testInterface) + require.NoError(t, err) + + // Verify interface is gone + up, err = mgr.IsInterfaceUp(testInterface) + require.NoError(t, err) + assert.False(t, up) +} + +func TestCreateInterfaceAlreadyExists(t *testing.T) { + skipIfNotPrivileged(t) + + mgr, err := NewManager() + require.NoError(t, err) + defer mgr.Close() + + const testInterface = "wgtest1" + + // Ensure clean state + _ = mgr.RemoveInterface(testInterface) + defer mgr.RemoveInterface(testInterface) + + // Create interface + err = mgr.CreateInterface(testInterface) + require.NoError(t, err) + + // Try to create again - should fail + err = mgr.CreateInterface(testInterface) + assert.Error(t, err) + assert.Contains(t, err.Error(), "already exists") +} + +func TestRemoveNonExistentInterface(t *testing.T) { + skipIfNotPrivileged(t) + + mgr, err := NewManager() + require.NoError(t, err) + defer mgr.Close() + + // Removing non-existent interface should not error + err = mgr.RemoveInterface("nonexistent123") + assert.NoError(t, err) +} + +func TestListDevices(t *testing.T) { + skipIfNotPrivileged(t) + + mgr, err := NewManager() + require.NoError(t, err) + defer mgr.Close() + + const testInterface = "wgtest2" + + // Ensure clean state + _ = mgr.RemoveInterface(testInterface) + defer mgr.RemoveInterface(testInterface) + + // Get initial device count + initialDevices, err := mgr.ListDevices() + require.NoError(t, err) + + // Create interface + err = mgr.CreateInterface(testInterface) + require.NoError(t, err) + + // List devices should include our new interface + devices, err := mgr.ListDevices() + require.NoError(t, err) + assert.Len(t, devices, len(initialDevices)+1) + assert.Contains(t, devices, testInterface) +} + +func TestConfigureDevice(t *testing.T) { + skipIfNotPrivileged(t) + + mgr, err := NewManager() + require.NoError(t, err) + defer mgr.Close() + + const testInterface = "wgtest3" + + // Ensure clean state + _ = mgr.RemoveInterface(testInterface) + defer mgr.RemoveInterface(testInterface) + + // Create interface + err = mgr.CreateInterface(testInterface) + require.NoError(t, err) + + // Configure with a test configuration + cfg := DeviceConfig{ + PrivateKey: "WG8ZcfSVD/oaEWcUW5lrJsLVWgYA8X0m9Iuv9U3XSWA=", // Test key only + Peers: []PeerConfig{ + { + PublicKey: "bPfJDdgBzYmjXLq0S+VQkzf5GdOOKZ5zjP8HCBXDYV4=", // Test key only + Endpoint: "192.0.2.1:51820", + AllowedIPs: []string{"10.0.0.0/24"}, + }, + }, + } + + err = mgr.Configure(testInterface, cfg) + require.NoError(t, err) + + // Verify configuration + dev, err := mgr.GetDevice(testInterface) + require.NoError(t, err) + require.NotNil(t, dev) + assert.Len(t, dev.Peers, 1) +} + +func TestAssignAddresses(t *testing.T) { + skipIfNotPrivileged(t) + + mgr, err := NewManager() + require.NoError(t, err) + defer mgr.Close() + + const testInterface = "wgtest4" + + // Ensure clean state + _ = mgr.RemoveInterface(testInterface) + defer mgr.RemoveInterface(testInterface) + + // Create interface + err = mgr.CreateInterface(testInterface) + require.NoError(t, err) + + // Assign addresses + err = mgr.AssignAddresses(testInterface, []string{"10.100.100.1/32", "fd00::1/128"}) + require.NoError(t, err) +} + +func TestSetInterfaceUp(t *testing.T) { + skipIfNotPrivileged(t) + + mgr, err := NewManager() + require.NoError(t, err) + defer mgr.Close() + + const testInterface = "wgtest5" + + // Ensure clean state + _ = mgr.RemoveInterface(testInterface) + defer mgr.RemoveInterface(testInterface) + + // Create interface + err = mgr.CreateInterface(testInterface) + require.NoError(t, err) + + // Assign an address first (required for bringing up) + err = mgr.AssignAddresses(testInterface, []string{"10.100.100.2/32"}) + require.NoError(t, err) + + // Bring interface up + err = mgr.SetInterfaceUp(testInterface) + require.NoError(t, err) + + // Verify interface is up + up, err := mgr.IsInterfaceUp(testInterface) + require.NoError(t, err) + assert.True(t, up) +} + +func TestInvalidPrivateKey(t *testing.T) { + skipIfNotPrivileged(t) + + mgr, err := NewManager() + require.NoError(t, err) + defer mgr.Close() + + const testInterface = "wgtest6" + + // Ensure clean state + _ = mgr.RemoveInterface(testInterface) + defer mgr.RemoveInterface(testInterface) + + // Create interface + err = mgr.CreateInterface(testInterface) + require.NoError(t, err) + + // Try to configure with invalid key + cfg := DeviceConfig{ + PrivateKey: "invalid-key", + } + + err = mgr.Configure(testInterface, cfg) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid private key") +} + +func TestIsDefaultRoute(t *testing.T) { + tests := []struct { + cidr string + expected bool + }{ + {"0.0.0.0/0", true}, + {"::/0", true}, + {"10.0.0.0/8", false}, + {"192.168.0.0/16", false}, + {"0.0.0.0/1", false}, + {"128.0.0.0/1", false}, + } + + for _, tt := range tests { + t.Run(tt.cidr, func(t *testing.T) { + _, ipnet, err := net.ParseCIDR(tt.cidr) + require.NoError(t, err) + assert.Equal(t, tt.expected, isDefaultRoute(ipnet)) + }) + } +} diff --git a/pkg/wireguard/manager.go b/pkg/wireguard/manager.go new file mode 100644 index 0000000..d0459ed --- /dev/null +++ b/pkg/wireguard/manager.go @@ -0,0 +1,73 @@ +package wireguard + +import ( + "net" + "time" +) + +// Manager defines the interface for WireGuard operations +type Manager interface { + // CreateInterface creates a new WireGuard interface + CreateInterface(name string) error + + // RemoveInterface removes a WireGuard interface + RemoveInterface(name string) error + + // Configure applies WireGuard configuration to an interface + Configure(name string, cfg DeviceConfig) error + + // AssignAddresses assigns IP addresses to an interface + AssignAddresses(name string, addrs []string) error + + // SetInterfaceUp brings the interface up + SetInterfaceUp(name string) error + + // AddRoutes adds routes for allowed IPs through the interface + AddRoutes(name string, allowedIPs []string, endpointIP net.IP) error + + // GetDevice returns information about a WireGuard device + GetDevice(name string) (*Device, error) + + // ListDevices returns a list of all WireGuard interface names + ListDevices() ([]string, error) + + // IsInterfaceUp checks if an interface is up + IsInterfaceUp(name string) (bool, error) + + // Close releases any resources held by the manager + Close() error +} + +// DeviceConfig represents WireGuard device configuration +type DeviceConfig struct { + PrivateKey string + ListenPort int + Peers []PeerConfig +} + +// PeerConfig represents a WireGuard peer configuration +type PeerConfig struct { + PublicKey string + Endpoint string // "host:port" + AllowedIPs []string // CIDR notation: ["0.0.0.0/0", "::/0"] + PersistentKeepalive int // seconds, 0 to disable +} + +// Device represents a WireGuard device's current state +type Device struct { + Name string + PublicKey string + ListenPort int + Peers []Peer +} + +// Peer represents a WireGuard peer's current state +type Peer struct { + PublicKey string + Endpoint string + AllowedIPs []string + LastHandshakeTime time.Time + TransmitBytes int64 + ReceiveBytes int64 + PersistentKeepaliveInterval int +} From ce5e57bd5cc48fd89404d8ff2a155726ec591b58 Mon Sep 17 00:00:00 2001 From: Vladimir Parfenov Date: Wed, 17 Dec 2025 11:16:13 +0200 Subject: [PATCH 2/3] feat: bring ip info back to status command --- pkg/vpn/vpn.go | 1 + 1 file changed, 1 insertion(+) diff --git a/pkg/vpn/vpn.go b/pkg/vpn/vpn.go index d1a36de..e74ef2b 100644 --- a/pkg/vpn/vpn.go +++ b/pkg/vpn/vpn.go @@ -437,6 +437,7 @@ func (vpn *DefaultVpn) Status() error { return errors.NewNetworkError("get network details", err) } + output.PrintMsg(fmt.Sprintf(" IP Address: %s", network.Ip), output.MsgOutput) output.PrintMsg(fmt.Sprintf(" Country: %s", network.Geo.Country), output.MsgOutput) output.PrintMsg(fmt.Sprintf(" City: %s", network.Geo.City), output.MsgOutput) From 72f2a1e0a9067b14ec5f011f7dc385c81801af74 Mon Sep 17 00:00:00 2001 From: Vladimir Parfenov Date: Wed, 17 Dec 2025 11:18:03 +0200 Subject: [PATCH 3/3]