Implement collect topology

This commit is contained in:
Joshua Coles 2025-02-12 13:01:14 +00:00
parent cd0b4a1bbf
commit 3aca4a22ae
6 changed files with 254 additions and 38 deletions

View File

@ -7,7 +7,7 @@ const TFLOPS: f64 = 1.00;
pub static CHIP_FLOPS: phf::Map<&'static str, DeviceFlops> = phf_map! {
// Source: https://www.cpu-monkey.com
// Note: currently no distinction between variants of M3 Max and M3 Pro, we pick the lower one to be conservative
/// M chips
// M chips
"Apple M1" => DeviceFlops { fp32: 2.29*TFLOPS, fp16: 4.58*TFLOPS, int8: 9.16*TFLOPS },
"Apple M1 Pro" => DeviceFlops { fp32: 5.30*TFLOPS, fp16: 10.60*TFLOPS, int8: 21.20*TFLOPS },
"Apple M1 Max" => DeviceFlops { fp32: 10.60*TFLOPS, fp16: 21.20*TFLOPS, int8: 42.40*TFLOPS },
@ -22,13 +22,13 @@ pub static CHIP_FLOPS: phf::Map<&'static str, DeviceFlops> = phf_map! {
"Apple M4" => DeviceFlops { fp32: 4.26*TFLOPS, fp16: 8.52*TFLOPS, int8: 17.04*TFLOPS },
"Apple M4 Pro" => DeviceFlops { fp32: 5.72*TFLOPS, fp16: 11.44*TFLOPS, int8: 22.88*TFLOPS },
"Apple M4 Max" => DeviceFlops { fp32: 18.03*TFLOPS, fp16: 36.07*TFLOPS, int8: 72.14*TFLOPS },
/// A chips
// A chips
"Apple A13 Bionic" => DeviceFlops { fp32: 0.69*TFLOPS, fp16: 1.38*TFLOPS, int8: 2.76*TFLOPS },
"Apple A14 Bionic" => DeviceFlops { fp32: 0.75*TFLOPS, fp16: 1.50*TFLOPS, int8: 3.00*TFLOPS },
"Apple A15 Bionic" => DeviceFlops { fp32: 1.37*TFLOPS, fp16: 2.74*TFLOPS, int8: 5.48*TFLOPS },
"Apple A16 Bionic" => DeviceFlops { fp32: 1.79*TFLOPS, fp16: 3.58*TFLOPS, int8: 7.16*TFLOPS },
"Apple A17 Pro" => DeviceFlops { fp32: 2.15*TFLOPS, fp16: 4.30*TFLOPS, int8: 8.60*TFLOPS },
/// NVIDIA GPUs
// NVIDIA GPUs
// RTX 40 series
"NVIDIA GEFORCE RTX 4090" => DeviceFlops { fp32: 82.58*TFLOPS, fp16: 165.16*TFLOPS, int8: 330.32*TFLOPS },
"NVIDIA GEFORCE RTX 4080" => DeviceFlops { fp32: 48.74*TFLOPS, fp16: 97.48*TFLOPS, int8: 194.96*TFLOPS },
@ -82,7 +82,7 @@ pub static CHIP_FLOPS: phf::Map<&'static str, DeviceFlops> = phf_map! {
"NVIDIA A800 80GB PCIE" => DeviceFlops { fp32: 19.5*TFLOPS, fp16: 312.0*TFLOPS, int8: 624.0*TFLOPS },
"NVIDIA A100 80GB SXM" => DeviceFlops { fp32: 19.5*TFLOPS, fp16: 312.0*TFLOPS, int8: 624.0*TFLOPS },
"NVIDIA A800 80GB SXM" => DeviceFlops { fp32: 19.5*TFLOPS, fp16: 312.0*TFLOPS, int8: 624.0*TFLOPS },
/// AMD GPUs
// AMD GPUs
// RX 6000 series
"AMD Radeon RX 6900 XT" => DeviceFlops { fp32: 23.04*TFLOPS, fp16: 46.08*TFLOPS, int8: 92.16*TFLOPS },
"AMD Radeon RX 6800 XT" => DeviceFlops { fp32: 20.74*TFLOPS, fp16: 41.48*TFLOPS, int8: 82.96*TFLOPS },

View File

