diff --git a/src/discovery/broadcast.rs b/src/discovery/broadcast.rs new file mode 100644 index 0000000..6f981ec --- /dev/null +++ b/src/discovery/broadcast.rs @@ -0,0 +1,44 @@ +use crate::discovery; +use crate::discovery::DiscoveryMessage; +use crate::discovery::NodeInfo; +use crate::network::BroadcastCreationInfo; +use std::net::{IpAddr, SocketAddr}; + +pub async fn listen_all(node_info: NodeInfo, broadcast_creation_infos: Vec) { + let sockets_and_messages = + broadcast_creation_infos + .iter() + .map(|broadcast_creation_info: &BroadcastCreationInfo| { + let socket_addr = + SocketAddr::new(IpAddr::V4(broadcast_creation_info.bind_address), 0); + + let socket = discovery::bind_to_address(socket_addr); + + let message = serde_json::to_vec(&DiscoveryMessage { + message_type: "discovery".to_string(), + node_id: node_info.node_id.clone(), + grpc_port: node_info.grpc_port, + device_capabilities: node_info.device_capabilities.clone(), + priority: broadcast_creation_info.interface_type.priority(), + interface_name: broadcast_creation_info.interface_name.clone(), + interface_type: broadcast_creation_info.interface_type.to_string(), + }) + .unwrap(); + + (socket, broadcast_creation_info.broadcast_address, message) + }); + + loop { + for (socket, broadcast_address, message) in sockets_and_messages.clone() { + socket + .send_to( + &message, + SocketAddr::new(IpAddr::V4(broadcast_address), node_info.broadcast_port), + ) + .await + .unwrap(); + } + + tokio::time::sleep(node_info.broadcast_interval).await; + } +} diff --git a/src/discovery/mod.rs b/src/discovery/mod.rs index cea2eec..12b27c3 100644 --- a/src/discovery/mod.rs +++ b/src/discovery/mod.rs @@ -1,12 +1,13 @@ -use crate::network::BroadcastCreationInfo; +use crate::network::get_broadcast_creation_info; use crate::topology::DeviceCapabilities; use serde::{Deserialize, Serialize}; use socket2::{Domain, Protocol, Socket, Type}; -use std::net::{IpAddr, SocketAddr}; +use std::net::SocketAddr; use std::time::Duration; use tokio::net::UdpSocket; -use udp_listen::NodeInfo; +use tokio::task::JoinHandle; +mod broadcast; mod udp_listen; #[derive(Debug, Serialize, Deserialize)] @@ -21,7 +22,7 @@ pub struct DiscoveryMessage { pub interface_type: String, } -fn bind_to_address(address: SocketAddr) -> UdpSocket { +pub fn bind_to_address(address: SocketAddr) -> UdpSocket { let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)).unwrap(); socket.set_broadcast(true).unwrap(); socket.set_reuse_address(true).unwrap(); @@ -33,46 +34,43 @@ fn bind_to_address(address: SocketAddr) -> UdpSocket { UdpSocket::from_std(socket.into()).unwrap() } -pub async fn listen_all( +#[derive(Debug, Clone)] +pub struct NodeInfo { + pub node_id: String, + pub discovery_listen_port: u16, + pub broadcast_port: u16, + pub broadcast_interval: Duration, + pub grpc_port: u16, + pub allowed_peer_ids: Option>, + pub allowed_interfaces: Option>, + pub discovery_timeout: Duration, + pub device_capabilities: DeviceCapabilities, +} + +pub struct UdpDiscovery { node_info: NodeInfo, - broadcast_port: u16, - broadcast_interval: Duration, - broadcast_creation_infos: Vec, -) { - let sockets_and_messages = - broadcast_creation_infos - .iter() - .map(|broadcast_creation_info: &BroadcastCreationInfo| { - let socket_addr = - SocketAddr::new(IpAddr::V4(broadcast_creation_info.bind_address), 0); + discovery_handle: JoinHandle<()>, + presence_handle: JoinHandle<()>, + peer_manager_handle: JoinHandle<()>, +} - let socket = bind_to_address(socket_addr); +impl UdpDiscovery { + pub fn new(node_info: NodeInfo) -> Self { + let broadcast_creation_info = get_broadcast_creation_info(); + let discovery_handle = tokio::spawn(broadcast::listen_all(node_info.clone(), broadcast_creation_info)); + let (presence_handle, peer_manager_handle) = udp_listen::manage_discovery(node_info.clone()); - let message = serde_json::to_vec(&DiscoveryMessage { - message_type: "discovery".to_string(), - node_id: node_info.node_id.clone(), - grpc_port: node_info.grpc_port, - device_capabilities: node_info.device_capabilities.clone(), - priority: broadcast_creation_info.interface_type.priority(), - interface_name: broadcast_creation_info.interface_name.clone(), - interface_type: broadcast_creation_info.interface_type.to_string(), - }) - .unwrap(); - - (socket, broadcast_creation_info.broadcast_address, message) - }); - - loop { - for (socket, broadcast_address, message) in sockets_and_messages.clone() { - socket - .send_to( - &message, - SocketAddr::new(IpAddr::V4(broadcast_address), broadcast_port), - ) - .await - .unwrap(); + UdpDiscovery { + node_info, + discovery_handle, + presence_handle, + peer_manager_handle, } + } - tokio::time::sleep(broadcast_interval).await; + pub fn stop(&self) { + self.discovery_handle.abort(); + self.presence_handle.abort(); + self.peer_manager_handle.abort(); } } diff --git a/src/discovery/udp_listen.rs b/src/discovery/udp_listen.rs index 9fc5abe..87fd92f 100644 --- a/src/discovery/udp_listen.rs +++ b/src/discovery/udp_listen.rs @@ -1,4 +1,6 @@ +use crate::discovery::{DiscoveryMessage, NodeInfo}; use crate::orchestration::PeerHandle; +use crate::topology::DeviceCapabilities; use std::collections::HashMap; use std::net::SocketAddr; use std::time::Duration; @@ -6,23 +8,11 @@ use system_configuration::sys::libc::disconnectx; use tokio::net::UdpSocket; use tokio::select; use tokio::sync::mpsc::UnboundedSender; +use tokio::task::JoinHandle; use tonic::transport::Error; use tracing::{debug, error, info}; -use crate::discovery::DiscoveryMessage; -use crate::topology::DeviceCapabilities; -#[derive(Debug, Clone)] -pub struct NodeInfo { - pub node_id: String, - pub discovery_listen_port: u16, - pub grpc_port: u16, - pub allowed_peer_ids: Option>, - pub allowed_interfaces: Option>, - pub discovery_timeout: Duration, - pub device_capabilities: DeviceCapabilities, -} - -pub async fn listen_for_discovery( +async fn listen_for_discovery( node_info: NodeInfo, tx: UnboundedSender<(SocketAddr, DiscoveryMessage)>, ) { @@ -129,22 +119,28 @@ async fn handle_new_peer( peers.insert(message.node_id, new_peer); } -pub async fn manage_discovery(node_info: NodeInfo) { +pub fn manage_discovery(node_info: NodeInfo) -> (JoinHandle<()>, JoinHandle<()>) { let mut peers: HashMap = HashMap::new(); let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<(SocketAddr, DiscoveryMessage)>(); - tokio::spawn(listen_for_discovery(node_info.clone(), tx)); - loop { - let action = select! { - _ = tokio::time::sleep(node_info.discovery_timeout) => Action::HealthChecks, - Some((addr, message)) = rx.recv() => Action::NewPeer(addr, message), - }; + // TODO: How do we handle killing this? + let listen_handle = tokio::spawn(listen_for_discovery(node_info.clone(), tx)); - match action { - Action::NewPeer(addr, message) => handle_new_peer(&mut peers, addr, message).await, - Action::HealthChecks => perform_health_checks(&mut peers).await, + let peer_manager_handle = tokio::spawn(async move { + loop { + let action = select! { + _ = tokio::time::sleep(node_info.discovery_timeout) => Action::HealthChecks, + Some((addr, message)) = rx.recv() => Action::NewPeer(addr, message), + }; + + match action { + Action::NewPeer(addr, message) => handle_new_peer(&mut peers, addr, message).await, + Action::HealthChecks => perform_health_checks(&mut peers).await, + } } - } + }); + + (listen_handle, peer_manager_handle) } async fn perform_health_checks(peers: &mut HashMap) {