Skip to content

Commit a296b57

Browse files
committed
Implement WebSocket proxy. Closes #7
1 parent 139831e commit a296b57

3 files changed

Lines changed: 355 additions & 161 deletions

File tree

httpx/transport.go

Lines changed: 101 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -369,16 +369,16 @@ func ProxyURL(fixedURL *url.URL) func(*http.Request) (*url.URL, error) {
369369

370370
// transportRequest is a wrapper around a *http.Request that adds
371371
// optional extra headers to write.
372-
type transportRequest struct {
372+
type TransportRequest struct {
373373
*http.Request // original request, not to be mutated
374-
extra http.Header // extra headers to write, or nil
374+
Extra http.Header // extra headers to write, or nil
375375
}
376376

377-
func (tr *transportRequest) extraHeaders() http.Header {
378-
if tr.extra == nil {
379-
tr.extra = make(http.Header)
377+
func (tr *TransportRequest) extraHeaders() http.Header {
378+
if tr.Extra == nil {
379+
tr.Extra = make(http.Header)
380380
}
381-
return tr.extra
381+
return tr.Extra
382382
}
383383

384384
// RoundTrip implements the RoundTripper interface.
@@ -411,8 +411,8 @@ func (t *Transport) RoundTrip(req *http.Request) (resp *http.Response, err error
411411
closeBody(req)
412412
return nil, errors.New("http: no Host in request URL")
413413
}
414-
treq := &transportRequest{Request: req}
415-
cm, err := t.connectMethodForRequest(treq)
414+
treq := &TransportRequest{Request: req}
415+
cm, err := t.ConnectMethodForRequest(treq)
416416
if err != nil {
417417
closeBody(req)
418418
return nil, err
@@ -528,24 +528,24 @@ func (e *envOnce) reset() {
528528
e.val = ""
529529
}
530530

531-
func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) {
532-
cm.targetScheme = treq.URL.Scheme
533-
cm.targetAddr = canonicalAddr(treq.URL)
531+
func (t *Transport) ConnectMethodForRequest(treq *TransportRequest) (cm ConnectMethod, err error) {
532+
cm.TargetScheme = treq.URL.Scheme
533+
cm.TargetAddr = canonicalAddr(treq.URL)
534534
if t.Proxy2 != nil {
535-
cm.proxyURL, cm.proxyTLSConfig, err = t.Proxy2(treq.Request)
535+
cm.ProxyURL, cm.ProxyTLSConfig, err = t.Proxy2(treq.Request)
536536
} else if t.Proxy != nil {
537-
cm.proxyURL, err = t.Proxy(treq.Request)
537+
cm.ProxyURL, err = t.Proxy(treq.Request)
538538
}
539539
return cm, err
540540
}
541541

542542
// proxyAuth returns the Proxy-Authorization header to set
543543
// on requests, if applicable.
544-
func (cm *connectMethod) proxyAuth() string {
545-
if cm.proxyURL == nil {
544+
func (cm *ConnectMethod) ProxyAuth() string {
545+
if cm.ProxyURL == nil {
546546
return ""
547547
}
548-
if u := cm.proxyURL.User; u != nil {
548+
if u := cm.ProxyURL.User; u != nil {
549549
username := u.Username()
550550
password, _ := u.Password()
551551
return "Basic " + basicAuth(username, password)
@@ -613,9 +613,9 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool {
613613
}
614614

615615
// getIdleConnCh returns a channel to receive and return idle
616-
// persistent connection for the given connectMethod.
616+
// persistent connection for the given ConnectMethod.
617617
// It may return nil, if persistent connections are not being used.
618-
func (t *Transport) getIdleConnCh(cm connectMethod) chan *persistConn {
618+
func (t *Transport) getIdleConnCh(cm ConnectMethod) chan *persistConn {
619619
if t.DisableKeepAlives {
620620
return nil
621621
}
@@ -634,7 +634,7 @@ func (t *Transport) getIdleConnCh(cm connectMethod) chan *persistConn {
634634
return ch
635635
}
636636

637-
func (t *Transport) getIdleConn(cm connectMethod) (pconn *persistConn) {
637+
func (t *Transport) getIdleConn(cm ConnectMethod) (pconn *persistConn) {
638638
key := cm.key()
639639
t.idleMu.Lock()
640640
defer t.idleMu.Unlock()
@@ -704,10 +704,10 @@ func (t *Transport) dial(network, addr string) (c net.Conn, err error) {
704704
var prePendingDial, postPendingDial func()
705705

706706
// getConn dials and creates a new persistConn to the target as
707-
// specified in the connectMethod. This includes doing a proxy CONNECT
707+
// specified in the ConnectMethod. This includes doing a proxy CONNECT
708708
// and/or setting up TLS. If this doesn't return an error, the persistConn
709709
// is ready to write requests to.
710-
func (t *Transport) getConn(req *http.Request, cm connectMethod) (*persistConn, error) {
710+
func (t *Transport) getConn(req *http.Request, cm ConnectMethod) (*persistConn, error) {
711711
if pc := t.getIdleConn(cm); pc != nil {
712712
// set request canceler to some non-nil function so we
713713
// can detect whether it was cleared between now and when
@@ -771,31 +771,17 @@ func (t *Transport) getConn(req *http.Request, cm connectMethod) (*persistConn,
771771
}
772772
}
773773

774-
func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
775-
pconn := &persistConn{
776-
t: t,
777-
cacheKey: cm.key(),
778-
reqch: make(chan requestAndChan, 1),
779-
writech: make(chan writeRequest, 1),
780-
closech: make(chan struct{}),
781-
writeErrCh: make(chan error, 1),
782-
}
783-
tlsDial := t.DialTLS != nil && cm.targetScheme == "https" && cm.proxyURL == nil
774+
func (t *Transport) DoDial(cm ConnectMethod) (conn net.Conn, isProxy bool, err error) {
775+
tlsDial := t.DialTLS != nil && cm.TargetScheme == "https" && cm.ProxyURL == nil
784776
if tlsDial {
785777
var err error
786-
pconn.conn, err = t.DialTLS("tcp", cm.addr())
778+
conn, err = t.DialTLS("tcp", cm.addr())
787779
if err != nil {
788-
return nil, err
789-
}
790-
if tc, ok := pconn.conn.(*tls.Conn); ok {
791-
cs := tc.ConnectionState()
792-
pconn.tlsState = &cs
780+
return nil, false, err
793781
}
794782
} else {
795-
var conn net.Conn
796-
var err error
797-
proxyTLSConfigTemlate := cm.proxyTLSConfig
798-
tlsProxy := cm.proxyURL != nil && cm.proxyURL.Scheme == "https"
783+
proxyTLSConfigTemlate := cm.ProxyTLSConfig
784+
tlsProxy := cm.ProxyURL != nil && cm.ProxyURL.Scheme == "https"
799785
if tlsProxy {
800786
if proxyTLSConfigTemlate == nil {
801787
if t.DialTLS != nil {
@@ -808,44 +794,37 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
808794
if conn == nil {
809795
conn, err = t.dial("tcp", cm.addr())
810796
if err != nil {
811-
if cm.proxyURL != nil {
812-
err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err)
797+
if cm.ProxyURL != nil {
798+
err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.ProxyURL, err)
813799
}
814-
return nil, err
800+
return nil, false, err
815801
}
816802
if tlsProxy {
817803
proxyTLSConfig := new(tls.Config)
818804
*proxyTLSConfig = *proxyTLSConfigTemlate
819-
host, _, err := net.SplitHostPort(cm.proxyURL.Host)
805+
host, _, err := net.SplitHostPort(cm.ProxyURL.Host)
820806
if err == nil {
821807
proxyTLSConfig.ServerName = host
822808
}
823809
conn = tls.Client(conn, proxyTLSConfig)
824810
}
825811
}
826-
pconn.conn = conn
827812
}
828813

829814
// Proxy setup.
830815
switch {
831-
case cm.proxyURL == nil:
816+
case cm.ProxyURL == nil:
832817
// Do nothing. Not using a proxy.
833-
case cm.targetScheme == "http":
834-
pconn.isProxy = true
835-
if pa := cm.proxyAuth(); pa != "" {
836-
pconn.mutateHeaderFunc = func(h http.Header) {
837-
h.Set("Proxy-Authorization", pa)
838-
}
839-
}
840-
case cm.targetScheme == "https":
841-
conn := pconn.conn
818+
case cm.TargetScheme == "http":
819+
isProxy = true
820+
case cm.TargetScheme == "https":
842821
connectReq := &http.Request{
843822
Method: "CONNECT",
844-
URL: &url.URL{Opaque: cm.targetAddr},
845-
Host: cm.targetAddr,
823+
URL: &url.URL{Opaque: cm.TargetAddr},
824+
Host: cm.TargetAddr,
846825
Header: make(http.Header),
847826
}
848-
if pa := cm.proxyAuth(); pa != "" {
827+
if pa := cm.ProxyAuth(); pa != "" {
849828
connectReq.Header.Set("Proxy-Authorization", pa)
850829
}
851830
connectReq.Write(conn)
@@ -854,26 +833,28 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
854833
// Okay to use and discard buffered reader here, because
855834
// TLS server will not speak until spoken to.
856835
br := bufio.NewReader(conn)
857-
resp, err := http.ReadResponse(br, connectReq)
836+
var resp *http.Response
837+
resp, err = http.ReadResponse(br, connectReq)
858838
if err != nil {
859839
conn.Close()
860-
return nil, err
840+
return
861841
}
862842
if resp.StatusCode != 200 {
863843
f := strings.SplitN(resp.Status, " ", 2)
864844
conn.Close()
865-
return nil, errors.New(f[1])
845+
err = errors.New(f[1])
846+
return
866847
}
867848
}
868849

869-
if cm.targetScheme == "https" && !tlsDial {
850+
if cm.TargetScheme == "https" && !tlsDial {
870851
// Initiate TLS and check remote host name against certificate.
871852
cfg := cloneTLSClientConfig(t.TLSClientConfig)
872853
if cfg.ServerName == "" {
873854
cfg.ServerName = cm.tlsHost()
874855
}
875-
plainConn := pconn.conn
876-
tlsConn := tls.Client(plainConn, cfg)
856+
plainConn := conn
857+
tlsConn := tls.Client(conn, cfg)
877858
errc := make(chan error, 2)
878859
var timer *time.Timer // for canceling TLS handshake
879860
if d := t.TLSHandshakeTimeout; d != 0 {
@@ -888,21 +869,46 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
888869
}
889870
errc <- err
890871
}()
891-
if err := <-errc; err != nil {
872+
if err = <-errc; err != nil {
892873
plainConn.Close()
893-
return nil, err
874+
return
894875
}
895876
if !cfg.InsecureSkipVerify {
896-
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
877+
if err = tlsConn.VerifyHostname(cfg.ServerName); err != nil {
897878
plainConn.Close()
898-
return nil, err
879+
return
899880
}
900881
}
901-
cs := tlsConn.ConnectionState()
902-
pconn.tlsState = &cs
903-
pconn.conn = tlsConn
882+
conn = tlsConn
904883
}
884+
return
885+
}
905886

887+
func (t *Transport) dialConn(cm ConnectMethod) (*persistConn, error) {
888+
pconn := &persistConn{
889+
t: t,
890+
cacheKey: cm.key(),
891+
reqch: make(chan requestAndChan, 1),
892+
writech: make(chan writeRequest, 1),
893+
closech: make(chan struct{}),
894+
writeErrCh: make(chan error, 1),
895+
}
896+
var err error
897+
pconn.conn, pconn.isProxy, err = t.DoDial(cm)
898+
if err != nil {
899+
return nil, err
900+
}
901+
if tc, ok := pconn.conn.(*tls.Conn); ok {
902+
cs := tc.ConnectionState()
903+
pconn.tlsState = &cs
904+
}
905+
if pconn.isProxy {
906+
if pa := cm.ProxyAuth(); pa != "" {
907+
pconn.mutateHeaderFunc = func(h http.Header) {
908+
h.Set("Proxy-Authorization", pa)
909+
}
910+
}
911+
}
906912
pconn.br = bufio.NewReader(noteEOFReader{pconn.conn, &pconn.sawEOF})
907913
pconn.bw = bufio.NewWriter(pconn.conn)
908914
go pconn.readLoop()
@@ -963,7 +969,7 @@ func useProxy(addr string) bool {
963969
return true
964970
}
965971

966-
// connectMethod is the map key (in its String form) for keeping persistent
972+
// ConnectMethod is the map key (in its String form) for keeping persistent
967973
// TCP connections alive for subsequent HTTP requests.
968974
//
969975
// A connect method may be of the following types:
@@ -977,49 +983,49 @@ func useProxy(addr string) bool {
977983
//
978984
// Note: no support to https to the proxy yet.
979985
//
980-
type connectMethod struct {
981-
proxyURL *url.URL // nil for no proxy, else full proxy URL
982-
proxyTLSConfig *tls.Config // TLS config for proxy
983-
targetScheme string // "http" or "https"
984-
targetAddr string // Not used if proxy + http targetScheme (4th example in table)
986+
type ConnectMethod struct {
987+
ProxyURL *url.URL // nil for no proxy, else full proxy URL
988+
ProxyTLSConfig *tls.Config // TLS config for proxy
989+
TargetScheme string // "http" or "https"
990+
TargetAddr string // Not used if proxy + http targetScheme (4th example in table)
985991
}
986992

987-
func (cm *connectMethod) key() connectMethodKey {
993+
func (cm *ConnectMethod) key() connectMethodKey {
988994
proxyStr := ""
989-
targetAddr := cm.targetAddr
990-
if cm.proxyURL != nil {
991-
proxyStr = cm.proxyURL.String()
992-
if cm.targetScheme == "http" {
995+
targetAddr := cm.TargetAddr
996+
if cm.ProxyURL != nil {
997+
proxyStr = cm.ProxyURL.String()
998+
if cm.TargetScheme == "http" {
993999
targetAddr = ""
9941000
}
9951001
}
9961002
return connectMethodKey{
9971003
proxy: proxyStr,
998-
scheme: cm.targetScheme,
1004+
scheme: cm.TargetScheme,
9991005
addr: targetAddr,
1000-
tlsConfigAddr: uintptr(unsafe.Pointer(cm.proxyTLSConfig)),
1006+
tlsConfigAddr: uintptr(unsafe.Pointer(cm.ProxyTLSConfig)),
10011007
}
10021008
}
10031009

10041010
// addr returns the first hop "host:port" to which we need to TCP connect.
1005-
func (cm *connectMethod) addr() string {
1006-
if cm.proxyURL != nil {
1007-
return canonicalAddr(cm.proxyURL)
1011+
func (cm *ConnectMethod) addr() string {
1012+
if cm.ProxyURL != nil {
1013+
return canonicalAddr(cm.ProxyURL)
10081014
}
1009-
return cm.targetAddr
1015+
return cm.TargetAddr
10101016
}
10111017

10121018
// tlsHost returns the host name to match against the peer's
10131019
// TLS certificate.
1014-
func (cm *connectMethod) tlsHost() string {
1015-
h := cm.targetAddr
1020+
func (cm *ConnectMethod) tlsHost() string {
1021+
h := cm.TargetAddr
10161022
if hasPort(h) {
10171023
h = h[:strings.LastIndex(h, ":")]
10181024
}
10191025
return h
10201026
}
10211027

1022-
// connectMethodKey is the map key version of connectMethod, with a
1028+
// connectMethodKey is the map key version of ConnectMethod, with a
10231029
// stringified proxy URL (or the empty string) instead of a pointer to
10241030
// a URL.
10251031
type connectMethodKey struct {
@@ -1238,7 +1244,7 @@ func (pc *persistConn) writeLoop() {
12381244
wr.ch <- errors.New("http: can't write HTTP request on broken connection")
12391245
continue
12401246
}
1241-
err := doWriteRequest(wr.req.Request, pc.bw, pc.isProxy, wr.req.extra)
1247+
err := doWriteRequest(wr.req.Request, pc.bw, pc.isProxy, wr.req.Extra)
12421248
if err == nil {
12431249
err = pc.bw.Flush()
12441250
}
@@ -1302,7 +1308,7 @@ type requestAndChan struct {
13021308
// concurrently waits on both the write response and the server's
13031309
// reply.
13041310
type writeRequest struct {
1305-
req *transportRequest
1311+
req *TransportRequest
13061312
ch chan<- error
13071313
}
13081314

@@ -1327,7 +1333,7 @@ var (
13271333
testHookReadLoopBeforeNextRead func()
13281334
)
13291335

1330-
func (pc *persistConn) roundTrip(req *transportRequest) (resp *http.Response, err error) {
1336+
func (pc *persistConn) roundTrip(req *TransportRequest) (resp *http.Response, err error) {
13311337
if hook := testHookEnterRoundTrip; hook != nil {
13321338
hook()
13331339
}

0 commit comments

Comments
 (0)