67 lines
2.3 KiB
Rust
67 lines
2.3 KiB
Rust
use std::path::Path;
|
|
use jsonrpsee::async_client::{Client, ClientBuilder};
|
|
use jsonrpsee::core::client::ClientT;
|
|
use tokio::io::BufReader;
|
|
use tokio::process::{Child, Command};
|
|
use crate::rpc_helpers::{NoParams, ToRpcArg};
|
|
use crate::stdio_transport;
|
|
use crate::stdio_transport::Adapter;
|
|
use crate::types::{CallToolRequestParams, InitializeRequestParams, InitializeResult, ListToolsRequestParams, ListToolsResult, Tool};
|
|
|
|
enum TransportHandle {
|
|
Stdio(Child),
|
|
}
|
|
|
|
pub struct McpClient {
|
|
pub(crate) transport: TransportHandle,
|
|
pub(crate) client: Client,
|
|
}
|
|
|
|
impl McpClient {
|
|
pub async fn new_stdio(command: String, args: Vec<String>, init_params: InitializeRequestParams) -> Result<Self, anyhow::Error> {
|
|
let mut child = Command::new(command)
|
|
.args(args)
|
|
.stdin(std::process::Stdio::piped())
|
|
.stdout(std::process::Stdio::piped())
|
|
.spawn()
|
|
.unwrap();
|
|
|
|
let stdin = child.stdin.take().unwrap();
|
|
let stdout = BufReader::new(child.stdout.take().unwrap());
|
|
|
|
let client = ClientBuilder::default().build_with_tokio(
|
|
Adapter(stdin),
|
|
Adapter(stdout),
|
|
);
|
|
|
|
let new_client = Self { transport: TransportHandle::Stdio(child), client };
|
|
|
|
new_client.initialize(init_params).await?;
|
|
Ok(new_client)
|
|
}
|
|
|
|
pub async fn initialize(&self, params: InitializeRequestParams) -> Result<InitializeResult, anyhow::Error> {
|
|
let result: InitializeResult = self.client.request("initialize", params.to_rpc()).await?;
|
|
self.client.notification("notifications/initialized", NoParams).await?;
|
|
Ok(result)
|
|
}
|
|
|
|
pub async fn list_tools(&self) -> Result<Vec<Tool>, anyhow::Error> {
|
|
let mut tools = vec![];
|
|
|
|
let result: ListToolsResult = self.client.request("tools/list", NoParams).await?;
|
|
tools.extend(result.tools);
|
|
|
|
while let Some(cursor) = result.next_cursor.as_ref() {
|
|
let result: ListToolsResult = self.client.request("tools/list", ListToolsRequestParams { cursor: Some(cursor.clone()) }.to_rpc()).await?;
|
|
tools.extend(result.tools);
|
|
}
|
|
|
|
Ok(tools)
|
|
}
|
|
|
|
pub async fn call_tool<T: serde::de::DeserializeOwned>(&self, params: CallToolRequestParams) -> Result<T, anyhow::Error> {
|
|
Ok(self.client.request("tools/call", params.to_rpc()).await?)
|
|
}
|
|
}
|