diff --git a/Cargo.lock b/Cargo.lock index 74957f7..f2ec38e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -169,6 +169,15 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "crc32fast" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" +dependencies = [ + "cfg-if", +] + [[package]] name = "either" version = "1.13.0" @@ -218,6 +227,16 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" +[[package]] +name = "flate2" +version = "1.0.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c936bfdafb507ebbf50b8074c54fa31c5be9a1e7e5f467dd659697041407d07c" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + [[package]] name = "fnv" version = "1.0.7" @@ -1037,6 +1056,7 @@ dependencies = [ "axum", "base64", "bytes", + "flate2", "h2", "http", "http-body", diff --git a/Cargo.toml b/Cargo.toml index 8e47e67..40f66df 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ prost = "0.13.4" serde = { version = "1.0.217", features = ["derive"] } serde_json = "1.0.138" tokio = { version = "1.43.0", features = ["full"] } -tonic = "0.12.3" +tonic = { version = "0.12.3", features = ["gzip"] } thiserror = "2.0" tracing = "0.1" tracing-subscriber = "0.3" diff --git a/src/discovery/mod.rs b/src/discovery/mod.rs deleted file mode 100644 index a9c54b1..0000000 --- a/src/discovery/mod.rs +++ /dev/null @@ -1,262 +0,0 @@ -mod udp; - -use crate::topology::DeviceCapabilities; -use serde::{Deserialize, Serialize}; -use std::{ - collections::HashMap, - net::{IpAddr, SocketAddr}, - sync::Arc, - time::{Duration, Instant}, -}; -use thiserror::Error; -use tokio::{net::UdpSocket, sync::RwLock, time}; -use tracing::{error}; - -#[derive(Error, Debug)] -pub enum DiscoveryError { - #[error("IO error: {0}")] - Io(#[from] std::io::Error), - - #[error("JSON serialization error: {0}")] - Json(#[from] serde_json::Error), - - #[error("UTF-8 conversion error: {0}")] - Utf8(#[from] std::string::FromUtf8Error), - - #[error("Address parse error: {0}")] - AddrParse(#[from] std::net::AddrParseError), - - #[error("Network interface error: {0}")] - NetworkInterface(String), - - #[error("Peer error: {0}")] - Peer(String), -} - -// Constants -const DEBUG: i32 = 0; -const BROADCAST_ADDR: &str = "255.255.255.255"; - -#[derive(Debug, Clone, Serialize, Deserialize)] -struct DiscoveryMessage { - message_type: String, - node_id: String, - grpc_port: u16, - device_capabilities: DeviceCapabilities, - priority: i32, - interface_name: String, - interface_type: String, -} - -struct PeerInfo { - peer_handle: PeerHandle, - connected_at: Instant, - last_seen: Instant, - priority: i32, -} - -struct UdpDiscovery { - node_id: String, - node_port: u16, - listen_port: u16, - broadcast_port: u16, - broadcast_interval: Duration, - discovery_timeout: Duration, - device_capabilities: DeviceCapabilities, - allowed_node_ids: Option>, - allowed_interface_types: Option>, - known_peers: Arc>>, -} - -impl UdpDiscovery { - pub fn new( - node_id: String, - node_port: u16, - listen_port: u16, - broadcast_port: u16, - broadcast_interval: Duration, - discovery_timeout: Duration, - device_capabilities: DeviceCapabilities, - allowed_node_ids: Option>, - allowed_interface_types: Option>, - ) -> Self { - Self { - node_id, - node_port, - listen_port, - broadcast_port, - broadcast_interval, - discovery_timeout, - device_capabilities, - allowed_node_ids, - allowed_interface_types, - known_peers: Arc::new(RwLock::new(HashMap::new())), - } - } - - pub async fn start(&self) -> Result<(), DiscoveryError> { - let broadcast_task = self.task_broadcast_presence(); - let listen_task = self.task_listen_for_peers(); - let cleanup_task = self.task_cleanup_peers(); - - tokio::try_join!(broadcast_task, listen_task, cleanup_task)?; - Ok(()) - } -} - -impl UdpDiscovery { - async fn task_listen_for_peers(&self) -> Result<(), DiscoveryError> { - let socket = UdpSocket::bind(format!("0.0.0.0:{}", self.listen_port)).await?; - let mut buf = vec![0u8; 65535]; - - loop { - let (len, addr) = socket.recv_from(&mut buf).await?; - if len == 0 { - continue; - } - - if let Ok(message) = String::from_utf8(buf[..len].to_vec()) { - if let Ok(discovery_message) = serde_json::from_str::(&message) { - self.handle_discovery_message(discovery_message, addr) - .await - .map_err(|e| DiscoveryError::Peer(e.to_string()))?; - } - } - } - } -} - -impl UdpDiscovery { - async fn task_broadcast_presence(&self) -> Result<(), DiscoveryError> { - let socket = UdpSocket::bind("0.0.0.0:0").await?; - socket.set_broadcast(true)?; - - loop { - let interfaces = get_all_ip_addresses_and_interfaces()?; - for (addr, interface_name) in interfaces { - let (interface_priority, interface_type) = - get_interface_priority_and_type(&interface_name) - .await - .map_err(|e| DiscoveryError::NetworkInterface(e.to_string()))?; - - let message = DiscoveryMessage { - message_type: "discovery".to_string(), - node_id: self.node_id.clone(), - grpc_port: self.node_port, - device_capabilities: self.device_capabilities.clone(), - priority: interface_priority, - interface_name: interface_name.clone(), - interface_type: interface_type.clone(), - }; - - let message_json = serde_json::to_string(&message)?; - let broadcast_addr = - SocketAddr::new(get_broadcast_address(&addr).parse()?, self.broadcast_port); - - if let Err(e) = socket - .send_to(message_json.as_bytes(), &broadcast_addr) - .await - { - error!("Error broadcasting to {}: {}", broadcast_addr, e); - } - } - - time::sleep(self.broadcast_interval).await; - } - } -} - -impl UdpDiscovery { - async fn task_cleanup_peers(&self) -> Result<(), DiscoveryError> { - loop { - let now = Instant::now(); - - // TODO: Do we really want a read then write lock or should we just take a write lock - // to begin with? - let peers_to_remove = { - let mut peers_to_remove = Vec::new(); - - let peers = self.known_peers.read().await; - for (peer_id, peer_info) in peers.iter() { - if self.should_remove_peer(peer_info, now).await { - peers_to_remove.push(peer_id.clone()); - } - } - - peers_to_remove - }; - - { - let mut peers = self.known_peers.write().await; - for peer_id in peers_to_remove { - peers.remove(&peer_id); - } - }; - - time::sleep(self.broadcast_interval).await; - } - } - - async fn should_remove_peer(&self, peer_info: &PeerInfo, now: Instant) -> bool { - let is_connected = peer_info - .peer_handle - .is_connected() - .await - .ok() - .unwrap_or(false); - - if !is_connected { - return true; - } - - if now.duration_since(peer_info.connected_at) > self.discovery_timeout { - return true; - } - - if now.duration_since(peer_info.last_seen) > self.discovery_timeout { - return true; - } - - peer_info - .peer_handle - .health_check() - .await - .ok() - .unwrap_or(false) - } -} - -fn get_broadcast_address(ip_addr: &IpAddr) -> String { - match ip_addr { - IpAddr::V4(addr) => { - let octets = addr.octets(); - format!("{}.{}.{}.255", octets[0], octets[1], octets[2]) - } - _ => BROADCAST_ADDR.to_string(), - } -} - -// You'll need to implement these functions based on your system -fn get_all_ip_addresses_and_interfaces() -> Result, DiscoveryError> { - // Implementation needed - Ok(vec![]) -} - -async fn get_interface_priority_and_type( - interface_name: &str, -) -> Result<(i32, String), DiscoveryError> { - // Implementation needed - Ok((0, "unknown".to_string())) -} - -struct PeerHandle; - -impl PeerHandle { - async fn is_connected(&self) -> Result { - Ok(true) - } - - async fn health_check(&self) -> Result { - Ok(true) - } -} diff --git a/src/discovery/udp.rs b/src/discovery/udp.rs deleted file mode 100644 index 8c3bf5c..0000000 --- a/src/discovery/udp.rs +++ /dev/null @@ -1,98 +0,0 @@ -use std::collections::HashMap; -use std::net::SocketAddr; -use crate::topology::DeviceCapabilities; -use serde::{Deserialize, Serialize}; -use tokio::sync::RwLock; -use tokio::task::JoinHandle; -use tracing::{debug, error, info}; -use crate::discovery::PeerInfo; - -struct UdpDiscovery { - listen_port: u16, - listen_buffer_size: usize, - allowed_interface_types: Option>, - known_peers: RwLock>, - - listen_handle: JoinHandle<()>, - presence_handle: JoinHandle<()>, -} - -#[derive(Debug, Serialize, Deserialize)] -struct DiscoveryMessage { - #[serde(rename = "type")] - message_type: String, - node_id: String, - grpc_port: u16, - device_capabilities: DeviceCapabilities, - priority: i32, - interface_name: String, - interface_type: String, -} - -impl UdpDiscovery { - async fn listen(&self) { - let listen_socket = tokio::net::UdpSocket::bind("0.0.0.0:42069").await.unwrap(); - let mut buf = vec![0; self.listen_buffer_size]; - - loop { - // This will block waiting for a message. - // If this fails it will end the loop and the task, which is what we want. - let Ok((len, addr)) = listen_socket.recv_from(&mut buf).await.unwrap(); - - let Some(message) = serde_json::from_slice::(&buf[..len]) else { - error!( - "Received invalid discovery message from {} {:?}", - addr, - str::from_utf8(&buf[..len]) - ); - continue; - }; - - if let Some(ref allowed_interface_types) = self.allowed_interface_types { - if !allowed_interface_types.contains(&message.interface_type) { - debug!("Ignoring message from {} because interface type {} is not allowed", addr, message.interface_type); - continue; - } - } - - self.on_discovery_message(message, addr).await; - } - } - - async fn on_discovery_message(&self, message: DiscoveryMessage, addr: SocketAddr) { - let known_peers = self.known_peers.write().await; - let existing = known_peers.get(&message.node_id); - - if let Some(existing) = existing { - - } - - /* - if peer_id not in self.known_peers or self.known_peers[peer_id][0].addr() != f"{peer_host}:{peer_port}": - if peer_id in self.known_peers: - existing_peer_prio = self.known_peers[peer_id][3] - if existing_peer_prio >= peer_prio: - if DEBUG >= 1: - print( - f"Ignoring peer {peer_id} at {peer_host}:{peer_port} with priority {peer_prio} because we already know about a peer with higher or equal priority: {existing_peer_prio}") - return - new_peer_handle = self.create_peer_handle(peer_id, f"{peer_host}:{peer_port}", - f"{peer_interface_type} ({peer_interface_name})", - device_capabilities) - if not await new_peer_handle.health_check(): - if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Skipping.") - return - if DEBUG >= 1: print( - f"Adding {peer_id=} at {peer_host}:{peer_port}. Replace existing peer_id: {peer_id in self.known_peers}") - self.known_peers[peer_id] = (new_peer_handle, time.time(), time.time(), peer_prio) - else: - if not await self.known_peers[peer_id][0].health_check(): - if DEBUG >= 1: print(f"Peer {peer_id} at {peer_host}:{peer_port} is not healthy. Removing.") - if peer_id in self.known_peers: del self.known_peers[peer_id] - return - if peer_id in self.known_peers: self.known_peers[peer_id] = ( - self.known_peers[peer_id][0], self.known_peers[peer_id][1], time.time(), peer_prio) - */ - } - } -} diff --git a/src/main.rs b/src/main.rs index 1ac08d0..a9542b1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,4 @@ mod topology; -mod discovery; mod orchestration; use serde::{Deserialize, Serialize}; diff --git a/src/orchestration.rs b/src/orchestration.rs index 60639cf..07fdf2d 100644 --- a/src/orchestration.rs +++ b/src/orchestration.rs @@ -1,20 +1,28 @@ use std::net::SocketAddr; +use tonic::codec::CompressionEncoding; +use crate::node_service::node_service_client::NodeServiceClient; use crate::topology::DeviceCapabilities; struct PeerHandle { node_id: String, address: SocketAddr, + description: Option, + client: NodeServiceClient, device_capabilities: DeviceCapabilities, } impl PeerHandle { - fn new(node_id: String, address: SocketAddr, device_capabilities: DeviceCapabilities) -> Self { - crate::node_service::node_service_client::NodeServiceClient::connect(address); + async fn new(node_id: String, address: SocketAddr, description: Option, device_capabilities: DeviceCapabilities) -> Result { + let endpoint = format!("http://{}", address); + let client = NodeServiceClient::connect(endpoint).await? + .accept_compressed(CompressionEncoding::Gzip); - Self { + Ok(Self { node_id, + description, address, + client, device_capabilities, - } + }) } -} \ No newline at end of file +}