diff --git a/ext/mcp/Cargo.lock b/ext/mcp/Cargo.lock index 4e5a626..5fd324e 100644 --- a/ext/mcp/Cargo.lock +++ b/ext/mcp/Cargo.lock @@ -334,8 +334,10 @@ dependencies = [ "anyhow", "jsonrpsee", "magnus", + "once_cell", "serde", "serde_json", + "serde_magnus", "tokio", "tracing", "tracing-subscriber", @@ -638,6 +640,17 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_magnus" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51b8b945a2dadb221f1c5490cfb411cab6c3821446b8eca50ee07e5a3893ec51" +dependencies = [ + "magnus", + "serde", + "tap", +] + [[package]] name = "sharded-slab" version = "0.1.7" @@ -695,6 +708,12 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tap" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" + [[package]] name = "thiserror" version = "1.0.69" diff --git a/ext/mcp/Cargo.toml b/ext/mcp/Cargo.toml index 762d86f..67f2d96 100644 --- a/ext/mcp/Cargo.toml +++ b/ext/mcp/Cargo.toml @@ -19,3 +19,5 @@ tracing = "0.1.41" serde = { version = "1.0.219", features = ["derive"] } serde_json = "1.0.140" tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } +once_cell = "1.19.0" +serde_magnus = "0.9.0" diff --git a/ext/mcp/src/lib.rs b/ext/mcp/src/lib.rs index 6c32632..b15d128 100644 --- a/ext/mcp/src/lib.rs +++ b/ext/mcp/src/lib.rs @@ -1,6 +1,8 @@ use jsonrpsee::async_client::{Client, ClientBuilder}; use tokio::process::Command; use crate::mcp_client::McpClient; +use jsonrpsee::core::client::ClientT; +use once_cell::sync::Lazy; mod mcp_client; mod types; @@ -14,30 +16,14 @@ use std::{ }; use magnus::{function, method, prelude::*, scan_args::{get_kwargs, scan_args}, typed_data, Error, Ruby, Value}; +use serde_magnus::serialize; +use crate::types::{Implementation, InitializeRequestParams, InitializeResult}; +use crate::types::builder::Tool; -#[magnus::wrap(class = "Mcp::Temperature", free_immediately, size)] -#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd)] -struct Temperature { - microkelvin: RefCell, -} - -// can't derive this due to needing to use RefCell to get mutability -impl Hash for Temperature { - fn hash(&self, state: &mut H) { - self.microkelvin.borrow().hash(state) - } -} - -const FACTOR: f64 = 1000000.0; -const C_OFFSET: f64 = 273.15; - -fn f_to_c(f: f64) -> f64 { - (f - 32.0) * (5.0 / 9.0) -} - -fn c_to_f(c: f64) -> f64 { - c * (9.0 / 5.0) + 32.0 -} +// Create global runtime +static RUNTIME: Lazy = Lazy::new(|| { + tokio::runtime::Runtime::new().expect("Failed to create Tokio runtime") +}); #[magnus::wrap(class = "Mcp::Client", free_immediately, size)] struct McpClientRb { @@ -46,23 +32,63 @@ struct McpClientRb { impl McpClientRb { fn new(command: String, args: Vec) -> Result { - let child = Command::new(command) - .args(args) - .stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .spawn() - .unwrap(); + let client = RUNTIME.block_on(async { + let child = Command::new(command) + .args(args) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .spawn() + .unwrap(); - let transport = stdio_transport::StdioTransport::new(child); + let transport = stdio_transport::StdioTransport::new(child); - let client: Client = ClientBuilder::default().build_with_tokio( - transport.clone(), - transport.clone(), - ); + ClientBuilder::default().build_with_tokio( + transport.clone(), + transport.clone(), + ) + }); - let client = McpClient { client }; - Ok(Self { client }) + Ok(Self { client: McpClient { client } }) } + + fn connect(&self) -> Result { + RUNTIME.block_on(async { + let a = self.client.initialize(InitializeRequestParams { + capabilities: Default::default(), + client_info: Implementation { name: "ABC".to_string(), version: "0.0.1".to_string() }, + protocol_version: "2024-11-05".to_string(), + }).await; + + match a { + Ok(_) => Ok(true), + Err(e) => Err(magnus::Error::new( + magnus::exception::runtime_error(), + e.to_string(), + )), + } + }) + } + + fn list_tools(&self) -> Result { + RUNTIME.block_on(async { + let a = self.client.list_tools().await; + + match a { + Ok(tools) => serialize::<_, Value>(&tools), + Err(e) => Err(Error::new( + magnus::exception::runtime_error(), + e.to_string(), + )), + } + }) + } + + // fn call_rpc(&self, method: &str, params: &[&str]) -> Result { + // RUNTIME.block_on(async { + // self.client.client.request(method, params).await + // .map_err(|e| magnus::Error::new(magnus::exception::runtime_error(), e.to_string())) + // }) + // } } #[magnus::init] @@ -71,6 +97,8 @@ fn init(ruby: &Ruby) -> Result<(), Error> { let client_class = module.define_class("Client", ruby.class_object())?; client_class.define_singleton_method("new", function!(McpClientRb::new, 2))?; + client_class.define_method("connect", method!(McpClientRb::connect, 0))?; + client_class.define_method("list_tools", method!(McpClientRb::list_tools, 0))?; Ok(()) } diff --git a/ext/mcp/src/mcp_client.rs b/ext/mcp/src/mcp_client.rs index 5f204fc..638dd4c 100644 --- a/ext/mcp/src/mcp_client.rs +++ b/ext/mcp/src/mcp_client.rs @@ -11,16 +11,17 @@ pub struct McpClient { } impl McpClient { - async fn initialize(&self, params: InitializeRequestParams) -> Result { + pub async fn initialize(&self, params: InitializeRequestParams) -> Result { let result: InitializeResult = self.client.request("initialize", params.to_rpc()).await?; self.client.notification("notifications/initialized", NoParams).await?; Ok(result) } - async fn list_tools(&self) -> Result, anyhow::Error> { + pub async fn list_tools(&self) -> Result, anyhow::Error> { let mut tools = vec![]; let result: ListToolsResult = self.client.request("tools/list", NoParams).await?; + tools.extend(result.tools); while let Some(cursor) = result.next_cursor.as_ref() { let result: ListToolsResult = self.client.request("tools/list", ListToolsRequestParams { cursor: Some(cursor.clone()) }.to_rpc()).await?;