@ -1,12 +1,18 @@
use std::cell::RefCell;
use std::collections::HashMap;
use crate::network::get_broadcast_creation_info;
use crate::topology::DeviceCapabilities;
use serde::{Deserialize, Serialize};
use socket2::{Domain, Protocol, Socket, Type};
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::UdpSocket;
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use tracing::{debug, info};
use uuid::Uuid;
use crate::orchestration::PeerHandle;
mod broadcast;
mod udp_listen;
@ -52,13 +58,13 @@ impl Default for NodeInfo {
fn default() -> Self {
NodeInfo {
node_id: Uuid::new_v4().to_string(),
discovery_listen_port: 0,
broadcast_port: 0,
broadcast_interval: Default::default(),
grpc_port: 0,
discovery_listen_port: 5678,
broadcast_port: 5678,
broadcast_interval: Duration::from_secs_f32(2.5),
grpc_port: 49152,
allowed_peer_ids: None,
allowed_interfaces: None,
discovery_timeout: Default::default(),
discovery_timeout: Duration::from_secs(30),
device_capabilities: DeviceCapabilities::determine(),
}
}
@ -69,19 +75,25 @@ pub struct UdpDiscovery {
discovery_handle: JoinHandle<()>,
presence_handle: JoinHandle<()>,
peer_manager_handle: JoinHandle<()>,
pub peers: Arc<Mutex<HashMap<String, PeerHandle>>>,
}
impl UdpDiscovery {
pub fn new(node_info: NodeInfo) -> Self {
let broadcast_creation_info = get_broadcast_creation_info();
info!("Found addresses: {:?}", broadcast_creation_info);
let peers = Arc::new(Mutex::new(HashMap::new()));
let discovery_handle = tokio::spawn(broadcast::listen_all(node_info.clone(), broadcast_creation_info));
let (presence_handle, peer_manager_handle) = udp_listen::manage_discovery(node_info.clone());
let (presence_handle, peer_manager_handle) = udp_listen::manage_discovery(node_info.clone(), peers.clone());
UdpDiscovery {
node_info,
discovery_handle,
presence_handle,
peer_manager_handle,
peers
}
}

View File

@ -1,13 +1,16 @@
use std::cell::RefCell;
use crate::discovery::{DiscoveryMessage, NodeInfo};
use crate::orchestration::PeerHandle;
use crate::topology::DeviceCapabilities;
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use system_configuration::sys::libc::disconnectx;
use tokio::net::UdpSocket;
use tokio::select;
use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use tonic::transport::Error;
use tracing::{debug, error, info};
@ -119,8 +122,7 @@ async fn handle_new_peer(
peers.insert(message.node_id, new_peer);
}
pub fn manage_discovery(node_info: NodeInfo) -> (JoinHandle<()>, JoinHandle<()>) {
let mut peers: HashMap<String, PeerHandle> = HashMap::new();
pub fn manage_discovery(node_info: NodeInfo, peers: Arc<Mutex<HashMap<String, PeerHandle>>>) -> (JoinHandle<()>, JoinHandle<()>) {
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<(SocketAddr, DiscoveryMessage)>();
// TODO: How do we handle killing this?
@ -128,11 +130,13 @@ pub fn manage_discovery(node_info: NodeInfo) -> (JoinHandle<()>, JoinHandle<()>)
let peer_manager_handle = tokio::spawn(async move {
loop {
let action = select! {
let action: Action = select! {
_ = tokio::time::sleep(node_info.discovery_timeout) => Action::HealthChecks,
Some((addr, message)) = rx.recv() => Action::NewPeer(addr, message),
};
let mut peers = peers.lock().await;
match action {
Action::NewPeer(addr, message) => handle_new_peer(&mut peers, addr, message).await,
Action::HealthChecks => perform_health_checks(&mut peers).await,

View File

@ -1,34 +1,41 @@
mod topology;
mod orchestration;
mod device_capability_data;
mod discovery;
mod network;
mod device_capability_data;
mod orchestration;
mod topology;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use tonic::{transport::Server, Request, Response, Status};
use crate::discovery::{NodeInfo, UdpDiscovery};
use crate::node_service::{
CollectTopologyRequest, Empty, ExampleRequest, HealthCheckRequest, HealthCheckResponse, Loss,
PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor, Topology as TopologyProto,
};
use node_service::node_service_server::{NodeService, NodeServiceServer};
use node_service::TensorRequest;
use std::collections::{HashMap, HashSet};
use topology::Topology;
use crate::discovery::{NodeInfo, UdpDiscovery};
pub mod node_service {
tonic::include_proto!("node_service"); // The string specified here must match the proto package name
}
struct Node {
node_info: NodeInfo,
current_topology: Topology,
udp_discovery: UdpDiscovery,
}
impl Default for Node {
fn default() -> Self {
let node_info = NodeInfo::default();
Self {
node_info: node_info.clone(),
current_topology: Topology::default(),
udp_discovery: UdpDiscovery::new(node_info),
}
}
}
@ -49,7 +56,6 @@ struct Shard {
pub n_layers: i32,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
struct NodeStatus {
node_id: String,
@ -151,11 +157,19 @@ impl NodeService for Node {
todo!()
}
// TODO: Why aren't we using the request?
async fn collect_topology(
&self,
request: Request<CollectTopologyRequest>,
) -> Result<Response<TopologyProto>, Status> {
todo!()
let request = request.into_inner();
let max_depth = request.max_depth as u8;
let visited = request.visited;
self.update_topology_inner(max_depth, visited.into_iter().collect())
.await;
Ok(Response::new(self.current_topology.clone().into()))
}
async fn send_result(
@ -188,14 +202,49 @@ impl NodeService for Node {
}
}
impl Node {
async fn update_topology(&mut self) {
let overall_max_depth = 4;
let visited: HashSet<String> = HashSet::new();
self.current_topology = self.update_topology_inner(overall_max_depth, visited).await;
}
async fn update_topology_inner(&self, max_depth: u8, mut visited: HashSet<String>) -> Topology {
let mut new_topology = Topology::default();
new_topology.update_node(
self.node_info.node_id.clone(),
self.node_info.device_capabilities.clone(),
);
for peer in self.udp_discovery.peers.lock().await.values() {
new_topology.update_node(peer.node_id.clone(), peer.device_capabilities.clone());
new_topology.update_edge(
self.node_info.node_id.clone(),
peer.node_id.clone(),
peer.description.clone(),
);
visited.insert(peer.node_id.clone());
if !visited.contains(&peer.node_id) {
let topology = peer.collect_topology(visited.clone(), max_depth - 1).await;
new_topology.merge_restricted(&peer.node_id, topology);
}
}
new_topology
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// install global collector configured based on RUST_LOG env var.
tracing_subscriber::fmt::init();
let grpc_addr = "[::1]:50051".parse()?;
let node = Node::default();
let udp_discovery = UdpDiscovery::new(NodeInfo::default());
let node: Node = Node::default();
// TODO: Also implement discovery

View File

@ -1,6 +1,7 @@
use std::collections::HashSet;
use crate::node_service::node_service_client::NodeServiceClient;
use crate::node_service::HealthCheckRequest;
use crate::topology::DeviceCapabilities;
use crate::node_service::{CollectTopologyRequest, HealthCheckRequest, Topology as TopologyProto};
use crate::topology::{DeviceCapabilities, Topology};
use std::net::SocketAddr;
use tonic::codec::CompressionEncoding;
@ -46,4 +47,13 @@ impl PeerHandle {
.map(|x| x.into_inner().is_healthy)
.unwrap_or(false)
}
pub async fn collect_topology(&self, visited: HashSet<String>, max_depth: u8) -> Topology {
let response = self.client.lock().await.collect_topology(CollectTopologyRequest {
visited: visited.clone().into_iter().collect(),
max_depth: max_depth as i32
}).await.unwrap().into_inner();
response.into()
}
}

View File

@ -1,13 +1,132 @@
use crate::{device_capability_data, node_service};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::process::Command;
use serde::{Deserialize, Serialize};
use crate::device_capability_data;
use tonic::Response;
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct Topology {
nodes: HashMap<String, DeviceCapabilities>,
peer_graph: HashMap<String, Vec<PeerConnection>>,
active_node_id: Option<String>
pub nodes: HashMap<String, DeviceCapabilities>,
pub peer_graph: HashMap<String, Vec<PeerConnection>>,
pub active_node_id: Option<String>,
}
impl Topology {
pub fn update_node(&mut self, node_id: String, device_capabilities: DeviceCapabilities) {
self.nodes.insert(node_id, device_capabilities);
}
pub fn update_edge(&mut self, from_id: String, to_id: String, description: Option<String>) {
let conn = PeerConnection {
from_id: from_id.clone(),
to_id,
description,
};
match self.peer_graph.get_mut(&from_id) {
None => {
self.peer_graph.insert(from_id, vec![conn]);
}
Some(existing) => {
existing.push(conn);
}
}
}
pub(crate) fn merge_restricted(&mut self, from_peer_id: &str, topology: Topology) {
if let Some(peer_capabilities) = topology.nodes.get(from_peer_id) {
self.nodes
.insert(from_peer_id.to_string(), peer_capabilities.clone());
}
self.peer_graph.extend(
topology
.peer_graph
.into_iter()
.filter(|(id, _)| id == from_peer_id),
);
}
}
impl From<crate::node_service::Topology> for Topology {
fn from(proto: crate::node_service::Topology) -> Self {
let nodes = proto
.nodes
.into_iter()
.map(|(k, v)| (k, v.into()))
.collect();
let peer_graph = proto
.peer_graph
.into_iter()
.map(|(from_id, connections)| {
(
from_id.clone(),
connections
.connections
.into_iter()
.map(|pc| PeerConnection {
from_id: from_id.clone(),
to_id: pc.to_id,
description: pc.description,
})
.collect(),
)
})
.collect();
Topology {
nodes,
peer_graph,
active_node_id: None,
}
}
}
impl Into<crate::node_service::Topology> for Topology {
fn into(self) -> crate::node_service::Topology {
let nodes = self
.nodes
.iter()
.map(|(node_id, cap)| {
(
node_id.clone(),
node_service::DeviceCapabilities {
model: cap.model.clone(),
chip: cap.chip.clone(),
memory: cap.memory as i32,
flops: Some(node_service::DeviceFlops {
fp32: cap.flops.fp32,
fp16: cap.flops.fp16,
int8: cap.flops.int8,
}),
},
)
})
.collect::<HashMap<String, node_service::DeviceCapabilities>>();
let peer_graph = self
.peer_graph
.iter()
.map(|(node_id, connections)| {
(
node_id.clone(),
node_service::PeerConnections {
connections: connections
.iter()
.map(|conn| node_service::PeerConnection {
to_id: conn.to_id.clone(),
description: conn.description.clone(),
})
.collect(),
},
)
})
.collect::<HashMap<String, node_service::PeerConnections>>();
node_service::Topology { nodes, peer_graph }
}
}
impl Default for Topology {
@ -22,16 +141,16 @@ impl Default for Topology {
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct DeviceCapabilities {
model: String,
chip: String,
memory: u64,
flops: DeviceFlops,
pub model: String,
pub chip: String,
pub memory: u64,
pub flops: DeviceFlops,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
struct SystemProfilerOutputData {
#[serde(rename = "SPHardwareDataType")]
hardware: Vec<SPHardwareDataType>
hardware: Vec<SPHardwareDataType>,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
@ -51,7 +170,7 @@ struct SPHardwareDataType {
platform_uuid: String,
#[serde(rename = "provisioning_UDID")]
provisioning_udid: String,
serial_number: String
serial_number: String,
}
impl DeviceCapabilities {
@ -83,7 +202,8 @@ impl DeviceCapabilities {
};
DeviceCapabilities {
flops: device_capability_data::look_up(&chip).expect("Failed to find FLOPS data for chip"),
flops: device_capability_data::look_up(&chip)
.expect("Failed to find FLOPS data for chip"),
model,
chip,
memory,
@ -91,6 +211,17 @@ impl DeviceCapabilities {
}
}
impl From<crate::node_service::DeviceCapabilities> for DeviceCapabilities {
fn from(value: crate::node_service::DeviceCapabilities) -> Self {
DeviceCapabilities {
model: value.model,
chip: value.chip,
memory: value.memory as u64,
flops: value.flops.map(|x| x.into()).unwrap_or_default(),
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
pub struct DeviceFlops {
pub fp32: f64,
@ -98,9 +229,19 @@ pub struct DeviceFlops {
pub int8: f64,
}
impl From<crate::node_service::DeviceFlops> for DeviceFlops {
fn from(value: crate::node_service::DeviceFlops) -> Self {
DeviceFlops {
fp32: value.fp32,
fp16: value.fp16,
int8: value.int8,
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone, Hash, Eq, PartialEq)]
pub struct PeerConnection {
from_id: String,
to_id: String,
description: Option<String>,
pub from_id: String,
pub to_id: String,
pub description: Option<String>,
}