Skip to content

Commit 30ad910

Browse files
committed
clean up and refactor tests
1 parent 97f4d42 commit 30ad910

1 file changed

Lines changed: 105 additions & 106 deletions

File tree

internal/api/proxy_test.go

Lines changed: 105 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -10,129 +10,103 @@ import (
1010
"net/http/httptest"
1111
"net/url"
1212
"strings"
13+
"sync"
1314
"testing"
1415
"time"
1516
)
1617

18+
type proxyOpts struct {
19+
useTLS bool
20+
username string
21+
password string
22+
observe func(*http.Request)
23+
}
24+
1725
// startProxy starts an HTTP or HTTPS CONNECT proxy on a random port.
18-
// It returns the proxy URL and a channel that receives the protocol observed by
19-
// the proxy handler for each CONNECT request.
20-
func startProxy(t *testing.T, useTLS bool) (proxyURL *url.URL, obsCh <-chan string) {
26+
// If opts.observe is set, it is called for each CONNECT request.
27+
// If opts.username is set, Proxy-Authorization is required.
28+
func startProxy(t *testing.T, opts proxyOpts) *url.URL {
2129
t.Helper()
2230

23-
ch := make(chan string, 10)
24-
2531
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
26-
select {
27-
case ch <- r.Proto:
28-
default:
32+
if opts.observe != nil {
33+
opts.observe(r)
2934
}
3035

3136
if r.Method != http.MethodConnect {
3237
http.Error(w, "expected CONNECT", http.StatusMethodNotAllowed)
3338
return
3439
}
3540

36-
destConn, err := net.DialTimeout("tcp", r.Host, 10*time.Second)
37-
if err != nil {
38-
http.Error(w, err.Error(), http.StatusBadGateway)
39-
return
40-
}
41-
defer destConn.Close()
42-
43-
hijacker, ok := w.(http.Hijacker)
44-
if !ok {
45-
http.Error(w, "hijacking not supported", http.StatusInternalServerError)
46-
return
41+
if opts.username != "" {
42+
wantAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(opts.username+":"+opts.password))
43+
if r.Header.Get("Proxy-Authorization") != wantAuth {
44+
http.Error(w, "proxy auth required", http.StatusProxyAuthRequired)
45+
return
46+
}
4747
}
4848

49-
w.WriteHeader(http.StatusOK)
50-
clientConn, bufrw, err := hijacker.Hijack()
51-
if err != nil {
52-
return
53-
}
54-
defer clientConn.Close()
55-
56-
done := make(chan struct{}, 2)
57-
// Read from bufrw (not clientConn) so any bytes already buffered
58-
// by the server's bufio.Reader are forwarded to the destination.
59-
go func() { io.Copy(destConn, bufrw); done <- struct{}{} }()
60-
go func() { io.Copy(clientConn, destConn); done <- struct{}{} }()
61-
<-done
62-
// Close both sides so the remaining goroutine unblocks.
63-
clientConn.Close()
64-
destConn.Close()
65-
<-done
49+
serveTunnel(w, r)
6650
}))
6751

68-
if useTLS {
52+
if opts.useTLS {
6953
srv.StartTLS()
7054
} else {
7155
srv.Start()
7256
}
7357
t.Cleanup(srv.Close)
7458

7559
pURL, _ := url.Parse(srv.URL)
76-
return pURL, ch
60+
if opts.username != "" {
61+
pURL.User = url.UserPassword(opts.username, opts.password)
62+
}
63+
return pURL
7764
}
7865

