@@ -13,6 +13,7 @@ import (
1313 "net/url"
1414 "sync/atomic"
1515 "testing"
16+ "time"
1617
1718 "github.com/hashicorp/go-retryablehttp"
1819 "github.com/stretchr/testify/assert"
@@ -33,11 +34,13 @@ func TestNoPrivateIPs(t *testing.T) {
3334
3435 allowedURL := "http://localhost:" + port + "/foobar"
3536 allowedGlob := "http://localhost:" + port + "/glob/*"
37+ allowedPrivateIP := "http://100.64.1.1:80" + "/private"
3638
3739 c := NewResilientClient (
38- ResilientClientWithMaxRetry (1 ),
40+ ResilientClientWithMaxRetry (0 ),
41+ ResilientClientWithConnectionTimeout (50 * time .Millisecond ),
3942 ResilientClientDisallowInternalIPs (),
40- ResilientClientAllowInternalIPRequestsTo (allowedURL , allowedGlob ),
43+ ResilientClientAllowInternalIPRequestsTo (allowedURL , allowedGlob , allowedPrivateIP ),
4144 )
4245
4346 for i := 0 ; i < 10 ; i ++ {
@@ -49,13 +52,51 @@ func TestNoPrivateIPs(t *testing.T) {
4952 "http://localhost:" + port + "/glob/bar" : true ,
5053 "http://localhost:" + port + "/glob/bar/baz" : false ,
5154 "http://localhost:" + port + "/FOOBAR" : false ,
55+ allowedPrivateIP : true ,
56+ "http://100.64.8.8:" + port + "/route" : false ,
5257 } {
5358 _ , err := c .Get (destination )
5459 if ! passes {
5560 require .Errorf (t , err , "dest = %s" , destination )
5661 assert .Containsf (t , err .Error (), "is not a permitted destination" , "dest = %s" , destination )
62+ } else if err != nil {
63+ assert .NotContainsf (t , err .Error (), "is not a permitted destination" , "dest = %s" , destination )
64+ }
65+ }
66+ }
67+ }
68+
69+ func TestAllowPrivateIPs (t * testing.T ) {
70+ ts := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , _ * http.Request ) {
71+ _ , _ = w .Write ([]byte ("Hello, world!" ))
72+ }))
73+ t .Cleanup (ts .Close )
74+
75+ target , err := url .ParseRequestURI (ts .URL )
76+ require .NoError (t , err )
77+
78+ _ , port , err := net .SplitHostPort (target .Host )
79+ require .NoError (t , err )
80+
81+ c := NewResilientClient (
82+ ResilientClientWithMaxRetry (0 ),
83+ ResilientClientWithConnectionTimeout (50 * time .Millisecond ),
84+ )
85+
86+ for i := 0 ; i < 10 ; i ++ {
87+ for destination , handled := range map [string ]bool {
88+ "http://127.0.0.1:" + port : true ,
89+ "http://localhost:" + port : true ,
90+ "http://192.168.178.5:" + port : false ,
91+ "http://localhost:" + port + "/glob/bar" : true ,
92+ "http://100.64.1.1:" + port + "/route" : false ,
93+ } {
94+ _ , err = c .Get (destination )
95+ if handled {
96+ require .NoError (t , err )
5797 } else {
58- require .NoErrorf (t , err , "dest = %s" , destination )
98+ require .Error (t , err )
99+ assert .NotContainsf (t , err .Error (), "is not a permitted destination" , "dest = %s" , destination )
59100 }
60101 }
61102 }
0 commit comments