diff --git a/ext/mcp/src/internal-test.rs b/ext/mcp/src/internal-test.rs index 5130b92..0552815 100644 --- a/ext/mcp/src/internal-test.rs +++ b/ext/mcp/src/internal-test.rs @@ -12,6 +12,9 @@ use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader, Stdin, use serde::{Deserialize, Serialize}; use serde_json::Error; use serde_json::value::RawValue; +use types::{Implementation, InitializeRequestParams, InitializeResult}; +use crate::types::ClientCapabilities; + mod types; #[derive(Debug, Clone)] @@ -34,12 +37,9 @@ impl TransportSenderT for StdioTransport { #[tracing::instrument(skip(self), level = "trace")] async fn send(&mut self, msg: String) -> Result<(), Self::Error> { - tracing::debug!("Sending message: {}", msg); let mut stdin = self.stdin.lock().await; - tracing::debug!("Locked stdin"); stdin.write_all(msg.as_bytes()).await?; stdin.write_all(b"\n").await?; - tracing::debug!("Wrote to stdin"); Ok(()) } } @@ -50,58 +50,22 @@ impl TransportReceiverT for StdioTransport { #[tracing::instrument(skip(self), level = "trace")] async fn receive(&mut self) -> Result { - tracing::debug!("Receiving message"); let mut stdout = self.stdout.lock().await; - tracing::debug!("Locked stdout"); let mut str = String::new(); - tracing::debug!("Reading from stdout"); stdout.read_line(&mut str).await?; - tracing::debug!("Read from stdout: {:?}", str); Ok(ReceivedMessage::Text(str)) } } -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -struct InitializeRequest { - protocol_version: String, - capabilities: Capabilities, - client_info: ClientInfo, -} +struct RpcArg(T); -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -struct Capabilities { - roots: Roots, - sampling: HashMap, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -struct Sampling { - sampling_interval: u64, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -struct Roots { - list_changed: bool, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct ClientInfo { - name: String, - version: String, -} - -impl ToRpcParams for InitializeRequest { - fn to_rpc_params(self) -> Result>, serde_json::Error> { - let s = String::from_utf8(serde_json::to_vec(&self)?).expect("Valid UTF8 format"); +impl ToRpcParams for RpcArg { + fn to_rpc_params(self) -> Result>, Error> { + let s = String::from_utf8(serde_json::to_vec(&self.0)?).expect("Valid UTF8 format"); RawValue::from_string(s).map(Some) } } - #[tokio::main] async fn main() -> anyhow::Result<()> { tracing_subscriber::fmt() @@ -114,17 +78,6 @@ async fn main() -> anyhow::Result<()> { .with_level(true) .init(); - tracing::debug!("Hello"); - - types::InitializeRequest { - method: "".to_string(), - params: types::InitializeRequestParams { - capabilities: Default::default(), - client_info: types::Implementation { name: "".to_string(), version: "".to_string() }, - protocol_version: "".to_string(), - }, - } - let cmd = tokio::process::Command::new("/Users/joshuacoles/.local/bin/mcp-server-fetch") .stdin(Stdio::piped()) .stdout(Stdio::piped()) @@ -137,16 +90,13 @@ async fn main() -> anyhow::Result<()> { transport.clone(), ); - let response: serde_json::Value = client.request("initialize", InitializeRequest { + let response: InitializeResult = client.request("initialize", RpcArg(InitializeRequestParams { + capabilities: ClientCapabilities::default(), + client_info: Implementation { name: "Rust MCP".to_string(), version: "0.1.0".to_string() }, protocol_version: "2024-11-05".to_string(), - capabilities: Capabilities { - roots: Roots { list_changed: true }, - sampling: HashMap::default(), - }, - client_info: ClientInfo { name: "ExampleClient".to_string(), version: "1.0.0".to_string() }, - }).await?; + })).await?; - println!("response: {:?}", response); + println!("Response: {:?}", response); Ok(()) }