diff --git a/Cargo.lock b/Cargo.lock index 400f78b..74957f7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -198,9 +198,12 @@ dependencies = [ "prost", "serde", "serde_json", + "thiserror", "tokio", "tonic", "tonic-build", + "tracing", + "tracing-subscriber", ] [[package]] @@ -460,6 +463,12 @@ version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + [[package]] name = "libc" version = "0.2.169" @@ -532,6 +541,16 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03" +[[package]] +name = "nu-ansi-term" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +dependencies = [ + "overload", + "winapi", +] + [[package]] name = "object" version = "0.36.7" @@ -547,6 +566,12 @@ version = "1.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e" +[[package]] +name = "overload" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" + [[package]] name = "parking_lot" version = "0.12.3" @@ -844,6 +869,15 @@ dependencies = [ "serde", ] +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + [[package]] name = "signal-hook-registry" version = "1.4.2" @@ -909,6 +943,36 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "thiserror" +version = "2.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "thread_local" +version = "1.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b9ef9bad013ada3808854ceac7b46812a6465ba368859a37e2100283d2d719c" +dependencies = [ + "cfg-if", + "once_cell", +] + [[package]] name = "tokio" version = "1.43.0" @@ -1081,6 +1145,32 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e672c95779cf947c5311f83787af4fa8fffd12fb27e4993211a84bdfd9610f9c" dependencies = [ "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +dependencies = [ + "nu-ansi-term", + "sharded-slab", + "smallvec", + "thread_local", + "tracing-core", + "tracing-log", ] [[package]] @@ -1095,6 +1185,12 @@ version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034" +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + [[package]] name = "want" version = "0.3.1" @@ -1119,6 +1215,28 @@ dependencies = [ "wit-bindgen-rt", ] +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-sys" version = "0.52.0" diff --git a/Cargo.toml b/Cargo.toml index 8f28140..8e47e67 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,9 @@ serde = { version = "1.0.217", features = ["derive"] } serde_json = "1.0.138" tokio = { version = "1.43.0", features = ["full"] } tonic = "0.12.3" +thiserror = "2.0" +tracing = "0.1" +tracing-subscriber = "0.3" [build-dependencies] tonic-build = "0.12.3" diff --git a/src/discovery.rs b/src/discovery.rs new file mode 100644 index 0000000..e562c21 --- /dev/null +++ b/src/discovery.rs @@ -0,0 +1,260 @@ +use crate::topology::DeviceCapabilities; +use serde::{Deserialize, Serialize}; +use std::{ + collections::HashMap, + net::{IpAddr, SocketAddr}, + sync::Arc, + time::{Duration, Instant}, +}; +use thiserror::Error; +use tokio::{net::UdpSocket, sync::RwLock, time}; +use tracing::{error}; + +#[derive(Error, Debug)] +pub enum DiscoveryError { + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + #[error("JSON serialization error: {0}")] + Json(#[from] serde_json::Error), + + #[error("UTF-8 conversion error: {0}")] + Utf8(#[from] std::string::FromUtf8Error), + + #[error("Address parse error: {0}")] + AddrParse(#[from] std::net::AddrParseError), + + #[error("Network interface error: {0}")] + NetworkInterface(String), + + #[error("Peer error: {0}")] + Peer(String), +} + +// Constants +const DEBUG: i32 = 0; +const BROADCAST_ADDR: &str = "255.255.255.255"; + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct DiscoveryMessage { + message_type: String, + node_id: String, + grpc_port: u16, + device_capabilities: DeviceCapabilities, + priority: i32, + interface_name: String, + interface_type: String, +} + +struct PeerInfo { + peer_handle: PeerHandle, + connected_at: Instant, + last_seen: Instant, + priority: i32, +} + +struct UdpDiscovery { + node_id: String, + node_port: u16, + listen_port: u16, + broadcast_port: u16, + broadcast_interval: Duration, + discovery_timeout: Duration, + device_capabilities: DeviceCapabilities, + allowed_node_ids: Option>, + allowed_interface_types: Option>, + known_peers: Arc>>, +} + +impl UdpDiscovery { + pub fn new( + node_id: String, + node_port: u16, + listen_port: u16, + broadcast_port: u16, + broadcast_interval: Duration, + discovery_timeout: Duration, + device_capabilities: DeviceCapabilities, + allowed_node_ids: Option>, + allowed_interface_types: Option>, + ) -> Self { + Self { + node_id, + node_port, + listen_port, + broadcast_port, + broadcast_interval, + discovery_timeout, + device_capabilities, + allowed_node_ids, + allowed_interface_types, + known_peers: Arc::new(RwLock::new(HashMap::new())), + } + } + + pub async fn start(&self) -> Result<(), DiscoveryError> { + let broadcast_task = self.task_broadcast_presence(); + let listen_task = self.task_listen_for_peers(); + let cleanup_task = self.task_cleanup_peers(); + + tokio::try_join!(broadcast_task, listen_task, cleanup_task)?; + Ok(()) + } +} + +impl UdpDiscovery { + async fn task_listen_for_peers(&self) -> Result<(), DiscoveryError> { + let socket = UdpSocket::bind(format!("0.0.0.0:{}", self.listen_port)).await?; + let mut buf = vec![0u8; 65535]; + + loop { + let (len, addr) = socket.recv_from(&mut buf).await?; + if len == 0 { + continue; + } + + if let Ok(message) = String::from_utf8(buf[..len].to_vec()) { + if let Ok(discovery_message) = serde_json::from_str::(&message) { + self.handle_discovery_message(discovery_message, addr) + .await + .map_err(|e| DiscoveryError::Peer(e.to_string()))?; + } + } + } + } +} + +impl UdpDiscovery { + async fn task_broadcast_presence(&self) -> Result<(), DiscoveryError> { + let socket = UdpSocket::bind("0.0.0.0:0").await?; + socket.set_broadcast(true)?; + + loop { + let interfaces = get_all_ip_addresses_and_interfaces()?; + for (addr, interface_name) in interfaces { + let (interface_priority, interface_type) = + get_interface_priority_and_type(&interface_name) + .await + .map_err(|e| DiscoveryError::NetworkInterface(e.to_string()))?; + + let message = DiscoveryMessage { + message_type: "discovery".to_string(), + node_id: self.node_id.clone(), + grpc_port: self.node_port, + device_capabilities: self.device_capabilities.clone(), + priority: interface_priority, + interface_name: interface_name.clone(), + interface_type: interface_type.clone(), + }; + + let message_json = serde_json::to_string(&message)?; + let broadcast_addr = + SocketAddr::new(get_broadcast_address(&addr).parse()?, self.broadcast_port); + + if let Err(e) = socket + .send_to(message_json.as_bytes(), &broadcast_addr) + .await + { + error!("Error broadcasting to {}: {}", broadcast_addr, e); + } + } + + time::sleep(self.broadcast_interval).await; + } + } +} + +impl UdpDiscovery { + async fn task_cleanup_peers(&self) -> Result<(), DiscoveryError> { + loop { + let now = Instant::now(); + + // TODO: Do we really want a read then write lock or should we just take a write lock + // to begin with? + let peers_to_remove = { + let mut peers_to_remove = Vec::new(); + + let peers = self.known_peers.read().await; + for (peer_id, peer_info) in peers.iter() { + if self.should_remove_peer(peer_info, now).await { + peers_to_remove.push(peer_id.clone()); + } + } + + peers_to_remove + }; + + { + let mut peers = self.known_peers.write().await; + for peer_id in peers_to_remove { + peers.remove(&peer_id); + } + }; + + time::sleep(self.broadcast_interval).await; + } + } + + async fn should_remove_peer(&self, peer_info: &PeerInfo, now: Instant) -> bool { + let is_connected = peer_info + .peer_handle + .is_connected() + .await + .ok() + .unwrap_or(false); + + if !is_connected { + return true; + } + + if now.duration_since(peer_info.connected_at) > self.discovery_timeout { + return true; + } + + if now.duration_since(peer_info.last_seen) > self.discovery_timeout { + return true; + } + + peer_info + .peer_handle + .health_check() + .await + .ok() + .unwrap_or(false) + } +} + +fn get_broadcast_address(ip_addr: &IpAddr) -> String { + match ip_addr { + IpAddr::V4(addr) => { + let octets = addr.octets(); + format!("{}.{}.{}.255", octets[0], octets[1], octets[2]) + } + _ => BROADCAST_ADDR.to_string(), + } +} + +// You'll need to implement these functions based on your system +fn get_all_ip_addresses_and_interfaces() -> Result, DiscoveryError> { + // Implementation needed + Ok(vec![]) +} + +async fn get_interface_priority_and_type( + interface_name: &str, +) -> Result<(i32, String), DiscoveryError> { + // Implementation needed + Ok((0, "unknown".to_string())) +} + +struct PeerHandle; + +impl PeerHandle { + async fn is_connected(&self) -> Result { + Ok(true) + } + + async fn health_check(&self) -> Result { + Ok(true) + } +} diff --git a/src/main.rs b/src/main.rs index feb4e2f..5fb1fec 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,5 @@ mod topology; +mod discovery; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -9,7 +10,7 @@ use crate::node_service::{ PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor, Topology as TopologyProto, }; use node_service::node_service_server::{NodeService, NodeServiceServer}; -use node_service::{Shard, TensorRequest}; +use node_service::{TensorRequest}; use topology::Topology; pub mod node_service { @@ -36,6 +37,15 @@ enum OpaqueStatus { SupportedInferenceEngines(SupportedInferenceEngines), } +#[derive(Debug, Deserialize, Serialize, Clone)] +struct Shard { + pub model_id: String, + pub start_layer: i32, + pub end_layer: i32, + pub n_layers: i32, +} + + #[derive(Debug, Deserialize, Serialize, Clone)] struct NodeStatus { node_id: String, @@ -69,7 +79,7 @@ struct SupportedInferenceEngines { } impl Node { - fn on_opaque_status(&self, request_id: String, status: String) { + fn on_opaque_status(&self, _request_id: String, status: String) { let status = serde_json::from_str::(&status).unwrap(); match status { @@ -140,7 +150,7 @@ impl NodeService for Node { async fn collect_topology( &self, request: Request, - ) -> Result, Status> { + ) -> Result, Status> { todo!() } @@ -155,8 +165,10 @@ impl NodeService for Node { &self, request: Request, ) -> Result, Status> { - let request_id = request.into_inner().request_id; - let status = request.into_inner().status; + let request = request.into_inner(); + let request_id = request.request_id; + let status = request.status; + println!( "Received SendOpaqueStatus request: {} {}", request_id, status @@ -174,6 +186,9 @@ impl NodeService for Node { #[tokio::main] async fn main() -> Result<(), Box> { + // install global collector configured based on RUST_LOG env var. + tracing_subscriber::fmt::init(); + let grpc_addr = "[::1]:50051".parse()?; let node = Node::default();