Skip to content

Commit 77c751b

Browse files
committed
2 parents c48d95b + af47554 commit 77c751b

File tree

2 files changed

+50
-3
lines changed

2 files changed

+50
-3
lines changed

client.go

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"context"
1010
"crypto/tls"
1111
"errors"
12+
"fmt"
1213
"io"
1314
"io/ioutil"
1415
"net"
@@ -318,14 +319,14 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
318319
}
319320

320321
netConn, err := netDial("tcp", hostPort)
322+
if err != nil {
323+
return nil, nil, err
324+
}
321325
if trace != nil && trace.GotConn != nil {
322326
trace.GotConn(httptrace.GotConnInfo{
323327
Conn: netConn,
324328
})
325329
}
326-
if err != nil {
327-
return nil, nil, err
328-
}
329330

330331
defer func() {
331332
if netConn != nil {
@@ -370,6 +371,17 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
370371

371372
resp, err := http.ReadResponse(conn.br, req)
372373
if err != nil {
374+
if d.TLSClientConfig != nil {
375+
for _, proto := range d.TLSClientConfig.NextProtos {
376+
if proto != "http/1.1" {
377+
return nil, nil, fmt.Errorf(
378+
"websocket: protocol %q was given but is not supported;"+
379+
"sharing tls.Config with net/http Transport can cause this error: %w",
380+
proto, err,
381+
)
382+
}
383+
}
384+
}
373385
return nil, nil, err
374386
}
375387

client_server_test.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,3 +1098,38 @@ func TestNetDialConnect(t *testing.T) {
10981098
}
10991099
}
11001100
}
1101+
func TestNextProtos(t *testing.T) {
1102+
ts := httptest.NewUnstartedServer(
1103+
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
1104+
)
1105+
ts.EnableHTTP2 = true
1106+
ts.StartTLS()
1107+
defer ts.Close()
1108+
1109+
d := Dialer{
1110+
TLSClientConfig: ts.Client().Transport.(*http.Transport).TLSClientConfig,
1111+
}
1112+
1113+
r, err := ts.Client().Get(ts.URL)
1114+
if err != nil {
1115+
t.Fatalf("Get: %v", err)
1116+
}
1117+
r.Body.Close()
1118+
1119+
// Asserts that Dialer.TLSClientConfig.NextProtos contains "h2"
1120+
// after the Client.Get call from net/http above.
1121+
var containsHTTP2 bool = false
1122+
for _, proto := range d.TLSClientConfig.NextProtos {
1123+
if proto == "h2" {
1124+
containsHTTP2 = true
1125+
}
1126+
}
1127+
if !containsHTTP2 {
1128+
t.Fatalf("Dialer.TLSClientConfig.NextProtos does not contain \"h2\"")
1129+
}
1130+
1131+
_, _, err = d.Dial(makeWsProto(ts.URL), nil)
1132+
if err == nil {
1133+
t.Fatalf("Dial succeeded, expect fail ")
1134+
}
1135+
}

0 commit comments

Comments
 (0)