This commit is contained in:
Joshua Coles 2025-03-17 15:50:28 +00:00
parent 8cfd28d3f6
commit 6ba2e9e8df
3 changed files with 34 additions and 26 deletions

View File

@ -58,13 +58,6 @@ impl McpClientRb {
})
}
async fn client<'a>(&'a self) -> Result<&'a McpClientConnection, Error> {
self.client.lock().await.ok_or(Error::new(
magnus::exception::runtime_error(),
"Client is not initialized".to_string(),
))
}
fn dispose(&self) {
RUNTIME.block_on(async {
self.client.lock().await.take();
@ -73,7 +66,14 @@ impl McpClientRb {
fn list_tools(&self) -> Result<Value, magnus::Error> {
RUNTIME.block_on(async {
let a = self.client().await?.list_tools().await;
let a = self
.client
.lock()
.await
.as_ref()
.unwrap()
.list_tools()
.await;
match a {
Ok(tools) => serialize::<_, Value>(&tools),
@ -98,8 +98,11 @@ impl McpClientRb {
RUNTIME.block_on(async {
let a = self
.client()
.await?
.client
.lock()
.await
.as_ref()
.unwrap()
.call_tool::<serde_json::Value>(CallToolRequestParams {
name,
arguments: kwargs,

View File

@ -1,9 +1,10 @@
use std::path::Path;
use std::sync::Arc;
use jsonrpsee::async_client::{Client, ClientBuilder};
use jsonrpsee::core::client::ClientT;
use keepcalm::SharedMut;
use tokio::io::{BufReader, Stdin};
use tokio::process::{Child, Command};
use tokio::io::{BufReader};
use tokio::process::{Child, Command, ChildStdin, ChildStdout};
use tokio::sync::Mutex;
use crate::rpc_helpers::{NoParams, ToRpcArg};
use crate::stdio_transport::Adapter;
use crate::types::{CallToolRequestParams, InitializeRequestParams, InitializeResult, ListToolsRequestParams, ListToolsResult, Tool};
@ -11,34 +12,36 @@ use crate::types::{CallToolRequestParams, InitializeRequestParams, InitializeRes
enum TransportHandle {
Stdio {
child: Child,
stdin: SharedMut<Stdin>,
stdout: SharedMut<BufReader<tokio::process::ChildStdout>>,
stdin: Arc<Mutex<ChildStdin>>,
stdout: Arc<Mutex<BufReader<ChildStdout>>>,
},
}
/// This represents a live MCP connection to an MCP server. It will close the connection when dropped on a best effort basis.
pub struct McpClientConnection {
pub(crate) transport: TransportHandle,
pub(crate) client: Client,
}
/// This represents a live MCP connection to an MCP server.
impl McpClientConnection {
pub async fn new_stdio(command: String, args: Vec<String>, init_params: InitializeRequestParams) -> Result<Self, anyhow::Error> {
let mut child = Command::new(command)
.args(args)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.kill_on_drop(true)
.spawn()?;
let stdin = SharedMut::new_mutex(child.stdin.take().unwrap());
let stdout = SharedMut::new_mutex(BufReader::new(child.stdout.take().unwrap()));
// We take ownership of the stdin and stdout here to pass them to the transport, wrapping them in an Arc and Mutex to allow them to be shared between threads in an async context.
let stdin = Arc::new(Mutex::new(child.stdin.take().unwrap()));
let stdout = Arc::new(Mutex::new(BufReader::new(child.stdout.take().unwrap())));
let client = ClientBuilder::default().build_with_tokio(
Adapter(stdin),
Adapter(stdout),
Adapter(stdin.clone()),
Adapter(stdout.clone()),
);
let new_client = Self { transport: TransportHandle::Stdio { child }, client };
let new_client = Self { transport: TransportHandle::Stdio { child, stdin, stdout }, client };
new_client.initialize(init_params).await?;
Ok(new_client)

View File

@ -7,16 +7,18 @@ use keepcalm::SharedMut;
use tokio::sync::Mutex;
use tracing::debug;
pub struct Adapter<T>(pub SharedMut<T>);
pub struct Adapter<T>(pub Arc<Mutex<T>>);
#[async_trait]
impl<T: AsyncWriteExt> TransportSenderT for Adapter<T> {
impl<T: Unpin + Send + 'static + AsyncWriteExt> TransportSenderT for Adapter<T> {
type Error = tokio::io::Error;
#[tracing::instrument(skip(self), level = "trace")]
async fn send(&mut self, msg: String) -> Result<(), Self::Error> {
self.0.write_all(msg.as_bytes()).await?;
self.0.write_all(b"\n").await?;
let mut guard = self.0.lock().await;
guard.write_all(msg.as_bytes()).await?;
guard.write_all(b"\n").await?;
Ok(())
}
}
@ -28,7 +30,7 @@ impl<T: AsyncBufRead + Unpin + Send + 'static> TransportReceiverT for Adapter<T>
#[tracing::instrument(skip(self), level = "trace")]
async fn receive(&mut self) -> Result<ReceivedMessage, Self::Error> {
let mut str = String::new();
self.0.read_line(&mut str).await?;
self.0.lock().await.read_line(&mut str).await?;
Ok(ReceivedMessage::Text(str))
}
}