@@ -14,6 +14,7 @@ use reqwest::Client;
14
14
use rmcp:: serde_json;
15
15
use rmcp:: transport:: auth:: {
16
16
AuthClient ,
17
+ OAuthClientConfig ,
17
18
OAuthState ,
18
19
OAuthTokenResponse ,
19
20
} ;
@@ -26,6 +27,10 @@ use rmcp::transport::{
26
27
StreamableHttpClientTransport ,
27
28
WorkerTransport ,
28
29
} ;
30
+ use serde:: {
31
+ Deserialize ,
32
+ Serialize ,
33
+ } ;
29
34
use sha2:: {
30
35
Digest ,
31
36
Sha256 ,
@@ -64,6 +69,8 @@ pub enum OauthUtilError {
64
69
Directory ( #[ from] DirectoryError ) ,
65
70
#[ error( transparent) ]
66
71
Reqwest ( #[ from] reqwest:: Error ) ,
72
+ #[ error( "Malformed directory" ) ]
73
+ MalformDirectory ,
67
74
}
68
75
69
76
/// A guard that automatically cancels the cancellation token when dropped.
@@ -79,6 +86,27 @@ impl Drop for LoopBackDropGuard {
79
86
}
80
87
}
81
88
89
+ /// This is modeled after [OAuthClientConfig]
90
+ /// It's only here because [OAuthClientConfig] does not implement Serialize and Deserialize
91
+ #[ derive( Clone , Serialize , Deserialize , Debug ) ]
92
+ pub struct Registration {
93
+ pub client_id : String ,
94
+ pub client_secret : Option < String > ,
95
+ pub scopes : Vec < String > ,
96
+ pub redirect_uri : String ,
97
+ }
98
+
99
+ impl From < OAuthClientConfig > for Registration {
100
+ fn from ( value : OAuthClientConfig ) -> Self {
101
+ Self {
102
+ client_id : value. client_id ,
103
+ client_secret : value. client_secret ,
104
+ scopes : value. scopes ,
105
+ redirect_uri : value. redirect_uri ,
106
+ }
107
+ }
108
+ }
109
+
82
110
/// A guard that manages the lifecycle of an authenticated MCP client and automatically
83
111
/// persists OAuth credentials when dropped.
84
112
///
@@ -164,6 +192,10 @@ pub enum HttpTransport {
164
192
WithoutAuth ( WorkerTransport < StreamableHttpClientWorker < Client > > ) ,
165
193
}
166
194
195
+ fn get_scopes ( ) -> & ' static [ & ' static str ] {
196
+ & [ "openid" , "mcp" , "email" , "profile" ]
197
+ }
198
+
167
199
pub async fn get_http_transport (
168
200
os : & Os ,
169
201
delete_cache : bool ,
@@ -175,6 +207,7 @@ pub async fn get_http_transport(
175
207
let url = Url :: from_str ( url) ?;
176
208
let key = compute_key ( & url) ;
177
209
let cred_full_path = cred_dir. join ( format ! ( "{key}.token.json" ) ) ;
210
+ let reg_full_path = cred_dir. join ( format ! ( "{key}.registration.json" ) ) ;
178
211
179
212
if delete_cache && cred_full_path. is_file ( ) {
180
213
tokio:: fs:: remove_file ( & cred_full_path) . await ?;
@@ -188,7 +221,8 @@ pub async fn get_http_transport(
188
221
let auth_client = match auth_client {
189
222
Some ( auth_client) => auth_client,
190
223
None => {
191
- let am = get_auth_manager ( url. clone ( ) , cred_full_path. clone ( ) , messenger) . await ?;
224
+ let am =
225
+ get_auth_manager ( url. clone ( ) , cred_full_path. clone ( ) , reg_full_path. clone ( ) , messenger) . await ?;
192
226
AuthClient :: new ( reqwest_client, am)
193
227
} ,
194
228
} ;
@@ -215,45 +249,67 @@ pub async fn get_http_transport(
215
249
async fn get_auth_manager (
216
250
url : Url ,
217
251
cred_full_path : PathBuf ,
252
+ reg_full_path : PathBuf ,
218
253
messenger : & dyn Messenger ,
219
254
) -> Result < AuthorizationManager , OauthUtilError > {
220
- let content_as_bytes = tokio:: fs:: read ( & cred_full_path) . await ;
255
+ let cred_as_bytes = tokio:: fs:: read ( & cred_full_path) . await ;
256
+ let reg_as_bytes = tokio:: fs:: read ( & reg_full_path) . await ;
221
257
let mut oauth_state = OAuthState :: new ( url, None ) . await ?;
222
258
223
- match content_as_bytes {
224
- Ok ( bytes) => {
225
- let token = serde_json:: from_slice :: < OAuthTokenResponse > ( & bytes) ?;
259
+ match ( cred_as_bytes, reg_as_bytes) {
260
+ ( Ok ( cred_as_bytes) , Ok ( reg_as_bytes) ) => {
261
+ let token = serde_json:: from_slice :: < OAuthTokenResponse > ( & cred_as_bytes) ?;
262
+ let reg = serde_json:: from_slice :: < Registration > ( & reg_as_bytes) ?;
226
263
227
- oauth_state. set_credentials ( "id" , token) . await ?;
264
+ oauth_state. set_credentials ( & reg . client_id , token) . await ?;
228
265
229
266
debug ! ( "## mcp: credentials set with cache" ) ;
230
267
231
268
Ok ( oauth_state
232
269
. into_authorization_manager ( )
233
270
. ok_or ( OauthUtilError :: MissingAuthorizationManager ) ?)
234
271
} ,
235
- Err ( e ) => {
236
- info ! ( "Error reading cached credentials: {e} " ) ;
272
+ _ => {
273
+ info ! ( "Error reading cached credentials" ) ;
237
274
debug ! ( "## mcp: cache read failed. constructing auth manager from scratch" ) ;
238
- get_auth_manager_impl ( oauth_state, messenger) . await
275
+ let ( am, redirect_uri) = get_auth_manager_impl ( oauth_state, messenger) . await ?;
276
+
277
+ // Client registration is done in [start_authorization]
278
+ // If we have gotten past that point that means we have the info to persist the
279
+ // registration on disk. These are info that we need to refresh stake
280
+ // tokens. This is in contrast to tokens, which we only persist when we drop
281
+ // the client (because that way we can write once and ensure what is on the
282
+ // disk always the most up to date)
283
+ let ( client_id, _credentials) = am. get_credentials ( ) . await ?;
284
+ let reg = Registration {
285
+ client_id,
286
+ client_secret : None ,
287
+ scopes : get_scopes ( ) . iter ( ) . map ( |s| ( * s) . to_string ( ) ) . collect :: < Vec < _ > > ( ) ,
288
+ redirect_uri,
289
+ } ;
290
+ let reg_as_str = serde_json:: to_string_pretty ( & reg) ?;
291
+ let reg_parent_path = reg_full_path. parent ( ) . ok_or ( OauthUtilError :: MalformDirectory ) ?;
292
+ tokio:: fs:: create_dir ( reg_parent_path) . await ?;
293
+ tokio:: fs:: write ( reg_full_path, & reg_as_str) . await ?;
294
+
295
+ Ok ( am)
239
296
} ,
240
297
}
241
298
}
242
299
243
300
async fn get_auth_manager_impl (
244
301
mut oauth_state : OAuthState ,
245
302
messenger : & dyn Messenger ,
246
- ) -> Result < AuthorizationManager , OauthUtilError > {
303
+ ) -> Result < ( AuthorizationManager , String ) , OauthUtilError > {
247
304
let socket_addr = SocketAddr :: from ( ( [ 127 , 0 , 0 , 1 ] , 0 ) ) ;
248
305
let cancellation_token = tokio_util:: sync:: CancellationToken :: new ( ) ;
249
306
let ( tx, rx) = tokio:: sync:: oneshot:: channel :: < String > ( ) ;
250
307
251
308
let ( actual_addr, _dg) = make_svc ( tx, socket_addr, cancellation_token) . await ?;
252
309
info ! ( "Listening on local host port {:?} for oauth" , actual_addr) ;
253
310
254
- oauth_state
255
- . start_authorization ( & [ "mcp" , "profile" , "email" ] , & format ! ( "http://{}" , actual_addr) )
256
- . await ?;
311
+ let redirect_uri = format ! ( "http://{}" , actual_addr) ;
312
+ oauth_state. start_authorization ( get_scopes ( ) , & redirect_uri) . await ?;
257
313
258
314
let auth_url = oauth_state. get_authorization_url ( ) . await ?;
259
315
_ = messenger. send_oauth_link ( auth_url) . await ;
@@ -264,7 +320,7 @@ async fn get_auth_manager_impl(
264
320
. into_authorization_manager ( )
265
321
. ok_or ( OauthUtilError :: MissingAuthorizationManager ) ?;
266
322
267
- Ok ( am )
323
+ Ok ( ( am , redirect_uri ) )
268
324
}
269
325
270
326
pub fn compute_key ( rs : & Url ) -> String {
@@ -320,7 +376,7 @@ async fn make_svc(
320
376
{
321
377
sender. send ( code) . map_err ( LoopBackError :: Send ) ?;
322
378
}
323
- mk_response ( "Auth code sent " . to_string ( ) )
379
+ mk_response ( "You can close this page now " . to_string ( ) )
324
380
} )
325
381
}
326
382
}
0 commit comments