Reduce warning count

This commit is contained in:
Joshua Coles 2025-02-12 14:23:54 +00:00
parent 02694cbacc
commit 8d91f64dbf
7 changed files with 66 additions and 46 deletions

View File

@ -103,7 +103,8 @@ pub static CHIP_FLOPS: phf::Map<&'static str, DeviceFlops> = phf_map! {
}; };
pub fn look_up(chip: &str) -> Option<DeviceFlops> { pub fn look_up(chip: &str) -> Option<DeviceFlops> {
CHIP_FLOPS.get(chip) CHIP_FLOPS
.get(chip)
.or_else(|| CHIP_FLOPS.get(&format!("Laptop GPU {}", chip))) .or_else(|| CHIP_FLOPS.get(&format!("Laptop GPU {}", chip)))
.or_else(|| CHIP_FLOPS.get(&format!("{} Laptop GPU", chip))) .or_else(|| CHIP_FLOPS.get(&format!("{} Laptop GPU", chip)))
.cloned() .cloned()

View File

@ -1,8 +1,9 @@
use std::collections::HashMap;
use crate::network::get_broadcast_creation_info; use crate::network::get_broadcast_creation_info;
use crate::orchestration::PeerHandle;
use crate::topology::DeviceCapabilities; use crate::topology::DeviceCapabilities;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use socket2::{Domain, Protocol, Socket, Type}; use socket2::{Domain, Protocol, Socket, Type};
use std::collections::HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
@ -11,7 +12,6 @@ use tokio::sync::Mutex;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use tracing::{debug, info}; use tracing::{debug, info};
use uuid::Uuid; use uuid::Uuid;
use crate::orchestration::PeerHandle;
mod broadcast; mod broadcast;
mod udp_listen; mod udp_listen;
@ -70,7 +70,6 @@ impl Default for NodeInfo {
} }
pub struct UdpDiscovery { pub struct UdpDiscovery {
node_info: NodeInfo,
discovery_handle: JoinHandle<()>, discovery_handle: JoinHandle<()>,
presence_handle: JoinHandle<()>, presence_handle: JoinHandle<()>,
peer_manager_handle: JoinHandle<()>, peer_manager_handle: JoinHandle<()>,
@ -84,15 +83,18 @@ impl UdpDiscovery {
let peers = Arc::new(Mutex::new(HashMap::new())); let peers = Arc::new(Mutex::new(HashMap::new()));
let discovery_handle = tokio::spawn(broadcast::listen_all(node_info.clone(), broadcast_creation_info)); let discovery_handle = tokio::spawn(broadcast::listen_all(
let (presence_handle, peer_manager_handle) = udp_listen::manage_discovery(node_info.clone(), peers.clone()); node_info.clone(),
broadcast_creation_info,
));
let (presence_handle, peer_manager_handle) =
udp_listen::manage_discovery(node_info.clone(), peers.clone());
UdpDiscovery { UdpDiscovery {
node_info,
discovery_handle, discovery_handle,
presence_handle, presence_handle,
peer_manager_handle, peer_manager_handle,
peers peers,
} }
} }

View File

@ -1,18 +1,13 @@
use std::cell::RefCell;
use crate::discovery::{DiscoveryMessage, NodeInfo}; use crate::discovery::{DiscoveryMessage, NodeInfo};
use crate::orchestration::PeerHandle; use crate::orchestration::PeerHandle;
use crate::topology::DeviceCapabilities;
use std::collections::HashMap; use std::collections::HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use system_configuration::sys::libc::disconnectx;
use tokio::net::UdpSocket; use tokio::net::UdpSocket;
use tokio::select; use tokio::select;
use tokio::sync::mpsc::UnboundedSender; use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
use tonic::transport::Error;
use tracing::{debug, error, info}; use tracing::{debug, error, info};
async fn listen_for_discovery( async fn listen_for_discovery(
@ -122,7 +117,10 @@ async fn handle_new_peer(
peers.insert(message.node_id, new_peer); peers.insert(message.node_id, new_peer);
} }
pub fn manage_discovery(node_info: NodeInfo, peers: Arc<Mutex<HashMap<String, PeerHandle>>>) -> (JoinHandle<()>, JoinHandle<()>) { pub fn manage_discovery(
node_info: NodeInfo,
peers: Arc<Mutex<HashMap<String, PeerHandle>>>,
) -> (JoinHandle<()>, JoinHandle<()>) {
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<(SocketAddr, DiscoveryMessage)>(); let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<(SocketAddr, DiscoveryMessage)>();
// TODO: How do we handle killing this? // TODO: How do we handle killing this?

View File

@ -2,8 +2,8 @@ mod device_capability_data;
mod discovery; mod discovery;
mod network; mod network;
mod orchestration; mod orchestration;
mod topology;
mod partitioning; mod partitioning;
mod topology;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
@ -16,7 +16,7 @@ use crate::node_service::{
}; };
use node_service::node_service_server::{NodeService, NodeServiceServer}; use node_service::node_service_server::{NodeService, NodeServiceServer};
use node_service::TensorRequest; use node_service::TensorRequest;
use std::collections::{HashMap, HashSet}; use std::collections::HashSet;
use topology::Topology; use topology::Topology;
pub mod node_service { pub mod node_service {

View File

@ -1,13 +1,9 @@
use crate::topology::DeviceCapabilities;
use network_interface::{Addr, NetworkInterface, NetworkInterfaceConfig}; use network_interface::{Addr, NetworkInterface, NetworkInterfaceConfig};
use serde::{Deserialize, Serialize}; use std::net::Ipv4Addr;
use socket2::{Domain, Protocol, Socket, Type};
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::time::Duration;
use system_configuration::core_foundation::array::CFArray; use system_configuration::core_foundation::array::CFArray;
use system_configuration::network_configuration::{get_interfaces, SCNetworkInterface, SCNetworkInterfaceType}; use system_configuration::network_configuration::{
use system_configuration::sys::network_configuration::SCNetworkInterfaceRef; get_interfaces, SCNetworkInterface, SCNetworkInterfaceType,
use tokio::net::UdpSocket; };
#[derive(Copy, Clone, Debug)] #[derive(Copy, Clone, Debug)]
pub enum InterfaceType { pub enum InterfaceType {
@ -76,7 +72,8 @@ pub fn get_broadcast_creation_info() -> Vec<BroadcastCreationInfo> {
for (bind_address, broadcast_address) in map { for (bind_address, broadcast_address) in map {
let interface_type = if cfg!(target_os = "macos") { let interface_type = if cfg!(target_os = "macos") {
get_sc_interface_type(&interface.name, &sc_interfaces).unwrap_or_else(|| determine_interface_type(&interface.name)) get_sc_interface_type(&interface.name, &sc_interfaces)
.unwrap_or_else(|| determine_interface_type(&interface.name))
} else { } else {
determine_interface_type(&interface.name) determine_interface_type(&interface.name)
}; };
@ -94,7 +91,10 @@ pub fn get_broadcast_creation_info() -> Vec<BroadcastCreationInfo> {
} }
#[cfg(target_os = "macos")] #[cfg(target_os = "macos")]
fn get_sc_interface_type(name: &str, sc_interfaces: &CFArray<SCNetworkInterface>) -> Option<InterfaceType> { fn get_sc_interface_type(
name: &str,
sc_interfaces: &CFArray<SCNetworkInterface>,
) -> Option<InterfaceType> {
sc_interfaces.iter().find_map(|sc_if| { sc_interfaces.iter().find_map(|sc_if| {
sc_if sc_if
.bsd_name() .bsd_name()
@ -110,9 +110,14 @@ fn get_sc_interface_type(name: &str, sc_interfaces: &CFArray<SCNetworkInterface>
fn determine_interface_type(name: &str) -> InterfaceType { fn determine_interface_type(name: &str) -> InterfaceType {
// Fallback to interface name pattern matching // Fallback to interface name pattern matching
if name.starts_with("docker") || name.starts_with("br-") || name.starts_with("veth") if name.starts_with("docker")
|| name.starts_with("cni") || name.starts_with("flannel") || name.starts_with("calico") || name.starts_with("br-")
|| name.starts_with("weave") || name.contains("bridge") || name.starts_with("veth")
|| name.starts_with("cni")
|| name.starts_with("flannel")
|| name.starts_with("calico")
|| name.starts_with("weave")
|| name.contains("bridge")
{ {
InterfaceType::ContainerVirtual InterfaceType::ContainerVirtual
} else if name.starts_with("lo") { } else if name.starts_with("lo") {
@ -122,13 +127,20 @@ fn determine_interface_type(name: &str) -> InterfaceType {
} else if (name.starts_with("eth") || name.starts_with("en")) && !matches!(name, "en0" | "en1") } else if (name.starts_with("eth") || name.starts_with("en")) && !matches!(name, "en0" | "en1")
{ {
InterfaceType::Ethernet InterfaceType::Ethernet
} else if name.starts_with("wlan") || name.starts_with("wifi") || name.starts_with("wl") } else if name.starts_with("wlan")
|| name.starts_with("wifi")
|| name.starts_with("wl")
|| matches!(name, "en0" | "en1") || matches!(name, "en0" | "en1")
{ {
InterfaceType::WiFi InterfaceType::WiFi
} else if name.starts_with("tun") || name.starts_with("tap") || name.starts_with("vtun") } else if name.starts_with("tun")
|| name.starts_with("utun") || name.starts_with("gif") || name.starts_with("stf") || name.starts_with("tap")
|| name.starts_with("awdl") || name.starts_with("llw") || name.starts_with("vtun")
|| name.starts_with("utun")
|| name.starts_with("gif")
|| name.starts_with("stf")
|| name.starts_with("awdl")
|| name.starts_with("llw")
{ {
InterfaceType::ExternalVirtual InterfaceType::ExternalVirtual
} else { } else {

View File

@ -1,16 +1,17 @@
use std::collections::HashSet;
use crate::node_service::node_service_client::NodeServiceClient; use crate::node_service::node_service_client::NodeServiceClient;
use crate::node_service::{CollectTopologyRequest, HealthCheckRequest, Topology as TopologyProto}; use crate::node_service::{CollectTopologyRequest, HealthCheckRequest};
use crate::topology::{DeviceCapabilities, Topology}; use crate::topology::{DeviceCapabilities, Topology};
use std::collections::HashSet;
use std::net::SocketAddr; use std::net::SocketAddr;
use tonic::codec::CompressionEncoding; use tonic::codec::CompressionEncoding;
#[derive(Debug, Clone)]
pub struct PeerHandle { pub struct PeerHandle {
pub node_id: String, pub node_id: String,
pub address: SocketAddr, pub address: SocketAddr,
pub address_priority: u8, pub address_priority: u8,
pub description: Option<String>, pub description: Option<String>,
pub client: tokio::sync::Mutex<NodeServiceClient<tonic::transport::Channel>>, client: NodeServiceClient<tonic::transport::Channel>,
pub device_capabilities: DeviceCapabilities, pub device_capabilities: DeviceCapabilities,
} }
@ -32,15 +33,17 @@ impl PeerHandle {
description, description,
address_priority, address_priority,
address, address,
client: tokio::sync::Mutex::new(client), client,
device_capabilities, device_capabilities,
}) })
} }
pub fn client(&self) -> NodeServiceClient<tonic::transport::Channel> {
self.client.clone()
}
pub async fn is_healthy(&self) -> bool { pub async fn is_healthy(&self) -> bool {
self.client self.client()
.lock()
.await
.health_check(HealthCheckRequest::default()) .health_check(HealthCheckRequest::default())
.await .await
.ok() .ok()
@ -49,10 +52,15 @@ impl PeerHandle {
} }
pub async fn collect_topology(&self, visited: HashSet<String>, max_depth: u8) -> Topology { pub async fn collect_topology(&self, visited: HashSet<String>, max_depth: u8) -> Topology {
let response = self.client.lock().await.collect_topology(CollectTopologyRequest { let response = self
.client()
.collect_topology(CollectTopologyRequest {
visited: visited.clone().into_iter().collect(), visited: visited.clone().into_iter().collect(),
max_depth: max_depth as i32 max_depth: max_depth as i32,
}).await.unwrap().into_inner(); })
.await
.unwrap()
.into_inner();
response.into() response.into()
} }

View File

@ -2,7 +2,6 @@ use crate::{device_capability_data, node_service};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::HashMap; use std::collections::HashMap;
use std::process::Command; use std::process::Command;
use tonic::Response;
#[derive(Debug, Deserialize, Serialize, Clone)] #[derive(Debug, Deserialize, Serialize, Clone)]
pub struct Topology { pub struct Topology {