Skip to content

Commit a4313f9

Browse files
committed
mcp: refactor the streamable client test to be more flexible
Use a fake streamable server to facilitate testing client behavior. For this commit, just update the existing test (moved to a new file for isolation). Subsequent CLs will add more tests. Improve one client error message that occurred while debuging tests. For #393
1 parent 3026172 commit a4313f9

File tree

3 files changed

+193
-72
lines changed

3 files changed

+193
-72
lines changed

mcp/streamable.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1234,7 +1234,14 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
12341234

12351235
default:
12361236
resp.Body.Close()
1237-
return fmt.Errorf("unsupported content type %q", ct)
1237+
switch msg := msg.(type) {
1238+
case *jsonrpc.Request:
1239+
return fmt.Errorf("unsupported content type %q when sending %q (status: %d)", ct, msg.Method, resp.StatusCode)
1240+
case *jsonrpc.Response:
1241+
return fmt.Errorf("unsupported content type %q when sending jsonrpc response #%d (status: %d)", ct, msg.ID, resp.StatusCode)
1242+
default:
1243+
panic("unreachable")
1244+
}
12381245
}
12391246
return nil
12401247
}

mcp/streamable_client_test.go

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
// Copyright 2025 The Go MCP SDK Authors. All rights reserved.
2+
// Use of this source code is governed by an MIT-style
3+
// license that can be found in the LICENSE file.
4+
5+
package mcp
6+
7+
import (
8+
"context"
9+
"io"
10+
"net/http"
11+
"net/http/httptest"
12+
"sync"
13+
"testing"
14+
15+
"github.com/google/go-cmp/cmp"
16+
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
17+
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
18+
)
19+
20+
type streamableRequestKey struct {
21+
httpMethod string // http method
22+
sessionID string // session ID header
23+
jsonrpcMethod string // jsonrpc method, or "" for non-requests
24+
}
25+
26+
type header map[string]string
27+
28+
type streamableResponse struct {
29+
header header
30+
status int // or http.StatusOK
31+
body string // or ""
32+
optional bool // if set, request need not be sent
33+
wantProtocolVersion string // if "", unchecked
34+
callback func() // if set, called after the request is handled
35+
}
36+
37+
type fakeResponses map[streamableRequestKey]*streamableResponse
38+
39+
type fakeStreamableServer struct {
40+
t *testing.T
41+
responses fakeResponses
42+
43+
callMu sync.Mutex
44+
calls map[streamableRequestKey]int
45+
}
46+
47+
func (s *fakeStreamableServer) missingRequests() []streamableRequestKey {
48+
s.callMu.Lock()
49+
defer s.callMu.Unlock()
50+
51+
var unused []streamableRequestKey
52+
for k, resp := range s.responses {
53+
if s.calls[k] == 0 && !resp.optional {
54+
unused = append(unused, k)
55+
}
56+
}
57+
return unused
58+
}
59+
60+
func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
61+
key := streamableRequestKey{
62+
httpMethod: req.Method,
63+
sessionID: req.Header.Get(sessionIDHeader),
64+
}
65+
if req.Method == http.MethodPost {
66+
body, err := io.ReadAll(req.Body)
67+
if err != nil {
68+
s.t.Errorf("failed to read body: %v", err)
69+
http.Error(w, "failed to read body", http.StatusInternalServerError)
70+
return
71+
}
72+
msg, err := jsonrpc.DecodeMessage(body)
73+
if err != nil {
74+
s.t.Errorf("invalid body: %v", err)
75+
http.Error(w, "invalid body", http.StatusInternalServerError)
76+
return
77+
}
78+
if r, ok := msg.(*jsonrpc.Request); ok {
79+
key.jsonrpcMethod = r.Method
80+
}
81+
}
82+
83+
s.callMu.Lock()
84+
if s.calls == nil {
85+
s.calls = make(map[streamableRequestKey]int)
86+
}
87+
s.calls[key]++
88+
s.callMu.Unlock()
89+
90+
resp, ok := s.responses[key]
91+
if !ok {
92+
s.t.Errorf("missing response for %v", key)
93+
http.Error(w, "no response", http.StatusInternalServerError)
94+
return
95+
}
96+
if resp.callback != nil {
97+
defer resp.callback()
98+
}
99+
for k, v := range resp.header {
100+
w.Header().Set(k, v)
101+
}
102+
status := resp.status
103+
if status == 0 {
104+
status = http.StatusOK
105+
}
106+
w.WriteHeader(status)
107+
108+
if v := req.Header.Get(protocolVersionHeader); v != resp.wantProtocolVersion && resp.wantProtocolVersion != "" {
109+
s.t.Errorf("%v: bad protocol version header: got %q, want %q", key, v, resp.wantProtocolVersion)
110+
}
111+
w.Write([]byte(resp.body))
112+
}
113+
114+
var (
115+
initResult = &InitializeResult{
116+
Capabilities: &ServerCapabilities{
117+
Completions: &CompletionCapabilities{},
118+
Logging: &LoggingCapabilities{},
119+
Tools: &ToolCapabilities{ListChanged: true},
120+
},
121+
ProtocolVersion: latestProtocolVersion,
122+
ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"},
123+
}
124+
initResp = resp(1, initResult, nil)
125+
)
126+
127+
func jsonBody(t *testing.T, msg jsonrpc2.Message) string {
128+
data, err := jsonrpc2.EncodeMessage(msg)
129+
if err != nil {
130+
t.Fatalf("encoding failed: %v", err)
131+
}
132+
return string(data)
133+
}
134+
135+
func TestStreamableClientTransportLifecycle(t *testing.T) {
136+
ctx := context.Background()
137+
138+
// The lifecycle test verifies various behavior of the streamable client
139+
// initialization:
140+
// - check that it can handle application/json responses
141+
// - check that it sends the negotiated protocol version
142+
fake := &fakeStreamableServer{
143+
t: t,
144+
responses: fakeResponses{
145+
{"POST", "", methodInitialize}: {
146+
header: header{
147+
"Content-Type": "application/json",
148+
sessionIDHeader: "123",
149+
},
150+
body: jsonBody(t, initResp),
151+
},
152+
{"POST", "123", notificationInitialized}: {
153+
status: http.StatusAccepted,
154+
wantProtocolVersion: latestProtocolVersion,
155+
},
156+
{"GET", "123", ""}: {
157+
header: header{
158+
"Content-Type": "text/event-stream",
159+
},
160+
optional: true,
161+
wantProtocolVersion: latestProtocolVersion,
162+
},
163+
{"DELETE", "123", ""}: {},
164+
},
165+
}
166+
167+
httpServer := httptest.NewServer(fake)
168+
defer httpServer.Close()
169+
170+
transport := &StreamableClientTransport{Endpoint: httpServer.URL}
171+
client := NewClient(testImpl, nil)
172+
session, err := client.Connect(ctx, transport, nil)
173+
if err != nil {
174+
t.Fatalf("client.Connect() failed: %v", err)
175+
}
176+
if err := session.Close(); err != nil {
177+
t.Errorf("closing session: %v", err)
178+
}
179+
if missing := fake.missingRequests(); len(missing) > 0 {
180+
t.Errorf("did not receive expected requests: %v", missing)
181+
}
182+
if diff := cmp.Diff(initResult, session.state.InitializeResult); diff != "" {
183+
t.Errorf("mismatch (-want, +got):\n%s", diff)
184+
}
185+
}

