mod topology; mod orchestration; mod discovery; mod network; use serde::{Deserialize, Serialize}; use serde_json::Value; use tonic::{transport::Server, Request, Response, Status}; use crate::node_service::{ CollectTopologyRequest, Empty, ExampleRequest, HealthCheckRequest, HealthCheckResponse, Loss, PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor, Topology as TopologyProto, }; use node_service::node_service_server::{NodeService, NodeServiceServer}; use node_service::{TensorRequest}; use topology::Topology; pub mod node_service { tonic::include_proto!("node_service"); // The string specified here must match the proto package name } struct Node { current_topology: Topology, } impl Default for Node { fn default() -> Self { Self { current_topology: Topology::default(), } } } #[derive(Debug, Deserialize, Serialize, Clone)] #[serde(tag = "type")] enum OpaqueStatus { NodeStatus(NodeStatus), DownloadProgress(DownloadProgress), SupportedInferenceEngines(SupportedInferenceEngines), } #[derive(Debug, Deserialize, Serialize, Clone)] struct Shard { pub model_id: String, pub start_layer: i32, pub end_layer: i32, pub n_layers: i32, } #[derive(Debug, Deserialize, Serialize, Clone)] struct NodeStatus { node_id: String, status: String, base_shard: Shard, shard: Shard, prompt: String, request_id: String, } impl NodeStatus { fn is_start(&self) -> bool { self.status.starts_with("start_") } fn is_end(&self) -> bool { self.status.starts_with("end_") } } #[derive(Debug, Deserialize, Serialize, Clone)] struct DownloadProgress { node_id: String, progress: Value, } #[derive(Debug, Deserialize, Serialize, Clone)] struct SupportedInferenceEngines { node_id: String, engines: Vec, } impl Node { fn on_opaque_status(&self, _request_id: String, status: String) { let status = serde_json::from_str::(&status).unwrap(); match status { OpaqueStatus::NodeStatus(node_status) => self.on_node_status(node_status), OpaqueStatus::DownloadProgress(download_progress) => { self.on_download_progress(download_progress) } OpaqueStatus::SupportedInferenceEngines(supported_inference_engines) => { self.on_supported_inference_engines(supported_inference_engines) } } } fn on_node_status(&self, node_status: NodeStatus) { println!("Received NodeStatus: {}", node_status.status); // This seems to only be used for visualization so we can ignore it for now // if node_status.is_start() { // self.current_topology.active_node_id = node_status.node_id; // } else if node_status.is_end() { // if node_status.node_id == self.current_topology.active_node_id { // self.current_topology.active_node_id = None; // } // } } fn on_download_progress(&self, download_progress: DownloadProgress) { // This is only used for visualization so we can ignore it for now } fn on_supported_inference_engines( &self, supported_inference_engines: SupportedInferenceEngines, ) { println!( "Received SupportedInferenceEngines: {}", supported_inference_engines.engines.join(", ") ); // let node_id = supported_inference_engines.node_id; // let engines = supported_inference_engines.engines; // self.topology_inference_engines_pool.append(engines); todo!(); } } #[tonic::async_trait] impl NodeService for Node { async fn send_prompt( &self, request: Request, ) -> Result, Status> { todo!() } async fn send_tensor( &self, request: Request, ) -> Result, Status> { todo!() } async fn send_example( &self, request: Request, ) -> Result, Status> { todo!() } async fn collect_topology( &self, request: Request, ) -> Result, Status> { todo!() } async fn send_result( &self, request: Request, ) -> Result, Status> { todo!() } async fn send_opaque_status( &self, request: Request, ) -> Result, Status> { let request = request.into_inner(); let request_id = request.request_id; let status = request.status; println!( "Received SendOpaqueStatus request: {} {}", request_id, status ); Ok(Response::new(Empty {})) } async fn health_check( &self, request: Request, ) -> Result, Status> { Ok(Response::new(HealthCheckResponse { is_healthy: true })) } } #[tokio::main] async fn main() -> Result<(), Box> { // install global collector configured based on RUST_LOG env var. tracing_subscriber::fmt::init(); let grpc_addr = "[::1]:50051".parse()?; let node = Node::default(); // TODO: Also implement discovery Server::builder() .add_service(NodeServiceServer::new(node)) .serve(grpc_addr) .await?; Ok(()) }