Skip to content

Commit c4a0910

Browse files
authored
fix(mcp): not being able to refresh tokens for remote mcp (#2849)
* adds registration persistance for token refresh * truncates on tool description * Modifies oauth success message * adds time stamps on mcp logs
1 parent 4ea78b9 commit c4a0910

File tree

4 files changed

+132
-38
lines changed

4 files changed

+132
-38
lines changed

crates/chat-cli/src/cli/chat/cli/mcp.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ impl McpArgs {
5454
let msg = msg
5555
.iter()
5656
.map(|record| match record {
57-
LoadingRecord::Err(content) | LoadingRecord::Warn(content) | LoadingRecord::Success(content) => {
58-
content.clone()
59-
},
57+
LoadingRecord::Err(timestamp, content)
58+
| LoadingRecord::Warn(timestamp, content)
59+
| LoadingRecord::Success(timestamp, content) => format!("[{timestamp}]: {content}"),
6060
})
6161
.collect::<Vec<_>>()
6262
.join("\n--- tools refreshed ---\n");

crates/chat-cli/src/cli/chat/tool_manager.rs

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,26 @@ enum LoadingMsg {
150150
/// surface (since we would only want to surface fatal errors in non-interactive mode).
151151
#[derive(Clone, Debug)]
152152
pub enum LoadingRecord {
153-
Success(String),
154-
Warn(String),
155-
Err(String),
153+
Success(String, String),
154+
Warn(String, String),
155+
Err(String, String),
156+
}
157+
158+
impl LoadingRecord {
159+
pub fn success(msg: String) -> Self {
160+
let timestamp = chrono::Local::now().format("%Y:%H:%S").to_string();
161+
LoadingRecord::Success(timestamp, msg)
162+
}
163+
164+
pub fn warn(msg: String) -> Self {
165+
let timestamp = chrono::Local::now().format("%Y:%H:%S").to_string();
166+
LoadingRecord::Warn(timestamp, msg)
167+
}
168+
169+
pub fn err(msg: String) -> Self {
170+
let timestamp = chrono::Local::now().format("%Y:%H:%S").to_string();
171+
LoadingRecord::Err(timestamp, msg)
172+
}
156173
}
157174

158175
pub struct ToolManagerBuilder {
@@ -473,10 +490,11 @@ pub enum PromptQueryResult {
473490
/// - `IllegalChar`: The tool name contains characters that are not allowed
474491
/// - `EmptyDescription`: The tool description is empty or missing
475492
#[allow(dead_code)]
476-
enum OutOfSpecName {
493+
enum ToolValidationViolation {
477494
TooLong(String),
478495
IllegalChar(String),
479496
EmptyDescription(String),
497+
DescriptionTooLong(String),
480498
}
481499

482500
#[derive(Clone, Default, Debug, Eq, PartialEq)]
@@ -814,7 +832,7 @@ impl ToolManager {
814832
.lock()
815833
.await
816834
.iter()
817-
.any(|(_, records)| records.iter().any(|record| matches!(record, LoadingRecord::Err(_))))
835+
.any(|(_, records)| records.iter().any(|record| matches!(record, LoadingRecord::Err(..))))
818836
{
819837
queue!(
820838
stderr,
@@ -962,7 +980,7 @@ impl ToolManager {
962980
if !conflicts.is_empty() {
963981
let mut record_lock = self.mcp_load_record.lock().await;
964982
for (server_name, msg) in conflicts {
965-
let record = LoadingRecord::Err(msg);
983+
let record = LoadingRecord::err(msg);
966984
record_lock
967985
.entry(server_name)
968986
.and_modify(|v| v.push(record.clone()))
@@ -1494,9 +1512,9 @@ fn spawn_orchestrator_task(
14941512
drop(buf_writer);
14951513
let record = String::from_utf8_lossy(record_temp_buf).to_string();
14961514
let record = if process_result.is_err() {
1497-
LoadingRecord::Warn(record)
1515+
LoadingRecord::warn(record)
14981516
} else {
1499-
LoadingRecord::Success(record)
1517+
LoadingRecord::success(record)
15001518
};
15011519
load_record
15021520
.lock()
@@ -1522,7 +1540,7 @@ fn spawn_orchestrator_task(
15221540
let _ = buf_writer.flush();
15231541
drop(buf_writer);
15241542
let record = String::from_utf8_lossy(record_temp_buf).to_string();
1525-
let record = LoadingRecord::Err(record);
1543+
let record = LoadingRecord::err(record);
15261544
load_record
15271545
.lock()
15281546
.await
@@ -1606,7 +1624,7 @@ fn spawn_orchestrator_task(
16061624
let _ = buf_writer.flush();
16071625
drop(buf_writer);
16081626
let record = String::from_utf8_lossy(record_temp_buf).to_string();
1609-
let record = LoadingRecord::Err(record);
1627+
let record = LoadingRecord::err(record);
16101628
load_record
16111629
.lock()
16121630
.await
@@ -1626,7 +1644,7 @@ fn spawn_orchestrator_task(
16261644
let _ = buf_writer.flush();
16271645
drop(buf_writer);
16281646
let record_str = String::from_utf8_lossy(record_temp_buf).to_string();
1629-
let record = LoadingRecord::Warn(record_str.clone());
1647+
let record = LoadingRecord::warn(record_str.clone());
16301648
load_record
16311649
.lock()
16321650
.await
@@ -1720,7 +1738,7 @@ async fn process_tool_specs(
17201738
//
17211739
// For non-compliance due to point 1, we shall change it on behalf of the users.
17221740
// For the rest, we simply throw a warning and reject the tool.
1723-
let mut out_of_spec_tool_names = Vec::<OutOfSpecName>::new();
1741+
let mut out_of_spec_tool_names = Vec::<ToolValidationViolation>::new();
17241742
let mut hasher = DefaultHasher::new();
17251743
let mut number_of_tools = 0_usize;
17261744

@@ -1745,12 +1763,18 @@ async fn process_tool_specs(
17451763
}
17461764
});
17471765
if model_tool_name.len() > 64 {
1748-
out_of_spec_tool_names.push(OutOfSpecName::TooLong(spec.name.clone()));
1766+
out_of_spec_tool_names.push(ToolValidationViolation::TooLong(spec.name.clone()));
17491767
continue;
17501768
} else if spec.description.is_empty() {
1751-
out_of_spec_tool_names.push(OutOfSpecName::EmptyDescription(spec.name.clone()));
1769+
out_of_spec_tool_names.push(ToolValidationViolation::EmptyDescription(spec.name.clone()));
17521770
continue;
17531771
}
1772+
1773+
if spec.description.len() > 10_004 {
1774+
spec.description.truncate(10_004);
1775+
out_of_spec_tool_names.push(ToolValidationViolation::DescriptionTooLong(spec.name.clone()));
1776+
}
1777+
17541778
tn_map.insert(model_tool_name.clone(), ToolInfo {
17551779
server_name: server_name.to_string(),
17561780
host_tool_name: spec.name.clone(),
@@ -1788,21 +1812,25 @@ async fn process_tool_specs(
17881812
if !out_of_spec_tool_names.is_empty() {
17891813
Err(eyre::eyre!(out_of_spec_tool_names.iter().fold(
17901814
String::from(
1791-
"The following tools are out of spec. They will be excluded from the list of available tools:\n",
1815+
"The following tools are out of spec. They may have been excluded from the list of available tools:\n",
17921816
),
17931817
|mut acc, name| {
17941818
let (tool_name, msg) = match name {
1795-
OutOfSpecName::TooLong(tool_name) => (
1819+
ToolValidationViolation::TooLong(tool_name) => (
17961820
tool_name.as_str(),
17971821
"tool name exceeds max length of 64 when combined with server name",
17981822
),
1799-
OutOfSpecName::IllegalChar(tool_name) => (
1823+
ToolValidationViolation::IllegalChar(tool_name) => (
18001824
tool_name.as_str(),
18011825
"tool name must be compliant with ^[a-zA-Z][a-zA-Z0-9_]*$",
18021826
),
1803-
OutOfSpecName::EmptyDescription(tool_name) => {
1827+
ToolValidationViolation::EmptyDescription(tool_name) => {
18041828
(tool_name.as_str(), "tool schema contains empty description")
18051829
},
1830+
ToolValidationViolation::DescriptionTooLong(tool_name) => (
1831+
tool_name.as_str(),
1832+
"tool description is longer than 10024 characters and has been truncated",
1833+
),
18061834
};
18071835
acc.push_str(format!(" - {} ({})\n", tool_name, msg).as_str());
18081836
acc

crates/chat-cli/src/mcp_client/client.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,16 @@ pub enum McpClientError {
152152
Auth(#[from] crate::auth::AuthError),
153153
}
154154

155+
/// Decorates the method passed in with retry logic, but only if the [RunningService] has an
156+
/// instance of [AuthClientDropGuard].
157+
/// The various methods to interact with the mcp server provided by RMCP supposedly does refresh
158+
/// token once the token expires but that logic would require us to also note down the time at
159+
/// which a token is obtained since the only time related information in the token is the duration
160+
/// for which a token is valid. However, if we do solely rely on the internals of these methods to
161+
/// refresh tokens, we would have no way of knowing when a token is obtained. (Maybe there is a
162+
/// method that would allow us to configure what extra info to include in the token. If you find it,
163+
/// feel free to remove this. That would also enable us to simplify the definition of
164+
/// [RunningService])
155165
macro_rules! decorate_with_auth_retry {
156166
($param_type:ty, $method_name:ident, $return_type:ty) => {
157167
pub async fn $method_name(&self, param: $param_type) -> Result<$return_type, rmcp::ServiceError> {
@@ -166,7 +176,7 @@ macro_rules! decorate_with_auth_retry {
166176
// TODO: discern error type prior to retrying
167177
// Not entirely sure what is thrown when auth is required
168178
if let Some(auth_client) = self.get_auth_client() {
169-
let refresh_result = auth_client.get_access_token().await;
179+
let refresh_result = auth_client.auth_manager.lock().await.refresh_token().await;
170180
match refresh_result {
171181
Ok(_) => {
172182
// Retry the operation after token refresh
@@ -340,7 +350,7 @@ impl McpClientService {
340350
Err(e) if matches!(*e, ClientInitializeError::ConnectionClosed(_)) => {
341351
debug!("## mcp: first hand shake attempt failed: {:?}", e);
342352
let refresh_res =
343-
auth_dg.auth_client.get_access_token().await;
353+
auth_dg.auth_client.auth_manager.lock().await.refresh_token().await;
344354
let new_self = McpClientService::new(
345355
server_name.clone(),
346356
backup_config,

crates/chat-cli/src/mcp_client/oauth_util.rs

Lines changed: 71 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use reqwest::Client;
1414
use rmcp::serde_json;
1515
use rmcp::transport::auth::{
1616
AuthClient,
17+
OAuthClientConfig,
1718
OAuthState,
1819
OAuthTokenResponse,
1920
};
@@ -26,6 +27,10 @@ use rmcp::transport::{
2627
StreamableHttpClientTransport,
2728
WorkerTransport,
2829
};
30+
use serde::{
31+
Deserialize,
32+
Serialize,
33+
};
2934
use sha2::{
3035
Digest,
3136
Sha256,
@@ -64,6 +69,8 @@ pub enum OauthUtilError {
6469
Directory(#[from] DirectoryError),
6570
#[error(transparent)]
6671
Reqwest(#[from] reqwest::Error),
72+
#[error("Malformed directory")]
73+
MalformDirectory,
6774
}
6875

6976
/// A guard that automatically cancels the cancellation token when dropped.
@@ -79,6 +86,27 @@ impl Drop for LoopBackDropGuard {
7986
}
8087
}
8188

89+
/// This is modeled after [OAuthClientConfig]
90+
/// It's only here because [OAuthClientConfig] does not implement Serialize and Deserialize
91+
#[derive(Clone, Serialize, Deserialize, Debug)]
92+
pub struct Registration {
93+
pub client_id: String,
94+
pub client_secret: Option<String>,
95+
pub scopes: Vec<String>,
96+
pub redirect_uri: String,
97+
}
98+
99+
impl From<OAuthClientConfig> for Registration {
100+
fn from(value: OAuthClientConfig) -> Self {
101+
Self {
102+
client_id: value.client_id,
103+
client_secret: value.client_secret,
104+
scopes: value.scopes,
105+
redirect_uri: value.redirect_uri,
106+
}
107+
}
108+
}
109+
82110
/// A guard that manages the lifecycle of an authenticated MCP client and automatically
83111
/// persists OAuth credentials when dropped.
84112
///
@@ -164,6 +192,10 @@ pub enum HttpTransport {
164192
WithoutAuth(WorkerTransport<StreamableHttpClientWorker<Client>>),
165193
}
166194

195+
fn get_scopes() -> &'static [&'static str] {
196+
&["openid", "mcp", "email", "profile"]
197+
}
198+
167199
pub async fn get_http_transport(
168200
os: &Os,
169201
delete_cache: bool,
@@ -175,6 +207,7 @@ pub async fn get_http_transport(
175207
let url = Url::from_str(url)?;
176208
let key = compute_key(&url);
177209
let cred_full_path = cred_dir.join(format!("{key}.token.json"));
210+
let reg_full_path = cred_dir.join(format!("{key}.registration.json"));
178211

179212
if delete_cache && cred_full_path.is_file() {
180213
tokio::fs::remove_file(&cred_full_path).await?;
@@ -188,7 +221,8 @@ pub async fn get_http_transport(
188221
let auth_client = match auth_client {
189222
Some(auth_client) => auth_client,
190223
None => {
191-
let am = get_auth_manager(url.clone(), cred_full_path.clone(), messenger).await?;
224+
let am =
225+
get_auth_manager(url.clone(), cred_full_path.clone(), reg_full_path.clone(), messenger).await?;
192226
AuthClient::new(reqwest_client, am)
193227
},
194228
};
@@ -215,45 +249,67 @@ pub async fn get_http_transport(
215249
async fn get_auth_manager(
216250
url: Url,
217251
cred_full_path: PathBuf,
252+
reg_full_path: PathBuf,
218253
messenger: &dyn Messenger,
219254
) -> Result<AuthorizationManager, OauthUtilError> {
220-
let content_as_bytes = tokio::fs::read(&cred_full_path).await;
255+
let cred_as_bytes = tokio::fs::read(&cred_full_path).await;
256+
let reg_as_bytes = tokio::fs::read(&reg_full_path).await;
221257
let mut oauth_state = OAuthState::new(url, None).await?;
222258

223-
match content_as_bytes {
224-
Ok(bytes) => {
225-
let token = serde_json::from_slice::<OAuthTokenResponse>(&bytes)?;
259+
match (cred_as_bytes, reg_as_bytes) {
260+
(Ok(cred_as_bytes), Ok(reg_as_bytes)) => {
261+
let token = serde_json::from_slice::<OAuthTokenResponse>(&cred_as_bytes)?;
262+
let reg = serde_json::from_slice::<Registration>(&reg_as_bytes)?;
226263

227-
oauth_state.set_credentials("id", token).await?;
264+
oauth_state.set_credentials(&reg.client_id, token).await?;
228265

229266
debug!("## mcp: credentials set with cache");
230267

231268
Ok(oauth_state
232269
.into_authorization_manager()
233270
.ok_or(OauthUtilError::MissingAuthorizationManager)?)
234271
},
235-
Err(e) => {
236-
info!("Error reading cached credentials: {e}");
272+
_ => {
273+
info!("Error reading cached credentials");
237274
debug!("## mcp: cache read failed. constructing auth manager from scratch");
238-
get_auth_manager_impl(oauth_state, messenger).await
275+
let (am, redirect_uri) = get_auth_manager_impl(oauth_state, messenger).await?;
276+
277+
// Client registration is done in [start_authorization]
278+
// If we have gotten past that point that means we have the info to persist the
279+
// registration on disk. These are info that we need to refresh stake
280+
// tokens. This is in contrast to tokens, which we only persist when we drop
281+
// the client (because that way we can write once and ensure what is on the
282+
// disk always the most up to date)
283+
let (client_id, _credentials) = am.get_credentials().await?;
284+
let reg = Registration {
285+
client_id,
286+
client_secret: None,
287+
scopes: get_scopes().iter().map(|s| (*s).to_string()).collect::<Vec<_>>(),
288+
redirect_uri,
289+
};
290+
let reg_as_str = serde_json::to_string_pretty(&reg)?;
291+
let reg_parent_path = reg_full_path.parent().ok_or(OauthUtilError::MalformDirectory)?;
292+
tokio::fs::create_dir(reg_parent_path).await?;
293+
tokio::fs::write(reg_full_path, &reg_as_str).await?;
294+
295+
Ok(am)
239296
},
240297
}
241298
}
242299

243300
async fn get_auth_manager_impl(
244301
mut oauth_state: OAuthState,
245302
messenger: &dyn Messenger,
246-
) -> Result<AuthorizationManager, OauthUtilError> {
303+
) -> Result<(AuthorizationManager, String), OauthUtilError> {
247304
let socket_addr = SocketAddr::from(([127, 0, 0, 1], 0));
248305
let cancellation_token = tokio_util::sync::CancellationToken::new();
249306
let (tx, rx) = tokio::sync::oneshot::channel::<String>();
250307

251308
let (actual_addr, _dg) = make_svc(tx, socket_addr, cancellation_token).await?;
252309
info!("Listening on local host port {:?} for oauth", actual_addr);
253310

254-
oauth_state
255-
.start_authorization(&["mcp", "profile", "email"], &format!("http://{}", actual_addr))
256-
.await?;
311+
let redirect_uri = format!("http://{}", actual_addr);
312+
oauth_state.start_authorization(get_scopes(), &redirect_uri).await?;
257313

258314
let auth_url = oauth_state.get_authorization_url().await?;
259315
_ = messenger.send_oauth_link(auth_url).await;
@@ -264,7 +320,7 @@ async fn get_auth_manager_impl(
264320
.into_authorization_manager()
265321
.ok_or(OauthUtilError::MissingAuthorizationManager)?;
266322

267-
Ok(am)
323+
Ok((am, redirect_uri))
268324
}
269325

270326
pub fn compute_key(rs: &Url) -> String {
@@ -320,7 +376,7 @@ async fn make_svc(
320376
{
321377
sender.send(code).map_err(LoopBackError::Send)?;
322378
}
323-
mk_response("Auth code sent".to_string())
379+
mk_response("You can close this page now".to_string())
324380
})
325381
}
326382
}

0 commit comments

Comments
 (0)