diff --git a/ext/mcp/src/internal-test.rs b/ext/mcp/src/internal-test.rs index 6a42019..511e791 100644 --- a/ext/mcp/src/internal-test.rs +++ b/ext/mcp/src/internal-test.rs @@ -6,6 +6,7 @@ use jsonrpsee::core::traits::ToRpcParams; use tokio::io::{AsyncBufReadExt, AsyncWriteExt}; use serde::Serialize; use serde_json::json; +use tokio::select; use types::{Implementation, InitializeRequestParams, InitializeResult}; use crate::types::{CallToolRequestParams, ClientCapabilities, ListToolsRequestParams, ListToolsResult}; @@ -28,12 +29,13 @@ async fn main() -> anyhow::Result<()> { .with_level(true) .init(); - let cmd = tokio::process::Command::new("/Users/joshuacoles/.local/bin/mcp-server-fetch") + + let mut cmd = tokio::process::Command::new("/Users/joshuacoles/.local/bin/mcp-server-fetch") .stdin(Stdio::piped()) .stdout(Stdio::piped()) .spawn()?; - let transport = StdioTransport::new(cmd); + let transport = StdioTransport::new(&mut cmd); let client: Client = ClientBuilder::default().build_with_tokio( transport.clone(), @@ -48,14 +50,30 @@ async fn main() -> anyhow::Result<()> { client.notification("notifications/initialized", NoParams).await?; - let response: ListToolsResult = client.request("tools/list", ListToolsRequestParams::default().to_rpc()).await?; + println!("Hey"); - let response: serde_json::Value = client.request("tools/call", CallToolRequestParams { - arguments: json!({ "url": "http://example.com" }).as_object().unwrap().clone(), - name: "fetch".to_string(), - }.to_rpc()).await?; + // drop(transport.stdin); - println!("Response: {:#?}", response); + select! { + _ = cmd.wait() => { + println!("Command exited"); + } + + _ = tokio::time::sleep(std::time::Duration::from_secs(1)) => { + cmd.kill().await?; + } + } + + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + // let response: ListToolsResult = client.request("tools/list", ListToolsRequestParams::default().to_rpc()).await?; + + // let response: serde_json::Value = client.request("tools/call", CallToolRequestParams { + // arguments: json!({ "url": "http://example.com" }).as_object().unwrap().clone(), + // name: "fetch".to_string(), + // }.to_rpc()).await?; + + // println!("Response: {:#?}", response); Ok(()) } diff --git a/ext/mcp/src/mcp_client.rs b/ext/mcp/src/mcp_client.rs index 93ee2e1..3152a45 100644 --- a/ext/mcp/src/mcp_client.rs +++ b/ext/mcp/src/mcp_client.rs @@ -7,6 +7,7 @@ use crate::stdio_transport; use crate::types::{CallToolRequestParams, InitializeRequestParams, InitializeResult, ListToolsRequestParams, ListToolsResult, Tool}; pub struct McpClient { + pub(crate) transport: StdioTransport, pub(crate) client: Client, } @@ -17,6 +18,10 @@ impl McpClient { Ok(result) } + pub async fn shutdown(mut self) { + self.transport.shutdown(); + } + pub async fn list_tools(&self) -> Result, anyhow::Error> { let mut tools = vec![]; diff --git a/ext/mcp/src/stdio_transport.rs b/ext/mcp/src/stdio_transport.rs index 0463908..7d85b8d 100644 --- a/ext/mcp/src/stdio_transport.rs +++ b/ext/mcp/src/stdio_transport.rs @@ -1,22 +1,27 @@ use std::sync::Arc; -use tokio::process::{Child, ChildStdin, ChildStdout}; +use tokio::process::{Child, ChildStdin, ChildStdout, Command}; use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; use jsonrpsee::core::async_trait; use jsonrpsee::core::client::{ReceivedMessage, TransportReceiverT, TransportSenderT}; +use tokio::sync::Mutex; use tracing::debug; #[derive(Debug, Clone)] pub struct StdioTransport { - stdin: Arc>, - stdout: Arc>>, + pub stdin: Arc>, + pub stdout: Arc>>, } impl StdioTransport { - pub fn new(mut child: Child) -> Self { - let stdin = Arc::new(tokio::sync::Mutex::new(child.stdin.take().unwrap())); - let stdout = Arc::new(tokio::sync::Mutex::new(BufReader::new(child.stdout.take().unwrap()))); + pub fn new(child: &mut Child) -> Self { + let stdin = Arc::new(Mutex::new(child.stdin.take().unwrap())); + let stdout = Arc::new(Mutex::new(BufReader::new(child.stdout.take().unwrap()))); Self { stdin, stdout } } + + pub(crate) async fn shutdown(mut self) -> Result<(), tokio::io::Error> { + Ok(()) + } } #[async_trait]