diff --git a/src/device_capability_data.rs b/src/device_capability_data.rs index 098b0a9..a82aec6 100644 --- a/src/device_capability_data.rs +++ b/src/device_capability_data.rs @@ -7,7 +7,7 @@ const TFLOPS: f64 = 1.00; pub static CHIP_FLOPS: phf::Map<&'static str, DeviceFlops> = phf_map! { // Source: https://www.cpu-monkey.com // Note: currently no distinction between variants of M3 Max and M3 Pro, we pick the lower one to be conservative - /// M chips + // M chips "Apple M1" => DeviceFlops { fp32: 2.29*TFLOPS, fp16: 4.58*TFLOPS, int8: 9.16*TFLOPS }, "Apple M1 Pro" => DeviceFlops { fp32: 5.30*TFLOPS, fp16: 10.60*TFLOPS, int8: 21.20*TFLOPS }, "Apple M1 Max" => DeviceFlops { fp32: 10.60*TFLOPS, fp16: 21.20*TFLOPS, int8: 42.40*TFLOPS }, @@ -22,13 +22,13 @@ pub static CHIP_FLOPS: phf::Map<&'static str, DeviceFlops> = phf_map! { "Apple M4" => DeviceFlops { fp32: 4.26*TFLOPS, fp16: 8.52*TFLOPS, int8: 17.04*TFLOPS }, "Apple M4 Pro" => DeviceFlops { fp32: 5.72*TFLOPS, fp16: 11.44*TFLOPS, int8: 22.88*TFLOPS }, "Apple M4 Max" => DeviceFlops { fp32: 18.03*TFLOPS, fp16: 36.07*TFLOPS, int8: 72.14*TFLOPS }, - /// A chips + // A chips "Apple A13 Bionic" => DeviceFlops { fp32: 0.69*TFLOPS, fp16: 1.38*TFLOPS, int8: 2.76*TFLOPS }, "Apple A14 Bionic" => DeviceFlops { fp32: 0.75*TFLOPS, fp16: 1.50*TFLOPS, int8: 3.00*TFLOPS }, "Apple A15 Bionic" => DeviceFlops { fp32: 1.37*TFLOPS, fp16: 2.74*TFLOPS, int8: 5.48*TFLOPS }, "Apple A16 Bionic" => DeviceFlops { fp32: 1.79*TFLOPS, fp16: 3.58*TFLOPS, int8: 7.16*TFLOPS }, "Apple A17 Pro" => DeviceFlops { fp32: 2.15*TFLOPS, fp16: 4.30*TFLOPS, int8: 8.60*TFLOPS }, - /// NVIDIA GPUs + // NVIDIA GPUs // RTX 40 series "NVIDIA GEFORCE RTX 4090" => DeviceFlops { fp32: 82.58*TFLOPS, fp16: 165.16*TFLOPS, int8: 330.32*TFLOPS }, "NVIDIA GEFORCE RTX 4080" => DeviceFlops { fp32: 48.74*TFLOPS, fp16: 97.48*TFLOPS, int8: 194.96*TFLOPS }, @@ -82,7 +82,7 @@ pub static CHIP_FLOPS: phf::Map<&'static str, DeviceFlops> = phf_map! { "NVIDIA A800 80GB PCIE" => DeviceFlops { fp32: 19.5*TFLOPS, fp16: 312.0*TFLOPS, int8: 624.0*TFLOPS }, "NVIDIA A100 80GB SXM" => DeviceFlops { fp32: 19.5*TFLOPS, fp16: 312.0*TFLOPS, int8: 624.0*TFLOPS }, "NVIDIA A800 80GB SXM" => DeviceFlops { fp32: 19.5*TFLOPS, fp16: 312.0*TFLOPS, int8: 624.0*TFLOPS }, - /// AMD GPUs + // AMD GPUs // RX 6000 series "AMD Radeon RX 6900 XT" => DeviceFlops { fp32: 23.04*TFLOPS, fp16: 46.08*TFLOPS, int8: 92.16*TFLOPS }, "AMD Radeon RX 6800 XT" => DeviceFlops { fp32: 20.74*TFLOPS, fp16: 41.48*TFLOPS, int8: 82.96*TFLOPS }, diff --git a/src/discovery/mod.rs b/src/discovery/mod.rs index 7a4c677..6f9fe5e 100644 --- a/src/discovery/mod.rs +++ b/src/discovery/mod.rs @@ -1,12 +1,18 @@ +use std::cell::RefCell; +use std::collections::HashMap; use crate::network::get_broadcast_creation_info; use crate::topology::DeviceCapabilities; use serde::{Deserialize, Serialize}; use socket2::{Domain, Protocol, Socket, Type}; use std::net::SocketAddr; +use std::sync::Arc; use std::time::Duration; use tokio::net::UdpSocket; +use tokio::sync::Mutex; use tokio::task::JoinHandle; +use tracing::{debug, info}; use uuid::Uuid; +use crate::orchestration::PeerHandle; mod broadcast; mod udp_listen; @@ -52,13 +58,13 @@ impl Default for NodeInfo { fn default() -> Self { NodeInfo { node_id: Uuid::new_v4().to_string(), - discovery_listen_port: 0, - broadcast_port: 0, - broadcast_interval: Default::default(), - grpc_port: 0, + discovery_listen_port: 5678, + broadcast_port: 5678, + broadcast_interval: Duration::from_secs_f32(2.5), + grpc_port: 49152, allowed_peer_ids: None, allowed_interfaces: None, - discovery_timeout: Default::default(), + discovery_timeout: Duration::from_secs(30), device_capabilities: DeviceCapabilities::determine(), } } @@ -69,19 +75,25 @@ pub struct UdpDiscovery { discovery_handle: JoinHandle<()>, presence_handle: JoinHandle<()>, peer_manager_handle: JoinHandle<()>, + pub peers: Arc>>, } impl UdpDiscovery { pub fn new(node_info: NodeInfo) -> Self { let broadcast_creation_info = get_broadcast_creation_info(); + info!("Found addresses: {:?}", broadcast_creation_info); + + let peers = Arc::new(Mutex::new(HashMap::new())); + let discovery_handle = tokio::spawn(broadcast::listen_all(node_info.clone(), broadcast_creation_info)); - let (presence_handle, peer_manager_handle) = udp_listen::manage_discovery(node_info.clone()); + let (presence_handle, peer_manager_handle) = udp_listen::manage_discovery(node_info.clone(), peers.clone()); UdpDiscovery { node_info, discovery_handle, presence_handle, peer_manager_handle, + peers } } diff --git a/src/discovery/udp_listen.rs b/src/discovery/udp_listen.rs index 87fd92f..9c76bcd 100644 --- a/src/discovery/udp_listen.rs +++ b/src/discovery/udp_listen.rs @@ -1,13 +1,16 @@ +use std::cell::RefCell; use crate::discovery::{DiscoveryMessage, NodeInfo}; use crate::orchestration::PeerHandle; use crate::topology::DeviceCapabilities; use std::collections::HashMap; use std::net::SocketAddr; +use std::sync::Arc; use std::time::Duration; use system_configuration::sys::libc::disconnectx; use tokio::net::UdpSocket; use tokio::select; use tokio::sync::mpsc::UnboundedSender; +use tokio::sync::Mutex; use tokio::task::JoinHandle; use tonic::transport::Error; use tracing::{debug, error, info}; @@ -119,8 +122,7 @@ async fn handle_new_peer( peers.insert(message.node_id, new_peer); } -pub fn manage_discovery(node_info: NodeInfo) -> (JoinHandle<()>, JoinHandle<()>) { - let mut peers: HashMap = HashMap::new(); +pub fn manage_discovery(node_info: NodeInfo, peers: Arc>>) -> (JoinHandle<()>, JoinHandle<()>) { let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<(SocketAddr, DiscoveryMessage)>(); // TODO: How do we handle killing this? @@ -128,11 +130,13 @@ pub fn manage_discovery(node_info: NodeInfo) -> (JoinHandle<()>, JoinHandle<()>) let peer_manager_handle = tokio::spawn(async move { loop { - let action = select! { + let action: Action = select! { _ = tokio::time::sleep(node_info.discovery_timeout) => Action::HealthChecks, Some((addr, message)) = rx.recv() => Action::NewPeer(addr, message), }; + let mut peers = peers.lock().await; + match action { Action::NewPeer(addr, message) => handle_new_peer(&mut peers, addr, message).await, Action::HealthChecks => perform_health_checks(&mut peers).await, diff --git a/src/main.rs b/src/main.rs index f7bb9f0..d823fa4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,34 +1,41 @@ -mod topology; -mod orchestration; +mod device_capability_data; mod discovery; mod network; -mod device_capability_data; +mod orchestration; +mod topology; use serde::{Deserialize, Serialize}; use serde_json::Value; use tonic::{transport::Server, Request, Response, Status}; +use crate::discovery::{NodeInfo, UdpDiscovery}; use crate::node_service::{ CollectTopologyRequest, Empty, ExampleRequest, HealthCheckRequest, HealthCheckResponse, Loss, PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor, Topology as TopologyProto, }; use node_service::node_service_server::{NodeService, NodeServiceServer}; use node_service::TensorRequest; +use std::collections::{HashMap, HashSet}; use topology::Topology; -use crate::discovery::{NodeInfo, UdpDiscovery}; pub mod node_service { tonic::include_proto!("node_service"); // The string specified here must match the proto package name } struct Node { + node_info: NodeInfo, current_topology: Topology, + udp_discovery: UdpDiscovery, } impl Default for Node { fn default() -> Self { + let node_info = NodeInfo::default(); + Self { + node_info: node_info.clone(), current_topology: Topology::default(), + udp_discovery: UdpDiscovery::new(node_info), } } } @@ -49,7 +56,6 @@ struct Shard { pub n_layers: i32, } - #[derive(Debug, Deserialize, Serialize, Clone)] struct NodeStatus { node_id: String, @@ -151,11 +157,19 @@ impl NodeService for Node { todo!() } + // TODO: Why aren't we using the request? async fn collect_topology( &self, request: Request, ) -> Result, Status> { - todo!() + let request = request.into_inner(); + let max_depth = request.max_depth as u8; + let visited = request.visited; + + self.update_topology_inner(max_depth, visited.into_iter().collect()) + .await; + + Ok(Response::new(self.current_topology.clone().into())) } async fn send_result( @@ -188,14 +202,49 @@ impl NodeService for Node { } } +impl Node { + async fn update_topology(&mut self) { + let overall_max_depth = 4; + let visited: HashSet = HashSet::new(); + + self.current_topology = self.update_topology_inner(overall_max_depth, visited).await; + } + + async fn update_topology_inner(&self, max_depth: u8, mut visited: HashSet) -> Topology { + let mut new_topology = Topology::default(); + + new_topology.update_node( + self.node_info.node_id.clone(), + self.node_info.device_capabilities.clone(), + ); + + for peer in self.udp_discovery.peers.lock().await.values() { + new_topology.update_node(peer.node_id.clone(), peer.device_capabilities.clone()); + new_topology.update_edge( + self.node_info.node_id.clone(), + peer.node_id.clone(), + peer.description.clone(), + ); + + visited.insert(peer.node_id.clone()); + + if !visited.contains(&peer.node_id) { + let topology = peer.collect_topology(visited.clone(), max_depth - 1).await; + new_topology.merge_restricted(&peer.node_id, topology); + } + } + + new_topology + } +} + #[tokio::main] async fn main() -> Result<(), Box> { // install global collector configured based on RUST_LOG env var. tracing_subscriber::fmt::init(); let grpc_addr = "[::1]:50051".parse()?; - let node = Node::default(); - let udp_discovery = UdpDiscovery::new(NodeInfo::default()); + let node: Node = Node::default(); // TODO: Also implement discovery diff --git a/src/orchestration.rs b/src/orchestration.rs index 45dda57..aaa6342 100644 --- a/src/orchestration.rs +++ b/src/orchestration.rs @@ -1,6 +1,7 @@ +use std::collections::HashSet; use crate::node_service::node_service_client::NodeServiceClient; -use crate::node_service::HealthCheckRequest; -use crate::topology::DeviceCapabilities; +use crate::node_service::{CollectTopologyRequest, HealthCheckRequest, Topology as TopologyProto}; +use crate::topology::{DeviceCapabilities, Topology}; use std::net::SocketAddr; use tonic::codec::CompressionEncoding; @@ -46,4 +47,13 @@ impl PeerHandle { .map(|x| x.into_inner().is_healthy) .unwrap_or(false) } + + pub async fn collect_topology(&self, visited: HashSet, max_depth: u8) -> Topology { + let response = self.client.lock().await.collect_topology(CollectTopologyRequest { + visited: visited.clone().into_iter().collect(), + max_depth: max_depth as i32 + }).await.unwrap().into_inner(); + + response.into() + } } diff --git a/src/topology.rs b/src/topology.rs index de5e901..a9d22a6 100644 --- a/src/topology.rs +++ b/src/topology.rs @@ -1,13 +1,132 @@ +use crate::{device_capability_data, node_service}; +use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::process::Command; -use serde::{Deserialize, Serialize}; -use crate::device_capability_data; +use tonic::Response; #[derive(Debug, Deserialize, Serialize, Clone)] pub struct Topology { - nodes: HashMap, - peer_graph: HashMap>, - active_node_id: Option + pub nodes: HashMap, + pub peer_graph: HashMap>, + pub active_node_id: Option, +} + +impl Topology { + pub fn update_node(&mut self, node_id: String, device_capabilities: DeviceCapabilities) { + self.nodes.insert(node_id, device_capabilities); + } + + pub fn update_edge(&mut self, from_id: String, to_id: String, description: Option) { + let conn = PeerConnection { + from_id: from_id.clone(), + to_id, + description, + }; + + match self.peer_graph.get_mut(&from_id) { + None => { + self.peer_graph.insert(from_id, vec![conn]); + } + + Some(existing) => { + existing.push(conn); + } + } + } + + pub(crate) fn merge_restricted(&mut self, from_peer_id: &str, topology: Topology) { + if let Some(peer_capabilities) = topology.nodes.get(from_peer_id) { + self.nodes + .insert(from_peer_id.to_string(), peer_capabilities.clone()); + } + + self.peer_graph.extend( + topology + .peer_graph + .into_iter() + .filter(|(id, _)| id == from_peer_id), + ); + } +} + +impl From for Topology { + fn from(proto: crate::node_service::Topology) -> Self { + let nodes = proto + .nodes + .into_iter() + .map(|(k, v)| (k, v.into())) + .collect(); + + let peer_graph = proto + .peer_graph + .into_iter() + .map(|(from_id, connections)| { + ( + from_id.clone(), + connections + .connections + .into_iter() + .map(|pc| PeerConnection { + from_id: from_id.clone(), + to_id: pc.to_id, + description: pc.description, + }) + .collect(), + ) + }) + .collect(); + + Topology { + nodes, + peer_graph, + active_node_id: None, + } + } +} + +impl Into for Topology { + fn into(self) -> crate::node_service::Topology { + let nodes = self + .nodes + .iter() + .map(|(node_id, cap)| { + ( + node_id.clone(), + node_service::DeviceCapabilities { + model: cap.model.clone(), + chip: cap.chip.clone(), + memory: cap.memory as i32, + flops: Some(node_service::DeviceFlops { + fp32: cap.flops.fp32, + fp16: cap.flops.fp16, + int8: cap.flops.int8, + }), + }, + ) + }) + .collect::>(); + + let peer_graph = self + .peer_graph + .iter() + .map(|(node_id, connections)| { + ( + node_id.clone(), + node_service::PeerConnections { + connections: connections + .iter() + .map(|conn| node_service::PeerConnection { + to_id: conn.to_id.clone(), + description: conn.description.clone(), + }) + .collect(), + }, + ) + }) + .collect::>(); + + node_service::Topology { nodes, peer_graph } + } } impl Default for Topology { @@ -22,16 +141,16 @@ impl Default for Topology { #[derive(Debug, Deserialize, Serialize, Clone)] pub struct DeviceCapabilities { - model: String, - chip: String, - memory: u64, - flops: DeviceFlops, + pub model: String, + pub chip: String, + pub memory: u64, + pub flops: DeviceFlops, } #[derive(Debug, Deserialize, Serialize, Clone)] struct SystemProfilerOutputData { #[serde(rename = "SPHardwareDataType")] - hardware: Vec + hardware: Vec, } #[derive(Debug, Deserialize, Serialize, Clone)] @@ -51,7 +170,7 @@ struct SPHardwareDataType { platform_uuid: String, #[serde(rename = "provisioning_UDID")] provisioning_udid: String, - serial_number: String + serial_number: String, } impl DeviceCapabilities { @@ -83,7 +202,8 @@ impl DeviceCapabilities { }; DeviceCapabilities { - flops: device_capability_data::look_up(&chip).expect("Failed to find FLOPS data for chip"), + flops: device_capability_data::look_up(&chip) + .expect("Failed to find FLOPS data for chip"), model, chip, memory, @@ -91,6 +211,17 @@ impl DeviceCapabilities { } } +impl From for DeviceCapabilities { + fn from(value: crate::node_service::DeviceCapabilities) -> Self { + DeviceCapabilities { + model: value.model, + chip: value.chip, + memory: value.memory as u64, + flops: value.flops.map(|x| x.into()).unwrap_or_default(), + } + } +} + #[derive(Debug, Deserialize, Serialize, Clone, Default)] pub struct DeviceFlops { pub fp32: f64, @@ -98,9 +229,19 @@ pub struct DeviceFlops { pub int8: f64, } +impl From for DeviceFlops { + fn from(value: crate::node_service::DeviceFlops) -> Self { + DeviceFlops { + fp32: value.fp32, + fp16: value.fp16, + int8: value.int8, + } + } +} + #[derive(Debug, Deserialize, Serialize, Clone, Hash, Eq, PartialEq)] pub struct PeerConnection { - from_id: String, - to_id: String, - description: Option, + pub from_id: String, + pub to_id: String, + pub description: Option, }