diff --git a/ext/mcp/Cargo.toml b/ext/mcp/Cargo.toml index 3d6870c..762d86f 100644 --- a/ext/mcp/Cargo.toml +++ b/ext/mcp/Cargo.toml @@ -12,7 +12,7 @@ crate-type = ["cdylib"] [dependencies] jsonrpsee = { version = "0.24.8", features = ["async-client", "client-core", "tracing"] } -magnus = "0.7" +magnus = { version = "0.7", features = ["default"] } tokio = { version = "1.44.1", features = ["full"] } anyhow = "1.0.97" tracing = "0.1.41" diff --git a/ext/mcp/src/internal-test.rs b/ext/mcp/src/internal-test.rs index 0b5f29e..6a42019 100644 --- a/ext/mcp/src/internal-test.rs +++ b/ext/mcp/src/internal-test.rs @@ -1,64 +1,20 @@ use std::process::Stdio; -use std::sync::Arc; -use jsonrpsee::core::async_trait; use jsonrpsee::core::client::{ - Client, ClientBuilder, ClientT, ReceivedMessage, TransportReceiverT, TransportSenderT, + Client, ClientBuilder, ClientT, TransportReceiverT, TransportSenderT, }; use jsonrpsee::core::traits::ToRpcParams; -use tokio::process::{Child, ChildStdin, ChildStdout}; -use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt}; use serde::Serialize; -use serde_json::{json, Error}; -use serde_json::value::RawValue; -use tracing::debug; +use serde_json::json; use types::{Implementation, InitializeRequestParams, InitializeResult}; use crate::types::{CallToolRequestParams, ClientCapabilities, ListToolsRequestParams, ListToolsResult}; mod types; mod rpc_helpers; +mod stdio_transport; + use rpc_helpers::*; - -#[derive(Debug, Clone)] -struct StdioTransport { - stdin: Arc>, - stdout: Arc>>, -} - -impl StdioTransport { - fn new(mut child: Child) -> Self { - let stdin = Arc::new(tokio::sync::Mutex::new(child.stdin.take().unwrap())); - let stdout = Arc::new(tokio::sync::Mutex::new(BufReader::new(child.stdout.take().unwrap()))); - Self { stdin, stdout } - } -} - -#[async_trait] -impl TransportSenderT for StdioTransport { - type Error = tokio::io::Error; - - #[tracing::instrument(skip(self), level = "trace")] - async fn send(&mut self, msg: String) -> Result<(), Self::Error> { - debug!("Sending: {}", msg); - let mut stdin = self.stdin.lock().await; - stdin.write_all(msg.as_bytes()).await?; - stdin.write_all(b"\n").await?; - Ok(()) - } -} - -#[async_trait] -impl TransportReceiverT for StdioTransport { - type Error = tokio::io::Error; - - #[tracing::instrument(skip(self), level = "trace")] - async fn receive(&mut self) -> Result { - let mut stdout = self.stdout.lock().await; - let mut str = String::new(); - stdout.read_line(&mut str).await?; - debug!("Received: {}", str); - Ok(ReceivedMessage::Text(str)) - } -} +use stdio_transport::StdioTransport; #[tokio::main] async fn main() -> anyhow::Result<()> { diff --git a/ext/mcp/src/lib.rs b/ext/mcp/src/lib.rs index 400f6cc..6c32632 100644 --- a/ext/mcp/src/lib.rs +++ b/ext/mcp/src/lib.rs @@ -1,15 +1,76 @@ -use magnus::{function, Error, Ruby}; +use jsonrpsee::async_client::{Client, ClientBuilder}; +use tokio::process::Command; +use crate::mcp_client::McpClient; + mod mcp_client; mod types; mod rpc_helpers; +mod stdio_transport; -fn distance(a: (f64, f64), b: (f64, f64)) -> f64 { - ((b.0 - a.0).powi(2) + (b.1 - a.1).powi(2)).sqrt() +use std::{ + cell::RefCell, + fmt, + hash::{Hash, Hasher}, +}; + +use magnus::{function, method, prelude::*, scan_args::{get_kwargs, scan_args}, typed_data, Error, Ruby, Value}; + +#[magnus::wrap(class = "Mcp::Temperature", free_immediately, size)] +#[derive(Clone, Debug, Default, Eq, PartialEq, PartialOrd)] +struct Temperature { + microkelvin: RefCell, +} + +// can't derive this due to needing to use RefCell to get mutability +impl Hash for Temperature { + fn hash(&self, state: &mut H) { + self.microkelvin.borrow().hash(state) + } +} + +const FACTOR: f64 = 1000000.0; +const C_OFFSET: f64 = 273.15; + +fn f_to_c(f: f64) -> f64 { + (f - 32.0) * (5.0 / 9.0) +} + +fn c_to_f(c: f64) -> f64 { + c * (9.0 / 5.0) + 32.0 +} + +#[magnus::wrap(class = "Mcp::Client", free_immediately, size)] +struct McpClientRb { + client: McpClient, +} + +impl McpClientRb { + fn new(command: String, args: Vec) -> Result { + let child = Command::new(command) + .args(args) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .spawn() + .unwrap(); + + let transport = stdio_transport::StdioTransport::new(child); + + let client: Client = ClientBuilder::default().build_with_tokio( + transport.clone(), + transport.clone(), + ); + + let client = McpClient { client }; + Ok(Self { client }) + } } #[magnus::init] fn init(ruby: &Ruby) -> Result<(), Error> { - ruby.define_global_function("distance", function!(distance, 2)); + let module = ruby.define_module("Mcp")?; + let client_class = module.define_class("Client", ruby.class_object())?; + + client_class.define_singleton_method("new", function!(McpClientRb::new, 2))?; Ok(()) } diff --git a/ext/mcp/src/mcp_client.rs b/ext/mcp/src/mcp_client.rs index 5d9db67..5f204fc 100644 --- a/ext/mcp/src/mcp_client.rs +++ b/ext/mcp/src/mcp_client.rs @@ -1,10 +1,13 @@ use jsonrpsee::async_client::Client; use jsonrpsee::core::client::ClientT; +use tokio::process::Child; +use stdio_transport::StdioTransport; use crate::rpc_helpers::{NoParams, ToRpcArg}; -use crate::types::{InitializeRequestParams, InitializeResult}; +use crate::stdio_transport; +use crate::types::{CallToolRequestParams, InitializeRequestParams, InitializeResult, ListToolsRequestParams, ListToolsResult, Tool}; -struct McpClient { - client: Client, +pub struct McpClient { + pub(crate) client: Client, } impl McpClient { @@ -13,4 +16,21 @@ impl McpClient { self.client.notification("notifications/initialized", NoParams).await?; Ok(result) } + + async fn list_tools(&self) -> Result, anyhow::Error> { + let mut tools = vec![]; + + let result: ListToolsResult = self.client.request("tools/list", NoParams).await?; + + while let Some(cursor) = result.next_cursor.as_ref() { + let result: ListToolsResult = self.client.request("tools/list", ListToolsRequestParams { cursor: Some(cursor.clone()) }.to_rpc()).await?; + tools.extend(result.tools); + } + + Ok(tools) + } + + async fn call_tool(&self, params: CallToolRequestParams) -> Result { + Ok(self.client.request("tools/call", params.to_rpc()).await?) + } } diff --git a/ext/mcp/src/stdio_transport.rs b/ext/mcp/src/stdio_transport.rs new file mode 100644 index 0000000..0463908 --- /dev/null +++ b/ext/mcp/src/stdio_transport.rs @@ -0,0 +1,48 @@ +use std::sync::Arc; +use tokio::process::{Child, ChildStdin, ChildStdout}; +use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; +use jsonrpsee::core::async_trait; +use jsonrpsee::core::client::{ReceivedMessage, TransportReceiverT, TransportSenderT}; +use tracing::debug; + +#[derive(Debug, Clone)] +pub struct StdioTransport { + stdin: Arc>, + stdout: Arc>>, +} + +impl StdioTransport { + pub fn new(mut child: Child) -> Self { + let stdin = Arc::new(tokio::sync::Mutex::new(child.stdin.take().unwrap())); + let stdout = Arc::new(tokio::sync::Mutex::new(BufReader::new(child.stdout.take().unwrap()))); + Self { stdin, stdout } + } +} + +#[async_trait] +impl TransportSenderT for StdioTransport { + type Error = tokio::io::Error; + + #[tracing::instrument(skip(self), level = "trace")] + async fn send(&mut self, msg: String) -> Result<(), Self::Error> { + debug!("Sending: {}", msg); + let mut stdin = self.stdin.lock().await; + stdin.write_all(msg.as_bytes()).await?; + stdin.write_all(b"\n").await?; + Ok(()) + } +} + +#[async_trait] +impl TransportReceiverT for StdioTransport { + type Error = tokio::io::Error; + + #[tracing::instrument(skip(self), level = "trace")] + async fn receive(&mut self) -> Result { + let mut stdout = self.stdout.lock().await; + let mut str = String::new(); + stdout.read_line(&mut str).await?; + debug!("Received: {}", str); + Ok(ReceivedMessage::Text(str)) + } +} \ No newline at end of file diff --git a/lib/mcp.rb b/lib/mcp.rb index 00dbc0a..72cdc74 100644 --- a/lib/mcp.rb +++ b/lib/mcp.rb @@ -1,9 +1,10 @@ # frozen_string_literal: true require_relative "mcp/version" -require_relative "mcp/mcp" module Mcp class Error < StandardError; end # Your code goes here... end + +require_relative "mcp/mcp"