79-
// startProxyWithAuth is like startProxy but requires
80-
// Proxy-Authorization with the given username and password.
81-
func startProxyWithAuth(t *testing.T, useTLS bool, wantUser, wantPass string) (proxyURL *url.URL) {
82-
t.Helper()
83-
84-
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
85-
if r.Method != http.MethodConnect {
86-
http.Error(w, "expected CONNECT", http.StatusMethodNotAllowed)
87-
return
88-
}
89-
90-
authHeader := r.Header.Get("Proxy-Authorization")
91-
wantAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte(wantUser+":"+wantPass))
92-
if authHeader != wantAuth {
93-
http.Error(w, "proxy auth required", http.StatusProxyAuthRequired)
94-
return
95-
}
96-
97-
destConn, err := net.DialTimeout("tcp", r.Host, 10*time.Second)
98-
if err != nil {
99-
http.Error(w, err.Error(), http.StatusBadGateway)
100-
return
101-
}
102-
defer destConn.Close()
66+
// serveTunnel implements the CONNECT tunnel: dials the target, hijacks the
67+
// client connection, and copies bytes bidirectionally.
68+
func serveTunnel(w http.ResponseWriter, r *http.Request) {
69+
destConn, err := net.DialTimeout("tcp", r.Host, 10*time.Second)
70+
if err != nil {
71+
http.Error(w, err.Error(), http.StatusBadGateway)
72+
return
73+
}
74+
defer destConn.Close()
10375

104-
hijacker, ok := w.(http.Hijacker)
105-
if !ok {
106-
http.Error(w, "hijacking not supported", http.StatusInternalServerError)
107-
return
108-
}
76+
hijacker, ok := w.(http.Hijacker)
77+
if !ok {
78+
http.Error(w, "hijacking not supported", http.StatusInternalServerError)
79+
return
80+
}
10981

110-
w.WriteHeader(http.StatusOK)
111-
clientConn, bufrw, err := hijacker.Hijack()
112-
if err != nil {
113-
return
114-
}
115-
defer clientConn.Close()
82+
w.WriteHeader(http.StatusOK)
83+
clientConn, bufrw, err := hijacker.Hijack()
84+
if err != nil {
85+
return
86+
}
11687

117-
done := make(chan struct{}, 2)
118-
go func() { io.Copy(destConn, bufrw); done <- struct{}{} }()
119-
go func() { io.Copy(clientConn, destConn); done <- struct{}{} }()
120-
<-done
88+
var wg sync.WaitGroup
89+
var once sync.Once
90+
closeBoth := func() {
12191
clientConn.Close()
12292
destConn.Close()
123-
<-done
124-
}))
125-
126-
if useTLS {
127-
srv.StartTLS()
128-
} else {
129-
srv.Start()
13093
}
131-
t.Cleanup(srv.Close)
94+
defer once.Do(closeBoth)
13295

133-
pURL, _ := url.Parse(srv.URL)
134-
pURL.User = url.UserPassword(wantUser, wantPass)
135-
return pURL
96+
wg.Add(2)
97+
// Read from bufrw (not clientConn) so any bytes already buffered
98+
// by the server's bufio.Reader are forwarded to the destination.
99+
go func() {
100+
defer wg.Done()
101+
io.Copy(destConn, bufrw)
102+
once.Do(closeBoth)
103+
}()
104+
go func() {
105+
defer wg.Done()
106+
io.Copy(clientConn, destConn)
107+
once.Do(closeBoth)
108+
}()
109+
wg.Wait()
136110
}
137111

