Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions config.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
package proxy

import (
"net"
)

// Cidrs is a slice of IPNet addresses
type Cidrs []*net.IPNet

type Config struct {
// TrustedSubnets declare IP subnets which are allowed to set ip using X-Real-Ip and X-Forwarded-For
TrustedSubnets []string `mapstructure:"trusted_subnets"`
Expand Down
30 changes: 6 additions & 24 deletions plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ const (
forwarded string = "Forwarded"
)

var (
forwardedRegex = regexp.MustCompile(`(?i)(?:for=)([^(;|,| )]+)`)
)
var forwardedRegex = regexp.MustCompile(`(?i)(?:for=)([^(;|,| )]+)`)

type Logger interface {
NamedLogger(name string) *slog.Logger
Expand Down Expand Up @@ -107,8 +105,8 @@ func (p *Plugin) Middleware(next http.Handler) http.Handler {
}

ip := net.ParseIP(host)
for i := range p.trusted {
if p.trusted[i].Contains(ip) {
for _, subnet := range p.trusted {
if subnet.Contains(ip) {
resolvedIP := p.resolveIP(r.Header)
if resolvedIP != "" {
r.RemoteAddr = resolvedIP
Expand Down Expand Up @@ -144,16 +142,9 @@ func (p *Plugin) resolveIP(headers http.Header) string {
}
// XFF parse
} else if fwd := headers.Get(xff); fwd != "" {
s := strings.Index(fwd, ",")
if s == -1 {
return fwd
}

if len(fwd) < s {
return ""
}

return fwd[:s]
// take the first address; Cut returns the whole string when no comma is present
before, _, _ := strings.Cut(fwd, ",")
return before
// next -> X-Real-Ip
} else if fwd := headers.Get(xrip); fwd != "" {
return fwd
Expand All @@ -174,12 +165,3 @@ func (p *Plugin) resolveIP(headers http.Header) string {

return ""
}

func inc(ip net.IP) {
for j := len(ip) - 1; j >= 0; j-- {
ip[j]++
if ip[j] > 0 {
break
}
}
}
9 changes: 9 additions & 0 deletions trusted_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,12 @@ func TestCidrsInRange(t *testing.T) {

require.Len(t, addrs, 1024)
}

func inc(ip net.IP) {
for j := len(ip) - 1; j >= 0; j-- {
ip[j]++
if ip[j] > 0 {
break
}
}
}
Loading