diff --git a/ext/mcp/src/lib.rs b/ext/mcp/src/lib.rs index e0475e6..69e6e13 100644 --- a/ext/mcp/src/lib.rs +++ b/ext/mcp/src/lib.rs @@ -58,13 +58,6 @@ impl McpClientRb { }) } - async fn client<'a>(&'a self) -> Result<&'a McpClientConnection, Error> { - self.client.lock().await.ok_or(Error::new( - magnus::exception::runtime_error(), - "Client is not initialized".to_string(), - )) - } - fn dispose(&self) { RUNTIME.block_on(async { self.client.lock().await.take(); @@ -73,7 +66,14 @@ impl McpClientRb { fn list_tools(&self) -> Result { RUNTIME.block_on(async { - let a = self.client().await?.list_tools().await; + let a = self + .client + .lock() + .await + .as_ref() + .unwrap() + .list_tools() + .await; match a { Ok(tools) => serialize::<_, Value>(&tools), @@ -98,8 +98,11 @@ impl McpClientRb { RUNTIME.block_on(async { let a = self - .client() - .await? + .client + .lock() + .await + .as_ref() + .unwrap() .call_tool::(CallToolRequestParams { name, arguments: kwargs, diff --git a/ext/mcp/src/mcp_client.rs b/ext/mcp/src/mcp_client.rs index 14ddeb6..d2916c3 100644 --- a/ext/mcp/src/mcp_client.rs +++ b/ext/mcp/src/mcp_client.rs @@ -1,9 +1,10 @@ use std::path::Path; +use std::sync::Arc; use jsonrpsee::async_client::{Client, ClientBuilder}; use jsonrpsee::core::client::ClientT; -use keepcalm::SharedMut; -use tokio::io::{BufReader, Stdin}; -use tokio::process::{Child, Command}; +use tokio::io::{BufReader}; +use tokio::process::{Child, Command, ChildStdin, ChildStdout}; +use tokio::sync::Mutex; use crate::rpc_helpers::{NoParams, ToRpcArg}; use crate::stdio_transport::Adapter; use crate::types::{CallToolRequestParams, InitializeRequestParams, InitializeResult, ListToolsRequestParams, ListToolsResult, Tool}; @@ -11,34 +12,36 @@ use crate::types::{CallToolRequestParams, InitializeRequestParams, InitializeRes enum TransportHandle { Stdio { child: Child, - stdin: SharedMut, - stdout: SharedMut>, + stdin: Arc>, + stdout: Arc>>, }, } +/// This represents a live MCP connection to an MCP server. It will close the connection when dropped on a best effort basis. pub struct McpClientConnection { pub(crate) transport: TransportHandle, pub(crate) client: Client, } -/// This represents a live MCP connection to an MCP server. impl McpClientConnection { pub async fn new_stdio(command: String, args: Vec, init_params: InitializeRequestParams) -> Result { let mut child = Command::new(command) .args(args) .stdin(std::process::Stdio::piped()) .stdout(std::process::Stdio::piped()) + .kill_on_drop(true) .spawn()?; - let stdin = SharedMut::new_mutex(child.stdin.take().unwrap()); - let stdout = SharedMut::new_mutex(BufReader::new(child.stdout.take().unwrap())); + // We take ownership of the stdin and stdout here to pass them to the transport, wrapping them in an Arc and Mutex to allow them to be shared between threads in an async context. + let stdin = Arc::new(Mutex::new(child.stdin.take().unwrap())); + let stdout = Arc::new(Mutex::new(BufReader::new(child.stdout.take().unwrap()))); let client = ClientBuilder::default().build_with_tokio( - Adapter(stdin), - Adapter(stdout), + Adapter(stdin.clone()), + Adapter(stdout.clone()), ); - let new_client = Self { transport: TransportHandle::Stdio { child }, client }; + let new_client = Self { transport: TransportHandle::Stdio { child, stdin, stdout }, client }; new_client.initialize(init_params).await?; Ok(new_client) diff --git a/ext/mcp/src/stdio_transport.rs b/ext/mcp/src/stdio_transport.rs index ee71515..49bd106 100644 --- a/ext/mcp/src/stdio_transport.rs +++ b/ext/mcp/src/stdio_transport.rs @@ -7,16 +7,18 @@ use keepcalm::SharedMut; use tokio::sync::Mutex; use tracing::debug; -pub struct Adapter(pub SharedMut); +pub struct Adapter(pub Arc>); #[async_trait] -impl TransportSenderT for Adapter { +impl TransportSenderT for Adapter { type Error = tokio::io::Error; #[tracing::instrument(skip(self), level = "trace")] async fn send(&mut self, msg: String) -> Result<(), Self::Error> { - self.0.write_all(msg.as_bytes()).await?; - self.0.write_all(b"\n").await?; + let mut guard = self.0.lock().await; + + guard.write_all(msg.as_bytes()).await?; + guard.write_all(b"\n").await?; Ok(()) } } @@ -28,7 +30,7 @@ impl TransportReceiverT for Adapter #[tracing::instrument(skip(self), level = "trace")] async fn receive(&mut self) -> Result { let mut str = String::new(); - self.0.read_line(&mut str).await?; + self.0.lock().await.read_line(&mut str).await?; Ok(ReceivedMessage::Text(str)) } }