138112
// newTestTransport creates a base transport suitable for proxy tests.
@@ -157,7 +131,19 @@ func startTargetServer(t *testing.T) *httptest.Server {
157131

158132
func TestWithProxyTransport_HTTPProxy(t *testing.T) {
159133
target := startTargetServer(t)
160-
proxyURL, obsCh := startProxy(t, false)
134+
135+
var mu sync.Mutex
136+
var used bool
137+
var proto string
138+
139+
proxyURL := startProxy(t, proxyOpts{
140+
observe: func(r *http.Request) {
141+
mu.Lock()
142+
defer mu.Unlock()
143+
used = true
144+
proto = r.Proto
145+
},
146+
})
161147

162148
transport := withProxyTransport(newTestTransport(), proxyURL, "")
163149
t.Cleanup(transport.CloseIdleConnections)
@@ -180,19 +166,32 @@ func TestWithProxyTransport_HTTPProxy(t *testing.T) {
180166
t.Errorf("expected body 'ok', got %q", got)
181167
}
182168

183-
select {
184-
case proto := <-obsCh:
185-
if proto != "HTTP/1.1" {
186-
t.Errorf("expected proxy to see HTTP/1.1 CONNECT, got %s", proto)
187-
}
188-
case <-time.After(2 * time.Second):
169+
mu.Lock()
170+
defer mu.Unlock()
171+
if !used {
189172
t.Fatal("proxy handler was never invoked")
190173
}
174+
if proto != "HTTP/1.1" {
175+
t.Errorf("expected proxy to see HTTP/1.1 CONNECT, got %s", proto)
176+
}
191177
}
192178

193179
func TestWithProxyTransport_HTTPSProxy(t *testing.T) {
194180
target := startTargetServer(t)
195-
proxyURL, obsCh := startProxy(t, true)
181+
182+
var mu sync.Mutex
183+
var used bool
184+
var proto string
185+
186+
proxyURL := startProxy(t, proxyOpts{
187+
useTLS: true,
188+
observe: func(r *http.Request) {
189+
mu.Lock()
190+
defer mu.Unlock()
191+
used = true
192+
proto = r.Proto
193+
},
194+
})
196195

197196
transport := withProxyTransport(newTestTransport(), proxyURL, "")
198197
t.Cleanup(transport.CloseIdleConnections)
@@ -215,21 +214,21 @@ func TestWithProxyTransport_HTTPSProxy(t *testing.T) {
215214
t.Errorf("expected body 'ok', got %q", got)
216215
}
217216

218-
select {
219-
case proto := <-obsCh:
220-
if proto != "HTTP/1.1" {
221-
t.Errorf("expected proxy to see HTTP/1.1 CONNECT, got %s", proto)
222-
}
223-
case <-time.After(2 * time.Second):
217+
mu.Lock()
218+
defer mu.Unlock()
219+
if !used {
224220
t.Fatal("proxy handler was never invoked")
225221
}
222+
if proto != "HTTP/1.1" {
223+
t.Errorf("expected proxy to see HTTP/1.1 CONNECT, got %s", proto)
224+
}
226225
}
227226

228227
func TestWithProxyTransport_ProxyAuth(t *testing.T) {
229228
target := startTargetServer(t)
230229

231230
t.Run("http proxy with auth", func(t *testing.T) {
232-
proxyURL := startProxyWithAuth(t, false, "user", "pass")
231+
proxyURL := startProxy(t, proxyOpts{username: "user", password: "pass"})
233232
transport := withProxyTransport(newTestTransport(), proxyURL, "")
234233
t.Cleanup(transport.CloseIdleConnections)
235234
client := &http.Client{Transport: transport, Timeout: 10 * time.Second}
@@ -249,7 +248,7 @@ func TestWithProxyTransport_ProxyAuth(t *testing.T) {
249248
})
250249

251250
t.Run("https proxy with auth", func(t *testing.T) {
252-
proxyURL := startProxyWithAuth(t, true, "user", "s3cret")
251+
proxyURL := startProxy(t, proxyOpts{useTLS: true, username: "user", password: "s3cret"})
253252
transport := withProxyTransport(newTestTransport(), proxyURL, "")
254253
t.Cleanup(transport.CloseIdleConnections)
255254
client := &http.Client{Transport: transport, Timeout: 10 * time.Second}
@@ -273,7 +272,7 @@ func TestWithProxyTransport_HTTPSProxy_HTTP2ToOrigin(t *testing.T) {
273272
// Verify that when tunneling through an HTTPS proxy, the connection to
274273
// the origin target still negotiates HTTP/2 (not downgraded to HTTP/1.1).
275274
target := startTargetServer(t)
276-
proxyURL, _ := startProxy(t, true)
275+
proxyURL := startProxy(t, proxyOpts{useTLS: true})
277276

278277
transport := withProxyTransport(newTestTransport(), proxyURL, "")
279278
t.Cleanup(transport.CloseIdleConnections)
@@ -322,7 +321,7 @@ func TestWithProxyTransport_HandshakeFailureClosesConn(t *testing.T) {
322321
close(connClosed)
323322
}()
324323

325-
proxyURL, _ := startProxy(t, true)
324+
proxyURL := startProxy(t, proxyOpts{useTLS: true})
326325
transport := withProxyTransport(newTestTransport(), proxyURL, "")
327326
t.Cleanup(transport.CloseIdleConnections)
328327
client := &http.Client{Transport: transport, Timeout: 5 * time.Second}

0 commit comments

Comments
 (0)