From 203cb50380cc62580d13c2d11428099cc4a3da64 Mon Sep 17 00:00:00 2001 From: Joshua Coles Date: Wed, 12 Feb 2025 11:11:38 +0000 Subject: [PATCH] Support listening! --- src/discovery.rs | 16 ++--- src/main.rs | 1 + src/orchestration.rs | 18 +++--- src/udp_listen.rs | 140 +++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 159 insertions(+), 16 deletions(-) create mode 100644 src/udp_listen.rs diff --git a/src/discovery.rs b/src/discovery.rs index cb86949..bbbb507 100644 --- a/src/discovery.rs +++ b/src/discovery.rs @@ -13,15 +13,15 @@ struct NodeInfo { } #[derive(Debug, Serialize, Deserialize)] -struct DiscoveryMessage { +pub struct DiscoveryMessage { #[serde(rename = "type")] - message_type: String, - node_id: String, - grpc_port: u16, - device_capabilities: DeviceCapabilities, - priority: u8, - interface_name: String, - interface_type: String, + pub message_type: String, + pub node_id: String, + pub grpc_port: u16, + pub device_capabilities: DeviceCapabilities, + pub priority: u8, + pub interface_name: String, + pub interface_type: String, } fn bind_to_address(address: SocketAddr) -> UdpSocket { diff --git a/src/main.rs b/src/main.rs index 0453fd3..1130461 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ mod topology; mod orchestration; mod discovery; +mod udp_listen; mod network; use serde::{Deserialize, Serialize}; diff --git a/src/orchestration.rs b/src/orchestration.rs index 07fdf2d..bdab95c 100644 --- a/src/orchestration.rs +++ b/src/orchestration.rs @@ -3,16 +3,17 @@ use tonic::codec::CompressionEncoding; use crate::node_service::node_service_client::NodeServiceClient; use crate::topology::DeviceCapabilities; -struct PeerHandle { - node_id: String, - address: SocketAddr, - description: Option, - client: NodeServiceClient, - device_capabilities: DeviceCapabilities, +pub struct PeerHandle { + pub node_id: String, + pub address: SocketAddr, + pub address_priority: u8, + pub description: Option, + pub client: tokio::sync::Mutex>, + pub device_capabilities: DeviceCapabilities, } impl PeerHandle { - async fn new(node_id: String, address: SocketAddr, description: Option, device_capabilities: DeviceCapabilities) -> Result { + pub async fn new(node_id: String, address: SocketAddr, address_priority: u8, description: Option, device_capabilities: DeviceCapabilities) -> Result { let endpoint = format!("http://{}", address); let client = NodeServiceClient::connect(endpoint).await? .accept_compressed(CompressionEncoding::Gzip); @@ -20,8 +21,9 @@ impl PeerHandle { Ok(Self { node_id, description, + address_priority, address, - client, + client: tokio::sync::Mutex::new(client), device_capabilities, }) } diff --git a/src/udp_listen.rs b/src/udp_listen.rs new file mode 100644 index 0000000..6bbc98c --- /dev/null +++ b/src/udp_listen.rs @@ -0,0 +1,140 @@ +use crate::orchestration::PeerHandle; +use crate::{discovery::DiscoveryMessage, node_service::HealthCheckRequest}; +use std::collections::HashMap; +use std::net::SocketAddr; +use system_configuration::sys::libc::disconnectx; +use tokio::net::UdpSocket; +use tokio::select; +use tokio::sync::mpsc::UnboundedSender; +use tonic::transport::Error; +use tracing::{debug, error, info}; + +#[derive(Debug, Clone)] +struct NodeInfo { + id: String, + listen_port: u16, + allowed_peer_ids: Option>, + allowed_interfaces: Option>, +} + +pub async fn listen_for_discovery( + node_info: NodeInfo, + tx: UnboundedSender<(SocketAddr, DiscoveryMessage)>, +) { + let socket = UdpSocket::bind(format!("0.0.0.0:{}", node_info.listen_port)) + .await + .unwrap(); + let mut buf = vec![0u8; 65535]; + + loop { + let (len, addr) = socket.recv_from(&mut buf).await.unwrap(); + if len == 0 { + continue; + } + + let Ok(message) = String::from_utf8(buf[..len].to_vec()) else { + error!("Invalid UTF-8 message from {}", addr); + continue; + }; + + let Ok(message) = serde_json::from_str::(&message) else { + error!("Invalid discovery message from {}", addr); + continue; + }; + + // Validate message + if message.message_type != "discovery" || message.node_id == node_info.id { + continue; + } + + if node_info + .allowed_peer_ids + .as_ref() + .map(|ids| !ids.contains(&message.node_id)) + .unwrap_or(false) + { + debug!( + "Ignoring peer {peer_id} as it's not in the allowed node IDs list", + peer_id = message.node_id + ); + continue; + } + + if node_info + .allowed_interfaces + .as_ref() + .map(|interfaces| !interfaces.contains(&message.interface_name)) + .unwrap_or(false) + { + debug!("Ignoring peer {peer_id} as it's interface {interface} is not in the allowed interfaces list", peer_id = message.node_id, interface = message.interface_name); + continue; + } + + tx.send((addr, message)).unwrap(); + } +} + +struct PeerInfo { + address: SocketAddr, + priority: u8, +} + +pub async fn manage_discovery(node_info: NodeInfo) { + let mut peers: HashMap = HashMap::new(); + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<(SocketAddr, DiscoveryMessage)>(); + tokio::spawn(listen_for_discovery(node_info.clone(), tx)); + + while let Some((addr, message)) = rx.recv().await { + info!("Received discovery message from {}", message.node_id); + let existing = peers.get(&message.node_id); + let insert_new = match existing { + None => true, + Some(existing) => { + existing.address != addr && existing.address_priority < message.priority + } + }; + + if !insert_new { + continue; + } + + let description = format!("{} ({})", message.interface_type, message.interface_name); + + let a = PeerHandle::new( + message.node_id.clone(), + addr.clone(), + message.priority, + Some(description), + message.device_capabilities.clone(), + ) + .await; + + let a = match a { + Ok(a) => a, + Err(error) => { + error!( + "Failed to connect to new peer {} at {}: {}", + message.node_id, addr, error + ); + continue; + } + }; + + let is_healthy = a + .client + .lock() + .await + .health_check(HealthCheckRequest::default()) + .await + .ok() + .map(|x| x.into_inner().is_healthy) + .unwrap_or(false); + + if !is_healthy { + error!("Peer {} is not healthy", message.node_id); + continue; + } + + peers.insert(message.node_id, a); + } +}