diff --git a/src/network.rs b/src/network.rs index 567b5f4..553f995 100644 --- a/src/network.rs +++ b/src/network.rs @@ -1,10 +1,13 @@ +use crate::topology::DeviceCapabilities; use network_interface::{Addr, NetworkInterface, NetworkInterfaceConfig}; +use serde::{Deserialize, Serialize}; use socket2::{Domain, Protocol, Socket, Type}; use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::time::Duration; -use serde::{Deserialize, Serialize}; +use system_configuration::core_foundation::array::CFArray; +use system_configuration::network_configuration::{get_interfaces, SCNetworkInterface, SCNetworkInterfaceType}; +use system_configuration::sys::network_configuration::SCNetworkInterfaceRef; use tokio::net::UdpSocket; -use crate::topology::DeviceCapabilities; #[derive(Copy, Clone, Debug)] enum InterfaceType { @@ -45,7 +48,6 @@ impl ToString for InterfaceType { } } - struct BroadcastCreationInfo { interface_name: String, interface_type: InterfaceType, @@ -54,7 +56,84 @@ struct BroadcastCreationInfo { broadcast_address: Ipv4Addr, } -fn get_broadcast_creation_info() -> Vec {} +fn get_broadcast_creation_info() -> Vec { + let interfaces = NetworkInterface::show().unwrap(); + let mut broadcast_info = Vec::new(); + + #[cfg(target_os = "macos")] + let sc_interfaces = get_interfaces(); + + for interface in interfaces { + // Skip interfaces without IPv4 addresses with broadcast + let map = interface.addr.iter().filter_map(|a| { + if let Addr::V4(v4) = a { + v4.broadcast.map(|b| (v4.ip, b)) + } else { + None + } + }); + + for (bind_address, broadcast_address) in map { + let interface_type = if cfg!(target_os = "macos") { + get_sc_interface_type(&interface.name, &sc_interfaces).unwrap_or_else(|| determine_interface_type(&interface.name)) + } else { + determine_interface_type(&interface.name) + }; + + broadcast_info.push(BroadcastCreationInfo { + interface_name: interface.name.clone(), + interface_type, + bind_address, + broadcast_address, + }); + } + } + + broadcast_info +} + +#[cfg(target_os = "macos")] +fn get_sc_interface_type(name: &str, sc_interfaces: &CFArray) -> Option { + sc_interfaces.iter().find_map(|sc_if| { + sc_if + .bsd_name() + .and_then(|bsd| (bsd.to_string() == name).then(|| sc_if.interface_type())) + .flatten() + .map(|t| match t { + SCNetworkInterfaceType::Ethernet => InterfaceType::Ethernet, + SCNetworkInterfaceType::IEEE80211 => InterfaceType::WiFi, + _ => InterfaceType::Other, + }) + }) +} + +fn determine_interface_type(name: &str) -> InterfaceType { + // Fallback to interface name pattern matching + if name.starts_with("docker") || name.starts_with("br-") || name.starts_with("veth") + || name.starts_with("cni") || name.starts_with("flannel") || name.starts_with("calico") + || name.starts_with("weave") || name.contains("bridge") + { + InterfaceType::ContainerVirtual + } else if name.starts_with("lo") { + InterfaceType::Loopback + } else if name.starts_with("tb") || name.starts_with("nx") || name.starts_with("ten") { + InterfaceType::Thunderbolt + } else if (name.starts_with("eth") || name.starts_with("en")) && !matches!(name, "en0" | "en1") + { + InterfaceType::Ethernet + } else if name.starts_with("wlan") || name.starts_with("wifi") || name.starts_with("wl") + || matches!(name, "en0" | "en1") + { + InterfaceType::WiFi + } else if name.starts_with("tun") || name.starts_with("tap") || name.starts_with("vtun") + || name.starts_with("utun") || name.starts_with("gif") || name.starts_with("stf") + || name.starts_with("awdl") || name.starts_with("llw") + { + InterfaceType::ExternalVirtual + } else { + InterfaceType::Other + } +} struct NodeInfo { node_id: String, @@ -74,11 +153,13 @@ struct DiscoveryMessage { interface_type: String, } -async fn listen(broadcast_creation_info: BroadcastCreationInfo, node_info: NodeInfo, broadcast_port: u16, broadcast_interval: Duration) { - let socket_addr = SocketAddr::new( - IpAddr::V4(broadcast_creation_info.bind_address), - 0, - ); +async fn listen( + broadcast_creation_info: BroadcastCreationInfo, + node_info: NodeInfo, + broadcast_port: u16, + broadcast_interval: Duration, +) { + let socket_addr = SocketAddr::new(IpAddr::V4(broadcast_creation_info.bind_address), 0); let socket = bind_to_address(socket_addr); @@ -90,13 +171,20 @@ async fn listen(broadcast_creation_info: BroadcastCreationInfo, node_info: NodeI priority: broadcast_creation_info.interface_type.priority(), interface_name: broadcast_creation_info.interface_name, interface_type: broadcast_creation_info.interface_type.to_string(), - }).unwrap(); + }) + .unwrap(); loop { - socket.send_to( - &message, - SocketAddr::new(IpAddr::V4(broadcast_creation_info.broadcast_address), broadcast_port), - ).await.unwrap(); + socket + .send_to( + &message, + SocketAddr::new( + IpAddr::V4(broadcast_creation_info.broadcast_address), + broadcast_port, + ), + ) + .await + .unwrap(); tokio::time::sleep(broadcast_interval).await; } @@ -116,17 +204,5 @@ fn bind_to_address(address: SocketAddr) -> UdpSocket { #[test] fn test_interfaces() { - let raw = NetworkInterface::show().unwrap(); - let names_and_addrs = raw - .iter() - .flat_map(|network_interface| { - let v4_addrs = network_interface - .addr - .iter() - .filter(|addr: &&Addr| matches!(addr, Addr::V4(..))) - .map(|addr: &Addr| (network_interface.name.clone(), *addr)); - - v4_addrs - }) - .collect::>(); + dbg!(get_interfaces()); }