-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathsocks.go
More file actions
137 lines (119 loc) · 3.25 KB
/
socks.go
File metadata and controls
137 lines (119 loc) · 3.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
package main
import (
"context"
"fmt"
"net"
"strconv"
"sync"
"time"
"github.com/armon/go-socks5"
)
type SOCKS5Server struct {
server *socks5.Server
listener net.Listener
port int
tunnel *Tunnel
dials chan string
wg sync.WaitGroup
}
func buildSOCKS5Dial(tunnel *Tunnel, socksPort int, baseDial func(context.Context, string, string) (net.Conn, error), onDial func(string, string)) func(context.Context, string, string) (net.Conn, error) {
return func(ctx context.Context, network, addr string) (net.Conn, error) {
if onDial != nil {
onDial(network, addr)
}
logger.Debugf("SOCKS5 dial request: %s %s", network, addr)
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("invalid address format: %w", err)
}
ip := net.ParseIP(host)
if ip != nil && ip.IsLoopback() {
portNum, _ := strconv.Atoi(port)
if portNum == socksPort {
return nil, fmt.Errorf("refusing recursive SOCKS dial to localhost:%d", socksPort)
}
return baseDial(ctx, network, addr)
}
if tunnel != nil && tunnel.router != nil && ip != nil {
portNum, _ := strconv.Atoi(port)
peer, peerIdx := tunnel.router.FindPeerForDestination(ip, portNum, normalizeNetworkProtocol(network))
if peer != nil {
logger.Debugf("Routing %s through WireGuard tunnel via peer %d (endpoint: %s)", addr, peerIdx, peer.Endpoint)
return tunnel.DialWireGuard(ctx, network, host, port)
}
}
logger.Debugf("Using normal dial for %s", addr)
conn, err := baseDial(ctx, network, addr)
if err != nil {
logger.Debugf("SOCKS5 dial failed for %s: %v", addr, err)
} else {
logger.Debugf("SOCKS5 dial succeeded for %s", addr)
}
return conn, err
}
}
func NewSOCKS5Server(tunnel *Tunnel) (*SOCKS5Server, error) {
if tunnel == nil {
return nil, fmt.Errorf("tunnel is required")
}
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, fmt.Errorf("failed to listen for SOCKS5 connections: %w", err)
}
port := listener.Addr().(*net.TCPAddr).Port
baseDialer := (&net.Dialer{}).DialContext
s := &SOCKS5Server{
listener: listener,
port: port,
tunnel: tunnel,
dials: make(chan string, 32),
}
socksConfig := &socks5.Config{}
socksConfig.Dial = buildSOCKS5Dial(tunnel, port, baseDialer, func(_ string, addr string) {
select {
case s.dials <- addr:
default:
}
})
server, err := socks5.New(socksConfig)
if err != nil {
_ = listener.Close()
return nil, fmt.Errorf("failed to create SOCKS5 server: %w", err)
}
s.server = server
// Start serving in background
s.wg.Add(1)
go func() {
defer s.wg.Done()
if err := server.Serve(listener); err != nil {
// Log error but don't crash - server might be shutting down
logger.Debugf("SOCKS5 server stopped: %v", err)
}
}()
return s, nil
}
func (s *SOCKS5Server) Port() int {
return s.port
}
func (s *SOCKS5Server) WaitForDial(addr string, timeout time.Duration) error {
timer := time.NewTimer(timeout)
defer timer.Stop()
for {
select {
case got := <-s.dials:
if got == addr {
return nil
}
case <-timer.C:
return fmt.Errorf("timed out waiting for SOCKS dial to %s", addr)
}
}
}
func (s *SOCKS5Server) Close() error {
if s.listener != nil {
err := s.listener.Close()
s.wg.Wait()
return err
}
return nil
}