-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathplugin.go
More file actions
239 lines (201 loc) · 6.18 KB
/
Copy pathplugin.go
File metadata and controls
239 lines (201 loc) · 6.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
package proxy
import (
"context"
"log/slog"
"net"
"net/http"
"regexp"
"strings"
rrcontext "github.com/roadrunner-server/context"
"github.com/roadrunner-server/errors"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
jprop "go.opentelemetry.io/contrib/propagators/jaeger"
"go.opentelemetry.io/otel/propagation"
semconv "go.opentelemetry.io/otel/semconv/v1.20.0"
"go.opentelemetry.io/otel/trace"
)
const (
name string = "proxy_ip_parser"
configKey string = "http.trusted_subnets"
headersKey string = "http.trusted_headers"
xff string = "X-Forwarded-For"
xrip string = "X-Real-Ip"
tcip string = "True-Client-Ip"
cfip string = "Cf-Connecting-Ip"
forwarded string = "Forwarded"
)
var forwardedRegex = regexp.MustCompile(`(?i)(?:for=)([^(;|,| )]+)`)
type Logger interface {
NamedLogger(name string) *slog.Logger
}
type Configurer interface {
// UnmarshalKey takes a single key and unmarshal it into a Struct.
UnmarshalKey(name string, out any) error
// Has checks if a config section exists.
Has(name string) bool
}
type Plugin struct {
cfg *Config
log *slog.Logger
trusted []*net.IPNet
resolvers []resolver
prop propagation.TextMapPropagator
}
func (p *Plugin) Init(cfg Configurer, l Logger) error {
const op = errors.Op("proxy_ip_parser_init")
if !cfg.Has(configKey) {
return errors.E(errors.Disabled)
}
p.cfg = &Config{}
err := cfg.UnmarshalKey(configKey, &p.cfg.TrustedSubnets)
if err != nil {
return errors.E(op, err)
}
if len(p.cfg.TrustedSubnets) == 0 {
return errors.E(errors.Disabled)
}
p.log = l.NamedLogger(name)
p.prop = propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{}, jprop.Jaeger{})
p.trusted = make([]*net.IPNet, len(p.cfg.TrustedSubnets))
for i := range p.cfg.TrustedSubnets {
_, ipNet, err := net.ParseCIDR(p.cfg.TrustedSubnets[i])
if err != nil {
return errors.E(op, err)
}
p.trusted[i] = ipNet
}
if cfg.Has(headersKey) {
if err := cfg.UnmarshalKey(headersKey, &p.cfg.TrustedHeaders); err != nil {
return errors.E(op, err)
}
}
p.resolvers = buildResolvers(p.cfg.TrustedHeaders)
return nil
}
func (p *Plugin) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var span trace.Span
if val, ok := r.Context().Value(rrcontext.OtelTracerNameKey).(string); ok {
tp := trace.SpanFromContext(r.Context()).TracerProvider()
var ctx context.Context
ctx, span = tp.Tracer(val, trace.WithSchemaURL(semconv.SchemaURL),
trace.WithInstrumentationVersion(otelhttp.Version)).
Start(r.Context(), name, trace.WithSpanKind(trace.SpanKindInternal))
// inject
p.prop.Inject(ctx, propagation.HeaderCarrier(r.Header))
r = r.WithContext(ctx)
}
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
if span != nil {
span.End()
}
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
ip := net.ParseIP(host)
for _, subnet := range p.trusted {
if subnet.Contains(ip) {
resolvedIP := p.resolveIP(r.Header)
if resolvedIP != "" {
r.RemoteAddr = resolvedIP
}
break
}
}
// end span before calling next handler so it measures
// only this middleware's processing time
if span != nil {
span.End()
}
next.ServeHTTP(w, r)
})
}
func (p *Plugin) Name() string {
return name
}
// resolver extracts a candidate client IP from a single header value.
type resolver struct {
name string // canonical header name, e.g. "X-Forwarded-For"
parse func(string) string
}
// defaultResolvers returns the built-in resolution chain used when
// http.trusted_headers is not configured. True-Client-Ip and Cf-Connecting-Ip
// (CloudFlare) are checked last, matching their historical priority. The parser
// for each header comes from parserFor, the single source of that mapping.
func defaultResolvers() []resolver {
chain := []string{forwarded, xff, xrip, tcip, cfip}
resolvers := make([]resolver, len(chain))
for i, h := range chain {
resolvers[i] = resolver{h, parserFor(h)}
}
return resolvers
}
// buildResolvers turns the configured header allowlist into an ordered resolver
// chain. Entries are trimmed and canonicalized, blanks and duplicates dropped.
// An empty allowlist falls back to the default chain.
func buildResolvers(headers []string) []resolver {
resolvers := make([]resolver, 0, len(headers))
seen := make(map[string]struct{}, len(headers))
for _, hdr := range headers {
h := strings.TrimSpace(hdr)
if h == "" {
continue
}
canon := http.CanonicalHeaderKey(h)
if _, ok := seen[canon]; ok {
continue
}
seen[canon] = struct{}{}
resolvers = append(resolvers, resolver{canon, parserFor(canon)})
}
if len(resolvers) == 0 {
return defaultResolvers()
}
return resolvers
}
// parserFor selects the value parser for a canonical header name. The two
// structured headers keep dedicated parsers; everything else (including custom
// headers) is taken verbatim.
func parserFor(canon string) func(string) string {
switch canon {
case forwarded:
return parseForwarded
case xff:
return parseXFF
default:
return parseVerbatim
}
}
// parseForwarded extracts the "for=" target from an RFC 7239 Forwarded header.
// https://datatracker.ietf.org/doc/html/rfc7239
func parseForwarded(v string) string {
if m := forwardedRegex.FindStringSubmatch(v); len(m) > 1 {
// An IPv6 address (and any node-port) MUST be quoted, so trim the quotes.
return strings.Trim(m[1], `"`)
}
return ""
}
// parseXFF takes the left-most address from an X-Forwarded-For list.
func parseXFF(v string) string {
// Cut returns the whole string when no comma is present.
before, _, _ := strings.Cut(v, ",")
return before
}
// parseVerbatim returns the header value unchanged (X-Real-Ip, True-Client-Ip,
// Cf-Connecting-Ip and custom headers carry a single address).
func parseVerbatim(v string) string {
return v
}
// resolveIP returns the first non-empty client IP parsed from the configured
// (or default) header chain.
func (p *Plugin) resolveIP(headers http.Header) string {
for _, r := range p.resolvers {
if raw := headers.Get(r.name); raw != "" {
if ip := r.parse(raw); ip != "" {
return ip
}
}
}
return ""
}