diff --git a/ext/mcp/Cargo.lock b/ext/mcp/Cargo.lock index 5fd324e..91bc271 100644 --- a/ext/mcp/Cargo.lock +++ b/ext/mcp/Cargo.lock @@ -251,6 +251,16 @@ dependencies = [ "thiserror", ] +[[package]] +name = "keepcalm" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "031ddc7e27bbb011c78958881a3723873608397b8b10e146717fc05cf3364d78" +dependencies = [ + "once_cell", + "parking_lot", +] + [[package]] name = "lazy_static" version = "1.5.0" @@ -333,6 +343,7 @@ version = "0.1.0" dependencies = [ "anyhow", "jsonrpsee", + "keepcalm", "magnus", "once_cell", "serde", diff --git a/ext/mcp/Cargo.toml b/ext/mcp/Cargo.toml index 67f2d96..0d6e11c 100644 --- a/ext/mcp/Cargo.toml +++ b/ext/mcp/Cargo.toml @@ -21,3 +21,4 @@ serde_json = "1.0.140" tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } once_cell = "1.19.0" serde_magnus = "0.9.0" +keepcalm = "0.3.5" diff --git a/ext/mcp/src/internal-test.rs b/ext/mcp/src/internal-test.rs index 511e791..37d27c8 100644 --- a/ext/mcp/src/internal-test.rs +++ b/ext/mcp/src/internal-test.rs @@ -12,10 +12,10 @@ use crate::types::{CallToolRequestParams, ClientCapabilities, ListToolsRequestPa mod types; mod rpc_helpers; +mod mcp_client; mod stdio_transport; - use rpc_helpers::*; -use stdio_transport::StdioTransport; +use crate::mcp_client::McpClient; #[tokio::main] async fn main() -> anyhow::Result<()> { @@ -29,51 +29,53 @@ async fn main() -> anyhow::Result<()> { .with_level(true) .init(); - - let mut cmd = tokio::process::Command::new("/Users/joshuacoles/.local/bin/mcp-server-fetch") - .stdin(Stdio::piped()) - .stdout(Stdio::piped()) - .spawn()?; - - let transport = StdioTransport::new(&mut cmd); - - let client: Client = ClientBuilder::default().build_with_tokio( - transport.clone(), - transport.clone(), - ); - - let response: InitializeResult = client.request("initialize", InitializeRequestParams { - capabilities: ClientCapabilities::default(), - client_info: Implementation { name: "Rust MCP".to_string(), version: "0.1.0".to_string() }, - protocol_version: "2024-11-05".to_string(), - }.to_rpc()).await?; - - client.notification("notifications/initialized", NoParams).await?; - - println!("Hey"); - - // drop(transport.stdin); - - select! { - _ = cmd.wait() => { - println!("Command exited"); + let client = McpClient::new_stdio( + "/Users/joshuacoles/.local/bin/mcp-server-fetch".to_string(), + vec![], + InitializeRequestParams { + capabilities: Default::default(), + client_info: Implementation { + name: "ABC".to_string(), + version: "0.0.1".to_string(), + }, + protocol_version: "2024-11-05".to_string(), } + ).await?; - _ = tokio::time::sleep(std::time::Duration::from_secs(1)) => { - cmd.kill().await?; - } - } + dbg!(client.list_tools()); - tokio::time::sleep(std::time::Duration::from_secs(1)).await; - - // let response: ListToolsResult = client.request("tools/list", ListToolsRequestParams::default().to_rpc()).await?; - - // let response: serde_json::Value = client.request("tools/call", CallToolRequestParams { - // arguments: json!({ "url": "http://example.com" }).as_object().unwrap().clone(), - // name: "fetch".to_string(), + // let response: InitializeResult = client.request("initialize", InitializeRequestParams { + // capabilities: ClientCapabilities::default(), + // client_info: Implementation { name: "Rust MCP".to_string(), version: "0.1.0".to_string() }, + // protocol_version: "2024-11-05".to_string(), // }.to_rpc()).await?; - - // println!("Response: {:#?}", response); + // + // client.notification("notifications/initialized", NoParams).await?; + // + // println!("Hey"); + // + // // drop(transport.stdin); + // + // select! { + // _ = cmd.wait() => { + // println!("Command exited"); + // } + // + // _ = tokio::time::sleep(std::time::Duration::from_secs(1)) => { + // cmd.kill().await?; + // } + // } + // + // tokio::time::sleep(std::time::Duration::from_secs(1)).await; + // + // // let response: ListToolsResult = client.request("tools/list", ListToolsRequestParams::default().to_rpc()).await?; + // + // // let response: serde_json::Value = client.request("tools/call", CallToolRequestParams { + // // arguments: json!({ "url": "http://example.com" }).as_object().unwrap().clone(), + // // name: "fetch".to_string(), + // // }.to_rpc()).await?; + // + // // println!("Response: {:#?}", response); Ok(()) } diff --git a/ext/mcp/src/lib.rs b/ext/mcp/src/lib.rs index 9ade5f2..b42d3b0 100644 --- a/ext/mcp/src/lib.rs +++ b/ext/mcp/src/lib.rs @@ -1,75 +1,79 @@ -use jsonrpsee::async_client::{Client, ClientBuilder}; -use tokio::process::Command; use crate::mcp_client::McpClient; +use jsonrpsee::async_client::{Client, ClientBuilder}; use jsonrpsee::core::client::ClientT; -use once_cell::sync::Lazy; use magnus::prelude::*; +use once_cell::sync::Lazy; +use tokio::process::Command; mod mcp_client; -mod types; mod rpc_helpers; mod stdio_transport; +mod types; -use std::{ - hash::{Hash, Hasher}, +use crate::types::{ + CallToolRequestParams, Implementation, InitializeRequestParams, InitializeResult, +}; +use magnus::{ + function, method, + prelude::*, + scan_args::{get_kwargs, scan_args}, + typed_data, Error, RHash, Ruby, Symbol, TryConvert, Value, }; - -use magnus::{function, method, prelude::*, scan_args::{get_kwargs, scan_args}, typed_data, Error, RHash, Ruby, Symbol, TryConvert, Value}; use serde_magnus::serialize; -use crate::types::{CallToolRequestParams, Implementation, InitializeRequestParams, InitializeResult}; +use std::hash::{Hash, Hasher}; +use std::sync::{Arc, MutexGuard}; +use tokio::sync::Mutex; // Create global runtime -static RUNTIME: Lazy = Lazy::new(|| { - tokio::runtime::Runtime::new().expect("Failed to create Tokio runtime") -}); +static RUNTIME: Lazy = + Lazy::new(|| tokio::runtime::Runtime::new().expect("Failed to create Tokio runtime")); #[magnus::wrap(class = "Mcp::Client", free_immediately, size)] struct McpClientRb { - client: McpClient, + client: Mutex>, } impl McpClientRb { fn new(command: String, args: Vec) -> Result { - let client = RUNTIME.block_on(async { - let child = Command::new(command) - .args(args) - .stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .spawn() - .unwrap(); + let client = RUNTIME + .block_on(async { + McpClient::new_stdio( + command, + args, + InitializeRequestParams { + capabilities: Default::default(), + client_info: Implementation { + name: "ABC".to_string(), + version: "0.0.1".to_string(), + }, + protocol_version: "2024-11-05".to_string(), + }, + ) + .await + }) + .map_err(|err| Error::new(magnus::exception::runtime_error(), err.to_string()))?; - let transport = stdio_transport::StdioTransport::new(child); - - ClientBuilder::default().build_with_tokio( - transport.clone(), - transport.clone(), - ) - }); - - Ok(Self { client: McpClient { client } }) + Ok(Self { + client: Mutex::new(Some(client)), + }) } - fn connect(&self) -> Result { - RUNTIME.block_on(async { - let a = self.client.initialize(InitializeRequestParams { - capabilities: Default::default(), - client_info: Implementation { name: "ABC".to_string(), version: "0.0.1".to_string() }, - protocol_version: "2024-11-05".to_string(), - }).await; + async fn client<'a>(&'a self) -> Result<&'a McpClient, Error> { + self.client.lock().await.ok_or(Error::new( + magnus::exception::runtime_error(), + "Client is not initialized".to_string(), + )) + } - match a { - Ok(_) => Ok(true), - Err(e) => Err(magnus::Error::new( - magnus::exception::runtime_error(), - e.to_string(), - )), - } + fn dispose(&self) { + RUNTIME.block_on(async { + self.client.lock().await.take(); }) } fn list_tools(&self) -> Result { RUNTIME.block_on(async { - let a = self.client.list_tools().await; + let a = self.client().await?.list_tools().await; match a { Ok(tools) => serialize::<_, Value>(&tools), @@ -93,10 +97,14 @@ impl McpClientRb { }; RUNTIME.block_on(async { - let a = self.client.call_tool::(CallToolRequestParams { - name, - arguments: kwargs, - }).await; + let a = self + .client() + .await? + .call_tool::(CallToolRequestParams { + name, + arguments: kwargs, + }) + .await; match a { Ok(a) => Ok(serde_magnus::serialize(&a)?), @@ -111,13 +119,23 @@ impl McpClientRb { #[magnus::init] fn init(ruby: &Ruby) -> Result<(), Error> { + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_file(true) + .with_line_number(true) + .with_thread_ids(true) + .with_thread_names(true) + .with_target(true) + .with_level(true) + .init(); + let module = ruby.define_module("Mcp")?; let client_class = module.define_class("Client", ruby.class_object())?; client_class.define_singleton_method("new", function!(McpClientRb::new, 2))?; - client_class.define_method("connect", method!(McpClientRb::connect, 0))?; client_class.define_method("list_tools", method!(McpClientRb::list_tools, 0))?; client_class.define_method("call_tool", method!(McpClientRb::call_tool, -1))?; + client_class.define_method("dispose", method!(McpClientRb::dispose, 0))?; Ok(()) } diff --git a/ext/mcp/src/mcp_client.rs b/ext/mcp/src/mcp_client.rs index 3152a45..c2e435f 100644 --- a/ext/mcp/src/mcp_client.rs +++ b/ext/mcp/src/mcp_client.rs @@ -1,27 +1,51 @@ -use jsonrpsee::async_client::Client; +use std::path::Path; +use jsonrpsee::async_client::{Client, ClientBuilder}; use jsonrpsee::core::client::ClientT; -use tokio::process::Child; -use stdio_transport::StdioTransport; +use tokio::io::BufReader; +use tokio::process::{Child, Command}; use crate::rpc_helpers::{NoParams, ToRpcArg}; use crate::stdio_transport; +use crate::stdio_transport::Adapter; use crate::types::{CallToolRequestParams, InitializeRequestParams, InitializeResult, ListToolsRequestParams, ListToolsResult, Tool}; +enum TransportHandle { + Stdio(Child), +} + pub struct McpClient { - pub(crate) transport: StdioTransport, + pub(crate) transport: TransportHandle, pub(crate) client: Client, } impl McpClient { + pub async fn new_stdio(command: String, args: Vec, init_params: InitializeRequestParams) -> Result { + let mut child = Command::new(command) + .args(args) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .spawn() + .unwrap(); + + let stdin = child.stdin.take().unwrap(); + let stdout = BufReader::new(child.stdout.take().unwrap()); + + let client = ClientBuilder::default().build_with_tokio( + Adapter(stdin), + Adapter(stdout), + ); + + let new_client = Self { transport: TransportHandle::Stdio(child), client }; + + new_client.initialize(init_params).await?; + Ok(new_client) + } + pub async fn initialize(&self, params: InitializeRequestParams) -> Result { let result: InitializeResult = self.client.request("initialize", params.to_rpc()).await?; self.client.notification("notifications/initialized", NoParams).await?; Ok(result) } - pub async fn shutdown(mut self) { - self.transport.shutdown(); - } - pub async fn list_tools(&self) -> Result, anyhow::Error> { let mut tools = vec![]; diff --git a/ext/mcp/src/stdio_transport.rs b/ext/mcp/src/stdio_transport.rs index 7d85b8d..e7e6e4f 100644 --- a/ext/mcp/src/stdio_transport.rs +++ b/ext/mcp/src/stdio_transport.rs @@ -1,53 +1,33 @@ use std::sync::Arc; use tokio::process::{Child, ChildStdin, ChildStdout, Command}; -use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWriteExt, BufReader}; use jsonrpsee::core::async_trait; use jsonrpsee::core::client::{ReceivedMessage, TransportReceiverT, TransportSenderT}; use tokio::sync::Mutex; use tracing::debug; -#[derive(Debug, Clone)] -pub struct StdioTransport { - pub stdin: Arc>, - pub stdout: Arc>>, -} - -impl StdioTransport { - pub fn new(child: &mut Child) -> Self { - let stdin = Arc::new(Mutex::new(child.stdin.take().unwrap())); - let stdout = Arc::new(Mutex::new(BufReader::new(child.stdout.take().unwrap()))); - Self { stdin, stdout } - } - - pub(crate) async fn shutdown(mut self) -> Result<(), tokio::io::Error> { - Ok(()) - } -} +pub struct Adapter(pub T); #[async_trait] -impl TransportSenderT for StdioTransport { +impl TransportSenderT for Adapter { type Error = tokio::io::Error; #[tracing::instrument(skip(self), level = "trace")] async fn send(&mut self, msg: String) -> Result<(), Self::Error> { - debug!("Sending: {}", msg); - let mut stdin = self.stdin.lock().await; - stdin.write_all(msg.as_bytes()).await?; - stdin.write_all(b"\n").await?; + self.0.write_all(msg.as_bytes()).await?; + self.0.write_all(b"\n").await?; Ok(()) } } #[async_trait] -impl TransportReceiverT for StdioTransport { +impl TransportReceiverT for Adapter { type Error = tokio::io::Error; #[tracing::instrument(skip(self), level = "trace")] async fn receive(&mut self) -> Result { - let mut stdout = self.stdout.lock().await; let mut str = String::new(); - stdout.read_line(&mut str).await?; - debug!("Received: {}", str); + self.0.read_line(&mut str).await?; Ok(ReceivedMessage::Text(str)) } -} \ No newline at end of file +}