diff --git a/xff.go b/xff.go index 02f6e69..fc0e1d1 100644 --- a/xff.go +++ b/xff.go @@ -49,13 +49,25 @@ func IsPublicIP(ip net.IP) bool { } // Parse parses the value of the X-Forwarded-For Header and returns the IP address. -func Parse(ipList string) string { - for _, ip := range strings.Split(ipList, ",") { +func Parse(ipList string, allowed func(sip string) bool) string { + ipSplit := strings.Split(ipList, ",") + var lastValidIP net.IP + for i := len(ipSplit) - 1; i >= 0; i-- { + ip := ipSplit[i] ip = strings.TrimSpace(ip) - if IP := net.ParseIP(ip); IP != nil && IsPublicIP(IP) { - return ip + parsedIP := net.ParseIP(ip) + if parsedIP != nil { + lastValidIP = parsedIP + if !allowed(ip) { + break + } + } else { + break } } + if lastValidIP != nil { + return lastValidIP.String() + } return "" } @@ -71,7 +83,7 @@ func GetRemoteAddrIfAllowed(r *http.Request, allowed func(sip string) bool) stri if xffh := r.Header.Get("X-Forwarded-For"); xffh != "" { if sip, sport, err := net.SplitHostPort(r.RemoteAddr); err == nil && sip != "" { if allowed(sip) { - if xip := Parse(xffh); xip != "" { + if xip := Parse(xffh, allowed); xip != "" { return net.JoinHostPort(xip, sport) } } diff --git a/xff_test.go b/xff_test.go index 56297fd..f3e3ec1 100644 --- a/xff_test.go +++ b/xff_test.go @@ -8,63 +8,86 @@ import ( "github.com/stretchr/testify/assert" ) +func allowAll(string) bool { + return true +} + func TestParse_none(t *testing.T) { - res := Parse("") + res := Parse("", allowAll) assert.Equal(t, "", res) } func TestParse_localhost(t *testing.T) { - res := Parse("127.0.0.1") - assert.Equal(t, "", res) + res := Parse("127.0.0.1", allowAll) + assert.Equal(t, "127.0.0.1", res) } func TestParse_invalid(t *testing.T) { - res := Parse("invalid") + res := Parse("invalid", allowAll) assert.Equal(t, "", res) } func TestParse_invalid_sioux(t *testing.T) { - res := Parse("123#1#2#3") + res := Parse("123#1#2#3", allowAll) assert.Equal(t, "", res) } func TestParse_invalid_private_lookalike(t *testing.T) { - res := Parse("102.3.2.1") + res := Parse("102.3.2.1", allowAll) assert.Equal(t, "102.3.2.1", res) } func TestParse_valid(t *testing.T) { - res := Parse("68.45.152.220") + res := Parse("68.45.152.220", allowAll) assert.Equal(t, "68.45.152.220", res) } func TestParse_multi_first(t *testing.T) { - res := Parse("12.13.14.15, 68.45.152.220") + res := Parse("12.13.14.15, 68.45.152.220", allowAll) assert.Equal(t, "12.13.14.15", res) } func TestParse_multi_last(t *testing.T) { - res := Parse("192.168.110.162, 190.57.149.90") - assert.Equal(t, "190.57.149.90", res) + res := Parse("192.168.110.162, 190.57.149.90", allowAll) + assert.Equal(t, "192.168.110.162", res) +} + +func TestParse_multi_accept(t *testing.T) { + res := Parse("1.0.0.1, 1.0.0.2, 1.0.0.3", func(ip string) bool { + return ip == "1.0.0.3" + }) + assert.Equal(t, "1.0.0.2", res) +} + +func TestParse_multi_accept_intermediate_private(t *testing.T) { + res := Parse("1.0.0.1, 1.0.0.2, 10.0.0.1, 1.0.0.3", func(ip string) bool { + return ip == "1.0.0.3" || ip == "10.0.0.1" + }) + assert.Equal(t, "1.0.0.2", res) +} + +func TestParse_multi_accept_final_private(t *testing.T) { + res := Parse("10.0.0.1, 1.0.0.3", allowAll) + assert.Equal(t, "10.0.0.1", res) } func TestParse_multi_with_invalid(t *testing.T) { - res := Parse("192.168.110.162, invalid, 190.57.149.90") + res := Parse("192.168.110.162, invalid, 190.57.149.90", allowAll) assert.Equal(t, "190.57.149.90", res) } func TestParse_multi_with_invalid2(t *testing.T) { - res := Parse("192.168.110.162, 190.57.149.90, invalid") - assert.Equal(t, "190.57.149.90", res) + res := Parse("192.168.110.162, 190.57.149.90, invalid", allowAll) + assert.Equal(t, "", res) } func TestParse_multi_with_invalid_sioux(t *testing.T) { - res := Parse("192.168.110.162, 190.57.149.90, 123#1#2#3") - assert.Equal(t, "190.57.149.90", res) + res := Parse("192.168.110.162, 190.57.149.90, 123#1#2#3", allowAll) + assert.Equal(t, "", res) } func TestParse_ipv6_with_port(t *testing.T) { - res := Parse("2604:2000:71a9:bf00:f178:a500:9a2d:670d") + res := Parse("2604:2000:71a9:bf00:f178:a500:9a2d:670d", allowAll) assert.Equal(t, "2604:2000:71a9:bf00:f178:a500:9a2d:670d", res) } @@ -106,6 +129,32 @@ func TestGetRemoteAddr_ipv6_with_xff(t *testing.T) { assert.Equal(t, "[2001:db8:0:1:1:1:1:1]:1234", ra) } +func TestGetRemoteAddrIfAllowed_ipv4_with_xff(t *testing.T) { + r := &http.Request{ + RemoteAddr: "1.2.3.4:1234", + Header: http.Header{ + "X-Forwarded-For": []string{"100.1.0.1, 100.0.0.1"}, + }, + } + ra := GetRemoteAddrIfAllowed(r, func(ip string) bool { + return ip == "1.2.3.4" || ip == "100.0.0.1" + }) + assert.Equal(t, "100.1.0.1:1234", ra) +} + +func TestGetRemoteAddrIfAllowed_ipv6_with_xff(t *testing.T) { + r := &http.Request{ + RemoteAddr: "1.2.3.4:1234", + Header: http.Header{ + "X-Forwarded-For": []string{"2001:db8:cafe::17, 2001:db8:0:1:1:1:1:1"}, + }, + } + ra := GetRemoteAddrIfAllowed(r, func(ip string) bool { + return ip == "1.2.3.4" || ip == "2001:db8:0:1:1:1:1:1" + }) + assert.Equal(t, "[2001:db8:cafe::17]:1234", ra) +} + func TestToMasks_empty(t *testing.T) { ips := []string{} masks, err := toMasks(ips)