diff --git a/ext/mcp/src/internal-test.rs b/ext/mcp/src/internal-test.rs index 37d27c8..dcb59a7 100644 --- a/ext/mcp/src/internal-test.rs +++ b/ext/mcp/src/internal-test.rs @@ -15,7 +15,7 @@ mod rpc_helpers; mod mcp_client; mod stdio_transport; use rpc_helpers::*; -use crate::mcp_client::McpClient; +use crate::mcp_client::McpClientConnection; #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -29,7 +29,7 @@ async fn main() -> anyhow::Result<()> { .with_level(true) .init(); - let client = McpClient::new_stdio( + let client = McpClientConnection::new_stdio( "/Users/joshuacoles/.local/bin/mcp-server-fetch".to_string(), vec![], InitializeRequestParams { diff --git a/ext/mcp/src/lib.rs b/ext/mcp/src/lib.rs index b42d3b0..e0475e6 100644 --- a/ext/mcp/src/lib.rs +++ b/ext/mcp/src/lib.rs @@ -1,4 +1,4 @@ -use crate::mcp_client::McpClient; +use crate::mcp_client::McpClientConnection; use jsonrpsee::async_client::{Client, ClientBuilder}; use jsonrpsee::core::client::ClientT; use magnus::prelude::*; @@ -30,14 +30,14 @@ static RUNTIME: Lazy = #[magnus::wrap(class = "Mcp::Client", free_immediately, size)] struct McpClientRb { - client: Mutex>, + client: Mutex>, } impl McpClientRb { fn new(command: String, args: Vec) -> Result { let client = RUNTIME .block_on(async { - McpClient::new_stdio( + McpClientConnection::new_stdio( command, args, InitializeRequestParams { @@ -58,7 +58,7 @@ impl McpClientRb { }) } - async fn client<'a>(&'a self) -> Result<&'a McpClient, Error> { + async fn client<'a>(&'a self) -> Result<&'a McpClientConnection, Error> { self.client.lock().await.ok_or(Error::new( magnus::exception::runtime_error(), "Client is not initialized".to_string(), diff --git a/ext/mcp/src/mcp_client.rs b/ext/mcp/src/mcp_client.rs index c2e435f..14ddeb6 100644 --- a/ext/mcp/src/mcp_client.rs +++ b/ext/mcp/src/mcp_client.rs @@ -1,46 +1,50 @@ use std::path::Path; use jsonrpsee::async_client::{Client, ClientBuilder}; use jsonrpsee::core::client::ClientT; -use tokio::io::BufReader; +use keepcalm::SharedMut; +use tokio::io::{BufReader, Stdin}; use tokio::process::{Child, Command}; use crate::rpc_helpers::{NoParams, ToRpcArg}; -use crate::stdio_transport; use crate::stdio_transport::Adapter; use crate::types::{CallToolRequestParams, InitializeRequestParams, InitializeResult, ListToolsRequestParams, ListToolsResult, Tool}; enum TransportHandle { - Stdio(Child), + Stdio { + child: Child, + stdin: SharedMut, + stdout: SharedMut>, + }, } -pub struct McpClient { +pub struct McpClientConnection { pub(crate) transport: TransportHandle, pub(crate) client: Client, } -impl McpClient { +/// This represents a live MCP connection to an MCP server. +impl McpClientConnection { pub async fn new_stdio(command: String, args: Vec, init_params: InitializeRequestParams) -> Result { let mut child = Command::new(command) .args(args) .stdin(std::process::Stdio::piped()) .stdout(std::process::Stdio::piped()) - .spawn() - .unwrap(); + .spawn()?; - let stdin = child.stdin.take().unwrap(); - let stdout = BufReader::new(child.stdout.take().unwrap()); + let stdin = SharedMut::new_mutex(child.stdin.take().unwrap()); + let stdout = SharedMut::new_mutex(BufReader::new(child.stdout.take().unwrap())); let client = ClientBuilder::default().build_with_tokio( Adapter(stdin), Adapter(stdout), ); - let new_client = Self { transport: TransportHandle::Stdio(child), client }; + let new_client = Self { transport: TransportHandle::Stdio { child }, client }; new_client.initialize(init_params).await?; Ok(new_client) } - pub async fn initialize(&self, params: InitializeRequestParams) -> Result { + async fn initialize(&self, params: InitializeRequestParams) -> Result { let result: InitializeResult = self.client.request("initialize", params.to_rpc()).await?; self.client.notification("notifications/initialized", NoParams).await?; Ok(result) diff --git a/ext/mcp/src/stdio_transport.rs b/ext/mcp/src/stdio_transport.rs index e7e6e4f..ee71515 100644 --- a/ext/mcp/src/stdio_transport.rs +++ b/ext/mcp/src/stdio_transport.rs @@ -3,13 +3,14 @@ use tokio::process::{Child, ChildStdin, ChildStdout, Command}; use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWriteExt, BufReader}; use jsonrpsee::core::async_trait; use jsonrpsee::core::client::{ReceivedMessage, TransportReceiverT, TransportSenderT}; +use keepcalm::SharedMut; use tokio::sync::Mutex; use tracing::debug; -pub struct Adapter(pub T); +pub struct Adapter(pub SharedMut); #[async_trait] -impl TransportSenderT for Adapter { +impl TransportSenderT for Adapter { type Error = tokio::io::Error; #[tracing::instrument(skip(self), level = "trace")]