mcp/streamable_test.go

Lines changed: 0 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,77 +1035,6 @@ func mustMarshal(v any) json.RawMessage {
10351035
return data
10361036
}
10371037

1038-
func TestStreamableClientTransport(t *testing.T) {
1039-
// This test verifies various behavior of the streamable client transport:
1040-
// - check that it can handle application/json responses
1041-
// - check that it sends the negotiated protocol version
1042-
//
1043-
// TODO(rfindley): make this test more comprehensive, similar to
1044-
// [TestStreamableServerTransport].
1045-
ctx := context.Background()
1046-
resp := func(id int64, result any, err error) *jsonrpc.Response {
1047-
return &jsonrpc.Response{
1048-
ID: jsonrpc2.Int64ID(id),
1049-
Result: mustMarshal(result),
1050-
Error: err,
1051-
}
1052-
}
1053-
initResult := &InitializeResult{
1054-
Capabilities: &ServerCapabilities{
1055-
Completions: &CompletionCapabilities{},
1056-
Logging: &LoggingCapabilities{},
1057-
Tools: &ToolCapabilities{ListChanged: true},
1058-
},
1059-
ProtocolVersion: latestProtocolVersion,
1060-
ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"},
1061-
}
1062-
initResp := resp(1, initResult, nil)
1063-
1064-
var reqN atomic.Int32 // request count
1065-
serverHandler := func(w http.ResponseWriter, r *http.Request) {
1066-
rN := reqN.Add(1)
1067-
1068-
// TODO(rfindley): if the status code is NoContent or Accepted, we should
1069-
// probably be tolerant of when the content type is not application/json.
1070-
w.Header().Set("Content-Type", "application/json")
1071-
if rN == 1 {
1072-
data, err := jsonrpc2.EncodeMessage(initResp)
1073-
if err != nil {
1074-
t.Errorf("encoding failed: %v", err)
1075-
}
1076-
w.Header().Set("Mcp-Session-Id", "123")
1077-
w.Write(data)
1078-
} else {
1079-
if v := r.Header.Get(protocolVersionHeader); v != latestProtocolVersion {
1080-
t.Errorf("bad protocol version header: got %q, want %q", v, latestProtocolVersion)
1081-
}
1082-
}
1083-
}
1084-
1085-
httpServer := httptest.NewServer(http.HandlerFunc(serverHandler))
1086-
defer httpServer.Close()
1087-
1088-
transport := &StreamableClientTransport{Endpoint: httpServer.URL}
1089-
client := NewClient(testImpl, nil)
1090-
session, err := client.Connect(ctx, transport, nil)
1091-
if err != nil {
1092-
t.Fatalf("client.Connect() failed: %v", err)
1093-
}
1094-
if err := session.Close(); err != nil {
1095-
t.Errorf("closing session: %v", err)
1096-
}
1097-
1098-
if got, want := reqN.Load(), int32(3); got < want {
1099-
// Expect at least 3 requests: initialize, initialized, and DELETE.
1100-
// We may or may not observe the GET, depending on timing.
1101-
t.Errorf("unexpected number of requests: got %d, want at least %d", got, want)
1102-
}
1103-
1104-
if diff := cmp.Diff(initResult, session.state.InitializeResult); diff != "" {
1105-
t.Errorf("mismatch (-want, +got):\n%s", diff)
1106-
}
1107-
}
1108-
11091038
func TestEventID(t *testing.T) {
11101039
tests := []struct {
11111040
sid StreamID

0 commit comments

Comments
 (0)