From 8d91f64dbf45a47de13e9f48015237b07f3f3b52 Mon Sep 17 00:00:00 2001 From: Joshua Coles Date: Wed, 12 Feb 2025 14:23:54 +0000 Subject: [PATCH] Reduce warning count --- src/device_capability_data.rs | 3 ++- src/discovery/mod.rs | 16 +++++++----- src/discovery/udp_listen.rs | 10 +++----- src/main.rs | 4 +-- src/network.rs | 48 ++++++++++++++++++++++------------- src/orchestration.rs | 30 ++++++++++++++-------- src/topology.rs | 1 - 7 files changed, 66 insertions(+), 46 deletions(-) diff --git a/src/device_capability_data.rs b/src/device_capability_data.rs index a82aec6..a70c370 100644 --- a/src/device_capability_data.rs +++ b/src/device_capability_data.rs @@ -103,7 +103,8 @@ pub static CHIP_FLOPS: phf::Map<&'static str, DeviceFlops> = phf_map! { }; pub fn look_up(chip: &str) -> Option { - CHIP_FLOPS.get(chip) + CHIP_FLOPS + .get(chip) .or_else(|| CHIP_FLOPS.get(&format!("Laptop GPU {}", chip))) .or_else(|| CHIP_FLOPS.get(&format!("{} Laptop GPU", chip))) .cloned() diff --git a/src/discovery/mod.rs b/src/discovery/mod.rs index 132f573..babf2c7 100644 --- a/src/discovery/mod.rs +++ b/src/discovery/mod.rs @@ -1,8 +1,9 @@ -use std::collections::HashMap; use crate::network::get_broadcast_creation_info; +use crate::orchestration::PeerHandle; use crate::topology::DeviceCapabilities; use serde::{Deserialize, Serialize}; use socket2::{Domain, Protocol, Socket, Type}; +use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; @@ -11,7 +12,6 @@ use tokio::sync::Mutex; use tokio::task::JoinHandle; use tracing::{debug, info}; use uuid::Uuid; -use crate::orchestration::PeerHandle; mod broadcast; mod udp_listen; @@ -70,7 +70,6 @@ impl Default for NodeInfo { } pub struct UdpDiscovery { - node_info: NodeInfo, discovery_handle: JoinHandle<()>, presence_handle: JoinHandle<()>, peer_manager_handle: JoinHandle<()>, @@ -84,15 +83,18 @@ impl UdpDiscovery { let peers = Arc::new(Mutex::new(HashMap::new())); - let discovery_handle = tokio::spawn(broadcast::listen_all(node_info.clone(), broadcast_creation_info)); - let (presence_handle, peer_manager_handle) = udp_listen::manage_discovery(node_info.clone(), peers.clone()); + let discovery_handle = tokio::spawn(broadcast::listen_all( + node_info.clone(), + broadcast_creation_info, + )); + let (presence_handle, peer_manager_handle) = + udp_listen::manage_discovery(node_info.clone(), peers.clone()); UdpDiscovery { - node_info, discovery_handle, presence_handle, peer_manager_handle, - peers + peers, } } diff --git a/src/discovery/udp_listen.rs b/src/discovery/udp_listen.rs index 9c76bcd..4aac36f 100644 --- a/src/discovery/udp_listen.rs +++ b/src/discovery/udp_listen.rs @@ -1,18 +1,13 @@ -use std::cell::RefCell; use crate::discovery::{DiscoveryMessage, NodeInfo}; use crate::orchestration::PeerHandle; -use crate::topology::DeviceCapabilities; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; -use std::time::Duration; -use system_configuration::sys::libc::disconnectx; use tokio::net::UdpSocket; use tokio::select; use tokio::sync::mpsc::UnboundedSender; use tokio::sync::Mutex; use tokio::task::JoinHandle; -use tonic::transport::Error; use tracing::{debug, error, info}; async fn listen_for_discovery( @@ -122,7 +117,10 @@ async fn handle_new_peer( peers.insert(message.node_id, new_peer); } -pub fn manage_discovery(node_info: NodeInfo, peers: Arc>>) -> (JoinHandle<()>, JoinHandle<()>) { +pub fn manage_discovery( + node_info: NodeInfo, + peers: Arc>>, +) -> (JoinHandle<()>, JoinHandle<()>) { let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<(SocketAddr, DiscoveryMessage)>(); // TODO: How do we handle killing this? diff --git a/src/main.rs b/src/main.rs index 5029132..d52be8e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,8 +2,8 @@ mod device_capability_data; mod discovery; mod network; mod orchestration; -mod topology; mod partitioning; +mod topology; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -16,7 +16,7 @@ use crate::node_service::{ }; use node_service::node_service_server::{NodeService, NodeServiceServer}; use node_service::TensorRequest; -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use topology::Topology; pub mod node_service { diff --git a/src/network.rs b/src/network.rs index 33f2655..c9003f1 100644 --- a/src/network.rs +++ b/src/network.rs @@ -1,13 +1,9 @@ -use crate::topology::DeviceCapabilities; use network_interface::{Addr, NetworkInterface, NetworkInterfaceConfig}; -use serde::{Deserialize, Serialize}; -use socket2::{Domain, Protocol, Socket, Type}; -use std::net::{IpAddr, Ipv4Addr, SocketAddr}; -use std::time::Duration; +use std::net::Ipv4Addr; use system_configuration::core_foundation::array::CFArray; -use system_configuration::network_configuration::{get_interfaces, SCNetworkInterface, SCNetworkInterfaceType}; -use system_configuration::sys::network_configuration::SCNetworkInterfaceRef; -use tokio::net::UdpSocket; +use system_configuration::network_configuration::{ + get_interfaces, SCNetworkInterface, SCNetworkInterfaceType, +}; #[derive(Copy, Clone, Debug)] pub enum InterfaceType { @@ -76,7 +72,8 @@ pub fn get_broadcast_creation_info() -> Vec { for (bind_address, broadcast_address) in map { let interface_type = if cfg!(target_os = "macos") { - get_sc_interface_type(&interface.name, &sc_interfaces).unwrap_or_else(|| determine_interface_type(&interface.name)) + get_sc_interface_type(&interface.name, &sc_interfaces) + .unwrap_or_else(|| determine_interface_type(&interface.name)) } else { determine_interface_type(&interface.name) }; @@ -94,7 +91,10 @@ pub fn get_broadcast_creation_info() -> Vec { } #[cfg(target_os = "macos")] -fn get_sc_interface_type(name: &str, sc_interfaces: &CFArray) -> Option { +fn get_sc_interface_type( + name: &str, + sc_interfaces: &CFArray, +) -> Option { sc_interfaces.iter().find_map(|sc_if| { sc_if .bsd_name() @@ -110,9 +110,14 @@ fn get_sc_interface_type(name: &str, sc_interfaces: &CFArray fn determine_interface_type(name: &str) -> InterfaceType { // Fallback to interface name pattern matching - if name.starts_with("docker") || name.starts_with("br-") || name.starts_with("veth") - || name.starts_with("cni") || name.starts_with("flannel") || name.starts_with("calico") - || name.starts_with("weave") || name.contains("bridge") + if name.starts_with("docker") + || name.starts_with("br-") + || name.starts_with("veth") + || name.starts_with("cni") + || name.starts_with("flannel") + || name.starts_with("calico") + || name.starts_with("weave") + || name.contains("bridge") { InterfaceType::ContainerVirtual } else if name.starts_with("lo") { @@ -122,13 +127,20 @@ fn determine_interface_type(name: &str) -> InterfaceType { } else if (name.starts_with("eth") || name.starts_with("en")) && !matches!(name, "en0" | "en1") { InterfaceType::Ethernet - } else if name.starts_with("wlan") || name.starts_with("wifi") || name.starts_with("wl") - || matches!(name, "en0" | "en1") + } else if name.starts_with("wlan") + || name.starts_with("wifi") + || name.starts_with("wl") + || matches!(name, "en0" | "en1") { InterfaceType::WiFi - } else if name.starts_with("tun") || name.starts_with("tap") || name.starts_with("vtun") - || name.starts_with("utun") || name.starts_with("gif") || name.starts_with("stf") - || name.starts_with("awdl") || name.starts_with("llw") + } else if name.starts_with("tun") + || name.starts_with("tap") + || name.starts_with("vtun") + || name.starts_with("utun") + || name.starts_with("gif") + || name.starts_with("stf") + || name.starts_with("awdl") + || name.starts_with("llw") { InterfaceType::ExternalVirtual } else { diff --git a/src/orchestration.rs b/src/orchestration.rs index aaa6342..b48e361 100644 --- a/src/orchestration.rs +++ b/src/orchestration.rs @@ -1,16 +1,17 @@ -use std::collections::HashSet; use crate::node_service::node_service_client::NodeServiceClient; -use crate::node_service::{CollectTopologyRequest, HealthCheckRequest, Topology as TopologyProto}; +use crate::node_service::{CollectTopologyRequest, HealthCheckRequest}; use crate::topology::{DeviceCapabilities, Topology}; +use std::collections::HashSet; use std::net::SocketAddr; use tonic::codec::CompressionEncoding; +#[derive(Debug, Clone)] pub struct PeerHandle { pub node_id: String, pub address: SocketAddr, pub address_priority: u8, pub description: Option, - pub client: tokio::sync::Mutex>, + client: NodeServiceClient, pub device_capabilities: DeviceCapabilities, } @@ -32,15 +33,17 @@ impl PeerHandle { description, address_priority, address, - client: tokio::sync::Mutex::new(client), + client, device_capabilities, }) } + pub fn client(&self) -> NodeServiceClient { + self.client.clone() + } + pub async fn is_healthy(&self) -> bool { - self.client - .lock() - .await + self.client() .health_check(HealthCheckRequest::default()) .await .ok() @@ -49,10 +52,15 @@ impl PeerHandle { } pub async fn collect_topology(&self, visited: HashSet, max_depth: u8) -> Topology { - let response = self.client.lock().await.collect_topology(CollectTopologyRequest { - visited: visited.clone().into_iter().collect(), - max_depth: max_depth as i32 - }).await.unwrap().into_inner(); + let response = self + .client() + .collect_topology(CollectTopologyRequest { + visited: visited.clone().into_iter().collect(), + max_depth: max_depth as i32, + }) + .await + .unwrap() + .into_inner(); response.into() } diff --git a/src/topology.rs b/src/topology.rs index a9d22a6..f45e543 100644 --- a/src/topology.rs +++ b/src/topology.rs @@ -2,7 +2,6 @@ use crate::{device_capability_data, node_service}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::process::Command; -use tonic::Response; #[derive(Debug, Deserialize, Serialize, Clone)] pub struct Topology {