This commit is contained in:
Joshua Coles 2025-03-17 15:10:01 +00:00
parent 7ddc1196d5
commit 36eb457898
6 changed files with 164 additions and 128 deletions

11
ext/mcp/Cargo.lock generated
View File

@ -251,6 +251,16 @@ dependencies = [
"thiserror", "thiserror",
] ]
[[package]]
name = "keepcalm"
version = "0.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "031ddc7e27bbb011c78958881a3723873608397b8b10e146717fc05cf3364d78"
dependencies = [
"once_cell",
"parking_lot",
]
[[package]] [[package]]
name = "lazy_static" name = "lazy_static"
version = "1.5.0" version = "1.5.0"
@ -333,6 +343,7 @@ version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"jsonrpsee", "jsonrpsee",
"keepcalm",
"magnus", "magnus",
"once_cell", "once_cell",
"serde", "serde",

View File

@ -21,3 +21,4 @@ serde_json = "1.0.140"
tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } tracing-subscriber = { version = "0.3.19", features = ["env-filter"] }
once_cell = "1.19.0" once_cell = "1.19.0"
serde_magnus = "0.9.0" serde_magnus = "0.9.0"
keepcalm = "0.3.5"

View File

@ -12,10 +12,10 @@ use crate::types::{CallToolRequestParams, ClientCapabilities, ListToolsRequestPa
mod types; mod types;
mod rpc_helpers; mod rpc_helpers;
mod mcp_client;
mod stdio_transport; mod stdio_transport;
use rpc_helpers::*; use rpc_helpers::*;
use stdio_transport::StdioTransport; use crate::mcp_client::McpClient;
#[tokio::main] #[tokio::main]
async fn main() -> anyhow::Result<()> { async fn main() -> anyhow::Result<()> {
@ -29,51 +29,53 @@ async fn main() -> anyhow::Result<()> {
.with_level(true) .with_level(true)
.init(); .init();
let client = McpClient::new_stdio(
let mut cmd = tokio::process::Command::new("/Users/joshuacoles/.local/bin/mcp-server-fetch") "/Users/joshuacoles/.local/bin/mcp-server-fetch".to_string(),
.stdin(Stdio::piped()) vec![],
.stdout(Stdio::piped()) InitializeRequestParams {
.spawn()?; capabilities: Default::default(),
client_info: Implementation {
let transport = StdioTransport::new(&mut cmd); name: "ABC".to_string(),
version: "0.0.1".to_string(),
let client: Client = ClientBuilder::default().build_with_tokio( },
transport.clone(), protocol_version: "2024-11-05".to_string(),
transport.clone(),
);
let response: InitializeResult = client.request("initialize", InitializeRequestParams {
capabilities: ClientCapabilities::default(),
client_info: Implementation { name: "Rust MCP".to_string(), version: "0.1.0".to_string() },
protocol_version: "2024-11-05".to_string(),
}.to_rpc()).await?;
client.notification("notifications/initialized", NoParams).await?;
println!("Hey");
// drop(transport.stdin);
select! {
_ = cmd.wait() => {
println!("Command exited");
} }
).await?;
_ = tokio::time::sleep(std::time::Duration::from_secs(1)) => { dbg!(client.list_tools());
cmd.kill().await?;
}
}
tokio::time::sleep(std::time::Duration::from_secs(1)).await; // let response: InitializeResult = client.request("initialize", InitializeRequestParams {
// capabilities: ClientCapabilities::default(),
// let response: ListToolsResult = client.request("tools/list", ListToolsRequestParams::default().to_rpc()).await?; // client_info: Implementation { name: "Rust MCP".to_string(), version: "0.1.0".to_string() },
// protocol_version: "2024-11-05".to_string(),
// let response: serde_json::Value = client.request("tools/call", CallToolRequestParams {
// arguments: json!({ "url": "http://example.com" }).as_object().unwrap().clone(),
// name: "fetch".to_string(),
// }.to_rpc()).await?; // }.to_rpc()).await?;
//
// println!("Response: {:#?}", response); // client.notification("notifications/initialized", NoParams).await?;
//
// println!("Hey");
//
// // drop(transport.stdin);
//
// select! {
// _ = cmd.wait() => {
// println!("Command exited");
// }
//
// _ = tokio::time::sleep(std::time::Duration::from_secs(1)) => {
// cmd.kill().await?;
// }
// }
//
// tokio::time::sleep(std::time::Duration::from_secs(1)).await;
//
// // let response: ListToolsResult = client.request("tools/list", ListToolsRequestParams::default().to_rpc()).await?;
//
// // let response: serde_json::Value = client.request("tools/call", CallToolRequestParams {
// // arguments: json!({ "url": "http://example.com" }).as_object().unwrap().clone(),
// // name: "fetch".to_string(),
// // }.to_rpc()).await?;
//
// // println!("Response: {:#?}", response);
Ok(()) Ok(())
} }

View File

@ -1,75 +1,79 @@
use jsonrpsee::async_client::{Client, ClientBuilder};
use tokio::process::Command;
use crate::mcp_client::McpClient; use crate::mcp_client::McpClient;
use jsonrpsee::async_client::{Client, ClientBuilder};
use jsonrpsee::core::client::ClientT; use jsonrpsee::core::client::ClientT;
use once_cell::sync::Lazy;
use magnus::prelude::*; use magnus::prelude::*;
use once_cell::sync::Lazy;
use tokio::process::Command;
mod mcp_client; mod mcp_client;
mod types;
mod rpc_helpers; mod rpc_helpers;
mod stdio_transport; mod stdio_transport;
mod types;
use std::{ use crate::types::{
hash::{Hash, Hasher}, CallToolRequestParams, Implementation, InitializeRequestParams, InitializeResult,
};
use magnus::{
function, method,
prelude::*,
scan_args::{get_kwargs, scan_args},
typed_data, Error, RHash, Ruby, Symbol, TryConvert, Value,
}; };
use magnus::{function, method, prelude::*, scan_args::{get_kwargs, scan_args}, typed_data, Error, RHash, Ruby, Symbol, TryConvert, Value};
use serde_magnus::serialize; use serde_magnus::serialize;
use crate::types::{CallToolRequestParams, Implementation, InitializeRequestParams, InitializeResult}; use std::hash::{Hash, Hasher};
use std::sync::{Arc, MutexGuard};
use tokio::sync::Mutex;
// Create global runtime // Create global runtime
static RUNTIME: Lazy<tokio::runtime::Runtime> = Lazy::new(|| { static RUNTIME: Lazy<tokio::runtime::Runtime> =
tokio::runtime::Runtime::new().expect("Failed to create Tokio runtime") Lazy::new(|| tokio::runtime::Runtime::new().expect("Failed to create Tokio runtime"));
});
#[magnus::wrap(class = "Mcp::Client", free_immediately, size)] #[magnus::wrap(class = "Mcp::Client", free_immediately, size)]
struct McpClientRb { struct McpClientRb {
client: McpClient, client: Mutex<Option<McpClient>>,
} }
impl McpClientRb { impl McpClientRb {
fn new(command: String, args: Vec<String>) -> Result<Self, magnus::Error> { fn new(command: String, args: Vec<String>) -> Result<Self, magnus::Error> {
let client = RUNTIME.block_on(async { let client = RUNTIME
let child = Command::new(command) .block_on(async {
.args(args) McpClient::new_stdio(
.stdin(std::process::Stdio::piped()) command,
.stdout(std::process::Stdio::piped()) args,
.spawn() InitializeRequestParams {
.unwrap(); capabilities: Default::default(),
client_info: Implementation {
name: "ABC".to_string(),
version: "0.0.1".to_string(),
},
protocol_version: "2024-11-05".to_string(),
},
)
.await
})
.map_err(|err| Error::new(magnus::exception::runtime_error(), err.to_string()))?;
let transport = stdio_transport::StdioTransport::new(child); Ok(Self {
client: Mutex::new(Some(client)),
ClientBuilder::default().build_with_tokio( })
transport.clone(),
transport.clone(),
)
});
Ok(Self { client: McpClient { client } })
} }
fn connect(&self) -> Result<bool, magnus::Error> { async fn client<'a>(&'a self) -> Result<&'a McpClient, Error> {
RUNTIME.block_on(async { self.client.lock().await.ok_or(Error::new(
let a = self.client.initialize(InitializeRequestParams { magnus::exception::runtime_error(),
capabilities: Default::default(), "Client is not initialized".to_string(),
client_info: Implementation { name: "ABC".to_string(), version: "0.0.1".to_string() }, ))
protocol_version: "2024-11-05".to_string(), }
}).await;
match a { fn dispose(&self) {
Ok(_) => Ok(true), RUNTIME.block_on(async {
Err(e) => Err(magnus::Error::new( self.client.lock().await.take();
magnus::exception::runtime_error(),
e.to_string(),
)),
}
}) })
} }
fn list_tools(&self) -> Result<Value, magnus::Error> { fn list_tools(&self) -> Result<Value, magnus::Error> {
RUNTIME.block_on(async { RUNTIME.block_on(async {
let a = self.client.list_tools().await; let a = self.client().await?.list_tools().await;
match a { match a {
Ok(tools) => serialize::<_, Value>(&tools), Ok(tools) => serialize::<_, Value>(&tools),
@ -93,10 +97,14 @@ impl McpClientRb {
}; };
RUNTIME.block_on(async { RUNTIME.block_on(async {
let a = self.client.call_tool::<serde_json::Value>(CallToolRequestParams { let a = self
name, .client()
arguments: kwargs, .await?
}).await; .call_tool::<serde_json::Value>(CallToolRequestParams {
name,
arguments: kwargs,
})
.await;
match a { match a {
Ok(a) => Ok(serde_magnus::serialize(&a)?), Ok(a) => Ok(serde_magnus::serialize(&a)?),
@ -111,13 +119,23 @@ impl McpClientRb {
#[magnus::init] #[magnus::init]
fn init(ruby: &Ruby) -> Result<(), Error> { fn init(ruby: &Ruby) -> Result<(), Error> {
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.with_file(true)
.with_line_number(true)
.with_thread_ids(true)
.with_thread_names(true)
.with_target(true)
.with_level(true)
.init();
let module = ruby.define_module("Mcp")?; let module = ruby.define_module("Mcp")?;
let client_class = module.define_class("Client", ruby.class_object())?; let client_class = module.define_class("Client", ruby.class_object())?;
client_class.define_singleton_method("new", function!(McpClientRb::new, 2))?; client_class.define_singleton_method("new", function!(McpClientRb::new, 2))?;
client_class.define_method("connect", method!(McpClientRb::connect, 0))?;
client_class.define_method("list_tools", method!(McpClientRb::list_tools, 0))?; client_class.define_method("list_tools", method!(McpClientRb::list_tools, 0))?;
client_class.define_method("call_tool", method!(McpClientRb::call_tool, -1))?; client_class.define_method("call_tool", method!(McpClientRb::call_tool, -1))?;
client_class.define_method("dispose", method!(McpClientRb::dispose, 0))?;
Ok(()) Ok(())
} }

View File

@ -1,27 +1,51 @@
use jsonrpsee::async_client::Client; use std::path::Path;
use jsonrpsee::async_client::{Client, ClientBuilder};
use jsonrpsee::core::client::ClientT; use jsonrpsee::core::client::ClientT;
use tokio::process::Child; use tokio::io::BufReader;
use stdio_transport::StdioTransport; use tokio::process::{Child, Command};
use crate::rpc_helpers::{NoParams, ToRpcArg}; use crate::rpc_helpers::{NoParams, ToRpcArg};
use crate::stdio_transport; use crate::stdio_transport;
use crate::stdio_transport::Adapter;
use crate::types::{CallToolRequestParams, InitializeRequestParams, InitializeResult, ListToolsRequestParams, ListToolsResult, Tool}; use crate::types::{CallToolRequestParams, InitializeRequestParams, InitializeResult, ListToolsRequestParams, ListToolsResult, Tool};
enum TransportHandle {
Stdio(Child),
}
pub struct McpClient { pub struct McpClient {
pub(crate) transport: StdioTransport, pub(crate) transport: TransportHandle,
pub(crate) client: Client, pub(crate) client: Client,
} }
impl McpClient { impl McpClient {
pub async fn new_stdio(command: String, args: Vec<String>, init_params: InitializeRequestParams) -> Result<Self, anyhow::Error> {
let mut child = Command::new(command)
.args(args)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.spawn()
.unwrap();
let stdin = child.stdin.take().unwrap();
let stdout = BufReader::new(child.stdout.take().unwrap());
let client = ClientBuilder::default().build_with_tokio(
Adapter(stdin),
Adapter(stdout),
);
let new_client = Self { transport: TransportHandle::Stdio(child), client };
new_client.initialize(init_params).await?;
Ok(new_client)
}
pub async fn initialize(&self, params: InitializeRequestParams) -> Result<InitializeResult, anyhow::Error> { pub async fn initialize(&self, params: InitializeRequestParams) -> Result<InitializeResult, anyhow::Error> {
let result: InitializeResult = self.client.request("initialize", params.to_rpc()).await?; let result: InitializeResult = self.client.request("initialize", params.to_rpc()).await?;
self.client.notification("notifications/initialized", NoParams).await?; self.client.notification("notifications/initialized", NoParams).await?;
Ok(result) Ok(result)
} }
pub async fn shutdown(mut self) {
self.transport.shutdown();
}
pub async fn list_tools(&self) -> Result<Vec<Tool>, anyhow::Error> { pub async fn list_tools(&self) -> Result<Vec<Tool>, anyhow::Error> {
let mut tools = vec![]; let mut tools = vec![];

View File

@ -1,53 +1,33 @@
use std::sync::Arc; use std::sync::Arc;
use tokio::process::{Child, ChildStdin, ChildStdout, Command}; use tokio::process::{Child, ChildStdin, ChildStdout, Command};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWriteExt, BufReader};
use jsonrpsee::core::async_trait; use jsonrpsee::core::async_trait;
use jsonrpsee::core::client::{ReceivedMessage, TransportReceiverT, TransportSenderT}; use jsonrpsee::core::client::{ReceivedMessage, TransportReceiverT, TransportSenderT};
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tracing::debug; use tracing::debug;
#[derive(Debug, Clone)] pub struct Adapter<T>(pub T);
pub struct StdioTransport {
pub stdin: Arc<Mutex<ChildStdin>>,
pub stdout: Arc<Mutex<BufReader<ChildStdout>>>,
}
impl StdioTransport {
pub fn new(child: &mut Child) -> Self {
let stdin = Arc::new(Mutex::new(child.stdin.take().unwrap()));
let stdout = Arc::new(Mutex::new(BufReader::new(child.stdout.take().unwrap())));
Self { stdin, stdout }
}
pub(crate) async fn shutdown(mut self) -> Result<(), tokio::io::Error> {
Ok(())
}
}
#[async_trait] #[async_trait]
impl TransportSenderT for StdioTransport { impl<T: AsyncWriteExt + Unpin + Send + 'static> TransportSenderT for Adapter<T> {
type Error = tokio::io::Error; type Error = tokio::io::Error;
#[tracing::instrument(skip(self), level = "trace")] #[tracing::instrument(skip(self), level = "trace")]
async fn send(&mut self, msg: String) -> Result<(), Self::Error> { async fn send(&mut self, msg: String) -> Result<(), Self::Error> {
debug!("Sending: {}", msg); self.0.write_all(msg.as_bytes()).await?;
let mut stdin = self.stdin.lock().await; self.0.write_all(b"\n").await?;
stdin.write_all(msg.as_bytes()).await?;
stdin.write_all(b"\n").await?;
Ok(()) Ok(())
} }
} }
#[async_trait] #[async_trait]
impl TransportReceiverT for StdioTransport { impl<T: AsyncBufRead + Unpin + Send + 'static> TransportReceiverT for Adapter<T> {
type Error = tokio::io::Error; type Error = tokio::io::Error;
#[tracing::instrument(skip(self), level = "trace")] #[tracing::instrument(skip(self), level = "trace")]
async fn receive(&mut self) -> Result<ReceivedMessage, Self::Error> { async fn receive(&mut self) -> Result<ReceivedMessage, Self::Error> {
let mut stdout = self.stdout.lock().await;
let mut str = String::new(); let mut str = String::new();
stdout.read_line(&mut str).await?; self.0.read_line(&mut str).await?;
debug!("Received: {}", str);
Ok(ReceivedMessage::Text(str)) Ok(ReceivedMessage::Text(str))
} }
} }