mod device_capability_data; mod discovery; mod network; mod orchestration; mod partitioning; mod topology; use serde::{Deserialize, Serialize}; use serde_json::Value; use tonic::{transport::Server, Request, Response, Status}; use crate::discovery::{NodeInfo, UdpDiscovery}; use crate::node_service::{ CollectTopologyRequest, Empty, ExampleRequest, HealthCheckRequest, HealthCheckResponse, InferenceState, Loss, PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor, Topology as TopologyProto, }; use node_service::node_service_server::{NodeService, NodeServiceServer}; use node_service::TensorRequest; use std::collections::HashSet; use topology::Topology; use uuid::Uuid; pub mod node_service { tonic::include_proto!("node_service"); // The string specified here must match the proto package name } #[derive(Debug)] struct Node { node_info: NodeInfo, current_topology: Topology, udp_discovery: UdpDiscovery, } impl Node { #[tracing::instrument] pub async fn process_prompt( &self, base_shard: Shard, prompt: String, request_id: String, inference_state: Option, ) { let shard = self .current_topology .get_shard_for_node(base_shard, &self.node_info.node_id); todo!(); // if shard.is_first_layer() { // let result = self // .inference_engine // .infer_prompt(request_id, shard, prompt, inference_state) // .await; // self.process_inference_result(shard, result, request_id, inference_state) // } else { // self.forward_prompt(shard, prompt, request_id, inference_state) // } } } impl Default for Node { fn default() -> Self { let node_info = NodeInfo::default(); Self { node_info: node_info.clone(), current_topology: Topology::default(), udp_discovery: UdpDiscovery::new(node_info), } } } #[derive(Debug, Deserialize, Serialize, Clone)] #[serde(tag = "type")] enum OpaqueStatus { NodeStatus(NodeStatus), DownloadProgress(DownloadProgress), SupportedInferenceEngines(SupportedInferenceEngines), } #[derive(Debug, Deserialize, Serialize, Clone)] struct Shard { pub model_id: String, pub start_layer: u32, pub end_layer: u32, #[serde(rename = "n_layers")] pub total_layers: u32, } impl Shard { pub fn is_first_layer(&self) -> bool { self.start_layer == 0 } pub fn is_last_layer(&self) -> bool { self.end_layer == self.total_layers - 1 } pub fn len(&self) -> u32 { self.end_layer - self.start_layer + 1 } } impl From for Shard { fn from(proto: node_service::Shard) -> Self { Self { model_id: proto.model_id, start_layer: proto.start_layer as u32, end_layer: proto.end_layer as u32, total_layers: proto.n_layers as u32, } } } #[derive(Debug, Deserialize, Serialize, Clone)] struct NodeStatus { node_id: String, status: String, base_shard: Shard, shard: Shard, prompt: String, request_id: String, } impl NodeStatus { fn is_start(&self) -> bool { self.status.starts_with("start_") } fn is_end(&self) -> bool { self.status.starts_with("end_") } } #[derive(Debug, Deserialize, Serialize, Clone)] struct DownloadProgress { node_id: String, progress: Value, } #[derive(Debug, Deserialize, Serialize, Clone)] struct SupportedInferenceEngines { node_id: String, engines: Vec, } impl Node { fn on_opaque_status(&self, _request_id: String, status: String) { let status = serde_json::from_str::(&status).unwrap(); match status { OpaqueStatus::NodeStatus(node_status) => self.on_node_status(node_status), OpaqueStatus::DownloadProgress(download_progress) => { self.on_download_progress(download_progress) } OpaqueStatus::SupportedInferenceEngines(supported_inference_engines) => { self.on_supported_inference_engines(supported_inference_engines) } } } fn on_node_status(&self, node_status: NodeStatus) { println!("Received NodeStatus: {}", node_status.status); // This seems to only be used for visualization so we can ignore it for now // if node_status.is_start() { // self.current_topology.active_node_id = node_status.node_id; // } else if node_status.is_end() { // if node_status.node_id == self.current_topology.active_node_id { // self.current_topology.active_node_id = None; // } // } } fn on_download_progress(&self, download_progress: DownloadProgress) { // This is only used for visualization so we can ignore it for now } fn on_supported_inference_engines( &self, supported_inference_engines: SupportedInferenceEngines, ) { println!( "Received SupportedInferenceEngines: {}", supported_inference_engines.engines.join(", ") ); // let node_id = supported_inference_engines.node_id; // let engines = supported_inference_engines.engines; // self.topology_inference_engines_pool.append(engines); todo!(); } } #[tonic::async_trait] impl NodeService for Node { async fn send_prompt( &self, request: Request, ) -> Result, Status> { let request = request.into_inner(); let request_id = request .request_id .unwrap_or_else(|| Uuid::new_v4().to_string()); let result = self.process_prompt( request .shard .expect("No shard given. ExoPy does not allow this") .into(), request.prompt, request_id, request.inference_state, ); todo!(); } async fn send_tensor( &self, request: Request, ) -> Result, Status> { todo!() } async fn send_example( &self, request: Request, ) -> Result, Status> { todo!() } // TODO: Why aren't we using the request? async fn collect_topology( &self, request: Request, ) -> Result, Status> { let request = request.into_inner(); let max_depth = request.max_depth as u8; let visited = request.visited; self.update_topology_inner(max_depth, visited.into_iter().collect()) .await; Ok(Response::new(self.current_topology.clone().into())) } async fn send_result( &self, request: Request, ) -> Result, Status> { todo!() } async fn send_opaque_status( &self, request: Request, ) -> Result, Status> { let request = request.into_inner(); let request_id = request.request_id; let status = request.status; println!( "Received SendOpaqueStatus request: {} {}", request_id, status ); Ok(Response::new(Empty {})) } async fn health_check( &self, request: Request, ) -> Result, Status> { Ok(Response::new(HealthCheckResponse { is_healthy: true })) } } impl Node { async fn update_topology(&mut self) { let overall_max_depth = 4; let visited: HashSet = HashSet::new(); self.current_topology = self.update_topology_inner(overall_max_depth, visited).await; } async fn update_topology_inner(&self, max_depth: u8, mut visited: HashSet) -> Topology { let mut new_topology = Topology::default(); new_topology.update_node( self.node_info.node_id.clone(), self.node_info.device_capabilities.clone(), ); for peer in self.udp_discovery.peers.lock().await.values() { new_topology.update_node(peer.node_id.clone(), peer.device_capabilities.clone()); new_topology.update_edge( self.node_info.node_id.clone(), peer.node_id.clone(), peer.description.clone(), ); visited.insert(peer.node_id.clone()); if !visited.contains(&peer.node_id) { let topology = peer.collect_topology(visited.clone(), max_depth - 1).await; new_topology.merge_restricted(&peer.node_id, topology); } } new_topology } } #[tokio::main] async fn main() -> Result<(), Box> { // install global collector configured based on RUST_LOG env var. tracing_subscriber::fmt::init(); let grpc_addr = "[::1]:50051".parse()?; let node: Node = Node::default(); // TODO: Also implement discovery Server::builder() .add_service(NodeServiceServer::new(node)) .serve(grpc_addr) .await?; Ok(()) }