Stash
This commit is contained in:
parent
8cfd28d3f6
commit
6ba2e9e8df
@ -58,13 +58,6 @@ impl McpClientRb {
|
||||
})
|
||||
}
|
||||
|
||||
async fn client<'a>(&'a self) -> Result<&'a McpClientConnection, Error> {
|
||||
self.client.lock().await.ok_or(Error::new(
|
||||
magnus::exception::runtime_error(),
|
||||
"Client is not initialized".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
fn dispose(&self) {
|
||||
RUNTIME.block_on(async {
|
||||
self.client.lock().await.take();
|
||||
@ -73,7 +66,14 @@ impl McpClientRb {
|
||||
|
||||
fn list_tools(&self) -> Result<Value, magnus::Error> {
|
||||
RUNTIME.block_on(async {
|
||||
let a = self.client().await?.list_tools().await;
|
||||
let a = self
|
||||
.client
|
||||
.lock()
|
||||
.await
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.list_tools()
|
||||
.await;
|
||||
|
||||
match a {
|
||||
Ok(tools) => serialize::<_, Value>(&tools),
|
||||
@ -98,8 +98,11 @@ impl McpClientRb {
|
||||
|
||||
RUNTIME.block_on(async {
|
||||
let a = self
|
||||
.client()
|
||||
.await?
|
||||
.client
|
||||
.lock()
|
||||
.await
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.call_tool::<serde_json::Value>(CallToolRequestParams {
|
||||
name,
|
||||
arguments: kwargs,
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use jsonrpsee::async_client::{Client, ClientBuilder};
|
||||
use jsonrpsee::core::client::ClientT;
|
||||
use keepcalm::SharedMut;
|
||||
use tokio::io::{BufReader, Stdin};
|
||||
use tokio::process::{Child, Command};
|
||||
use tokio::io::{BufReader};
|
||||
use tokio::process::{Child, Command, ChildStdin, ChildStdout};
|
||||
use tokio::sync::Mutex;
|
||||
use crate::rpc_helpers::{NoParams, ToRpcArg};
|
||||
use crate::stdio_transport::Adapter;
|
||||
use crate::types::{CallToolRequestParams, InitializeRequestParams, InitializeResult, ListToolsRequestParams, ListToolsResult, Tool};
|
||||
@ -11,34 +12,36 @@ use crate::types::{CallToolRequestParams, InitializeRequestParams, InitializeRes
|
||||
enum TransportHandle {
|
||||
Stdio {
|
||||
child: Child,
|
||||
stdin: SharedMut<Stdin>,
|
||||
stdout: SharedMut<BufReader<tokio::process::ChildStdout>>,
|
||||
stdin: Arc<Mutex<ChildStdin>>,
|
||||
stdout: Arc<Mutex<BufReader<ChildStdout>>>,
|
||||
},
|
||||
}
|
||||
|
||||
/// This represents a live MCP connection to an MCP server. It will close the connection when dropped on a best effort basis.
|
||||
pub struct McpClientConnection {
|
||||
pub(crate) transport: TransportHandle,
|
||||
pub(crate) client: Client,
|
||||
}
|
||||
|
||||
/// This represents a live MCP connection to an MCP server.
|
||||
impl McpClientConnection {
|
||||
pub async fn new_stdio(command: String, args: Vec<String>, init_params: InitializeRequestParams) -> Result<Self, anyhow::Error> {
|
||||
let mut child = Command::new(command)
|
||||
.args(args)
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.stdout(std::process::Stdio::piped())
|
||||
.kill_on_drop(true)
|
||||
.spawn()?;
|
||||
|
||||
let stdin = SharedMut::new_mutex(child.stdin.take().unwrap());
|
||||
let stdout = SharedMut::new_mutex(BufReader::new(child.stdout.take().unwrap()));
|
||||
// We take ownership of the stdin and stdout here to pass them to the transport, wrapping them in an Arc and Mutex to allow them to be shared between threads in an async context.
|
||||
let stdin = Arc::new(Mutex::new(child.stdin.take().unwrap()));
|
||||
let stdout = Arc::new(Mutex::new(BufReader::new(child.stdout.take().unwrap())));
|
||||
|
||||
let client = ClientBuilder::default().build_with_tokio(
|
||||
Adapter(stdin),
|
||||
Adapter(stdout),
|
||||
Adapter(stdin.clone()),
|
||||
Adapter(stdout.clone()),
|
||||
);
|
||||
|
||||
let new_client = Self { transport: TransportHandle::Stdio { child }, client };
|
||||
let new_client = Self { transport: TransportHandle::Stdio { child, stdin, stdout }, client };
|
||||
|
||||
new_client.initialize(init_params).await?;
|
||||
Ok(new_client)
|
||||
|
||||
@ -7,16 +7,18 @@ use keepcalm::SharedMut;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::debug;
|
||||
|
||||
pub struct Adapter<T>(pub SharedMut<T>);
|
||||
pub struct Adapter<T>(pub Arc<Mutex<T>>);
|
||||
|
||||
#[async_trait]
|
||||
impl<T: AsyncWriteExt> TransportSenderT for Adapter<T> {
|
||||
impl<T: Unpin + Send + 'static + AsyncWriteExt> TransportSenderT for Adapter<T> {
|
||||
type Error = tokio::io::Error;
|
||||
|
||||
#[tracing::instrument(skip(self), level = "trace")]
|
||||
async fn send(&mut self, msg: String) -> Result<(), Self::Error> {
|
||||
self.0.write_all(msg.as_bytes()).await?;
|
||||
self.0.write_all(b"\n").await?;
|
||||
let mut guard = self.0.lock().await;
|
||||
|
||||
guard.write_all(msg.as_bytes()).await?;
|
||||
guard.write_all(b"\n").await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -28,7 +30,7 @@ impl<T: AsyncBufRead + Unpin + Send + 'static> TransportReceiverT for Adapter<T>
|
||||
#[tracing::instrument(skip(self), level = "trace")]
|
||||
async fn receive(&mut self) -> Result<ReceivedMessage, Self::Error> {
|
||||
let mut str = String::new();
|
||||
self.0.read_line(&mut str).await?;
|
||||
self.0.lock().await.read_line(&mut str).await?;
|
||||
Ok(ReceivedMessage::Text(str))
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user