Implement collect topology
This commit is contained in:
parent
cd0b4a1bbf
commit
3aca4a22ae
@ -7,7 +7,7 @@ const TFLOPS: f64 = 1.00;
|
||||
pub static CHIP_FLOPS: phf::Map<&'static str, DeviceFlops> = phf_map! {
|
||||
// Source: https://www.cpu-monkey.com
|
||||
// Note: currently no distinction between variants of M3 Max and M3 Pro, we pick the lower one to be conservative
|
||||
/// M chips
|
||||
// M chips
|
||||
"Apple M1" => DeviceFlops { fp32: 2.29*TFLOPS, fp16: 4.58*TFLOPS, int8: 9.16*TFLOPS },
|
||||
"Apple M1 Pro" => DeviceFlops { fp32: 5.30*TFLOPS, fp16: 10.60*TFLOPS, int8: 21.20*TFLOPS },
|
||||
"Apple M1 Max" => DeviceFlops { fp32: 10.60*TFLOPS, fp16: 21.20*TFLOPS, int8: 42.40*TFLOPS },
|
||||
@ -22,13 +22,13 @@ pub static CHIP_FLOPS: phf::Map<&'static str, DeviceFlops> = phf_map! {
|
||||
"Apple M4" => DeviceFlops { fp32: 4.26*TFLOPS, fp16: 8.52*TFLOPS, int8: 17.04*TFLOPS },
|
||||
"Apple M4 Pro" => DeviceFlops { fp32: 5.72*TFLOPS, fp16: 11.44*TFLOPS, int8: 22.88*TFLOPS },
|
||||
"Apple M4 Max" => DeviceFlops { fp32: 18.03*TFLOPS, fp16: 36.07*TFLOPS, int8: 72.14*TFLOPS },
|
||||
/// A chips
|
||||
// A chips
|
||||
"Apple A13 Bionic" => DeviceFlops { fp32: 0.69*TFLOPS, fp16: 1.38*TFLOPS, int8: 2.76*TFLOPS },
|
||||
"Apple A14 Bionic" => DeviceFlops { fp32: 0.75*TFLOPS, fp16: 1.50*TFLOPS, int8: 3.00*TFLOPS },
|
||||
"Apple A15 Bionic" => DeviceFlops { fp32: 1.37*TFLOPS, fp16: 2.74*TFLOPS, int8: 5.48*TFLOPS },
|
||||
"Apple A16 Bionic" => DeviceFlops { fp32: 1.79*TFLOPS, fp16: 3.58*TFLOPS, int8: 7.16*TFLOPS },
|
||||
"Apple A17 Pro" => DeviceFlops { fp32: 2.15*TFLOPS, fp16: 4.30*TFLOPS, int8: 8.60*TFLOPS },
|
||||
/// NVIDIA GPUs
|
||||
// NVIDIA GPUs
|
||||
// RTX 40 series
|
||||
"NVIDIA GEFORCE RTX 4090" => DeviceFlops { fp32: 82.58*TFLOPS, fp16: 165.16*TFLOPS, int8: 330.32*TFLOPS },
|
||||
"NVIDIA GEFORCE RTX 4080" => DeviceFlops { fp32: 48.74*TFLOPS, fp16: 97.48*TFLOPS, int8: 194.96*TFLOPS },
|
||||
@ -82,7 +82,7 @@ pub static CHIP_FLOPS: phf::Map<&'static str, DeviceFlops> = phf_map! {
|
||||
"NVIDIA A800 80GB PCIE" => DeviceFlops { fp32: 19.5*TFLOPS, fp16: 312.0*TFLOPS, int8: 624.0*TFLOPS },
|
||||
"NVIDIA A100 80GB SXM" => DeviceFlops { fp32: 19.5*TFLOPS, fp16: 312.0*TFLOPS, int8: 624.0*TFLOPS },
|
||||
"NVIDIA A800 80GB SXM" => DeviceFlops { fp32: 19.5*TFLOPS, fp16: 312.0*TFLOPS, int8: 624.0*TFLOPS },
|
||||
/// AMD GPUs
|
||||
// AMD GPUs
|
||||
// RX 6000 series
|
||||
"AMD Radeon RX 6900 XT" => DeviceFlops { fp32: 23.04*TFLOPS, fp16: 46.08*TFLOPS, int8: 92.16*TFLOPS },
|
||||
"AMD Radeon RX 6800 XT" => DeviceFlops { fp32: 20.74*TFLOPS, fp16: 41.48*TFLOPS, int8: 82.96*TFLOPS },
|
||||
|
||||
@ -1,12 +1,18 @@
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use crate::network::get_broadcast_creation_info;
|
||||
use crate::topology::DeviceCapabilities;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use socket2::{Domain, Protocol, Socket, Type};
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::JoinHandle;
|
||||
use tracing::{debug, info};
|
||||
use uuid::Uuid;
|
||||
use crate::orchestration::PeerHandle;
|
||||
|
||||
mod broadcast;
|
||||
mod udp_listen;
|
||||
@ -52,13 +58,13 @@ impl Default for NodeInfo {
|
||||
fn default() -> Self {
|
||||
NodeInfo {
|
||||
node_id: Uuid::new_v4().to_string(),
|
||||
discovery_listen_port: 0,
|
||||
broadcast_port: 0,
|
||||
broadcast_interval: Default::default(),
|
||||
grpc_port: 0,
|
||||
discovery_listen_port: 5678,
|
||||
broadcast_port: 5678,
|
||||
broadcast_interval: Duration::from_secs_f32(2.5),
|
||||
grpc_port: 49152,
|
||||
allowed_peer_ids: None,
|
||||
allowed_interfaces: None,
|
||||
discovery_timeout: Default::default(),
|
||||
discovery_timeout: Duration::from_secs(30),
|
||||
device_capabilities: DeviceCapabilities::determine(),
|
||||
}
|
||||
}
|
||||
@ -69,19 +75,25 @@ pub struct UdpDiscovery {
|
||||
discovery_handle: JoinHandle<()>,
|
||||
presence_handle: JoinHandle<()>,
|
||||
peer_manager_handle: JoinHandle<()>,
|
||||
pub peers: Arc<Mutex<HashMap<String, PeerHandle>>>,
|
||||
}
|
||||
|
||||
impl UdpDiscovery {
|
||||
pub fn new(node_info: NodeInfo) -> Self {
|
||||
let broadcast_creation_info = get_broadcast_creation_info();
|
||||
info!("Found addresses: {:?}", broadcast_creation_info);
|
||||
|
||||
let peers = Arc::new(Mutex::new(HashMap::new()));
|
||||
|
||||
let discovery_handle = tokio::spawn(broadcast::listen_all(node_info.clone(), broadcast_creation_info));
|
||||
let (presence_handle, peer_manager_handle) = udp_listen::manage_discovery(node_info.clone());
|
||||
let (presence_handle, peer_manager_handle) = udp_listen::manage_discovery(node_info.clone(), peers.clone());
|
||||
|
||||
UdpDiscovery {
|
||||
node_info,
|
||||
discovery_handle,
|
||||
presence_handle,
|
||||
peer_manager_handle,
|
||||
peers
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,13 +1,16 @@
|
||||
use std::cell::RefCell;
|
||||
use crate::discovery::{DiscoveryMessage, NodeInfo};
|
||||
use crate::orchestration::PeerHandle;
|
||||
use crate::topology::DeviceCapabilities;
|
||||
use std::collections::HashMap;
|
||||
use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use system_configuration::sys::libc::disconnectx;
|
||||
use tokio::net::UdpSocket;
|
||||
use tokio::select;
|
||||
use tokio::sync::mpsc::UnboundedSender;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::JoinHandle;
|
||||
use tonic::transport::Error;
|
||||
use tracing::{debug, error, info};
|
||||
@ -119,8 +122,7 @@ async fn handle_new_peer(
|
||||
peers.insert(message.node_id, new_peer);
|
||||
}
|
||||
|
||||
pub fn manage_discovery(node_info: NodeInfo) -> (JoinHandle<()>, JoinHandle<()>) {
|
||||
let mut peers: HashMap<String, PeerHandle> = HashMap::new();
|
||||
pub fn manage_discovery(node_info: NodeInfo, peers: Arc<Mutex<HashMap<String, PeerHandle>>>) -> (JoinHandle<()>, JoinHandle<()>) {
|
||||
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<(SocketAddr, DiscoveryMessage)>();
|
||||
|
||||
// TODO: How do we handle killing this?
|
||||
@ -128,11 +130,13 @@ pub fn manage_discovery(node_info: NodeInfo) -> (JoinHandle<()>, JoinHandle<()>)
|
||||
|
||||
let peer_manager_handle = tokio::spawn(async move {
|
||||
loop {
|
||||
let action = select! {
|
||||
let action: Action = select! {
|
||||
_ = tokio::time::sleep(node_info.discovery_timeout) => Action::HealthChecks,
|
||||
Some((addr, message)) = rx.recv() => Action::NewPeer(addr, message),
|
||||
};
|
||||
|
||||
let mut peers = peers.lock().await;
|
||||
|
||||
match action {
|
||||
Action::NewPeer(addr, message) => handle_new_peer(&mut peers, addr, message).await,
|
||||
Action::HealthChecks => perform_health_checks(&mut peers).await,
|
||||
|
||||
65
src/main.rs
65
src/main.rs
@ -1,34 +1,41 @@
|
||||
mod topology;
|
||||
mod orchestration;
|
||||
mod device_capability_data;
|
||||
mod discovery;
|
||||
mod network;
|
||||
mod device_capability_data;
|
||||
mod orchestration;
|
||||
mod topology;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use tonic::{transport::Server, Request, Response, Status};
|
||||
|
||||
use crate::discovery::{NodeInfo, UdpDiscovery};
|
||||
use crate::node_service::{
|
||||
CollectTopologyRequest, Empty, ExampleRequest, HealthCheckRequest, HealthCheckResponse, Loss,
|
||||
PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor, Topology as TopologyProto,
|
||||
};
|
||||
use node_service::node_service_server::{NodeService, NodeServiceServer};
|
||||
use node_service::TensorRequest;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use topology::Topology;
|
||||
use crate::discovery::{NodeInfo, UdpDiscovery};
|
||||
|
||||
pub mod node_service {
|
||||
tonic::include_proto!("node_service"); // The string specified here must match the proto package name
|
||||
}
|
||||
|
||||
struct Node {
|
||||
node_info: NodeInfo,
|
||||
current_topology: Topology,
|
||||
udp_discovery: UdpDiscovery,
|
||||
}
|
||||
|
||||
impl Default for Node {
|
||||
fn default() -> Self {
|
||||
let node_info = NodeInfo::default();
|
||||
|
||||
Self {
|
||||
node_info: node_info.clone(),
|
||||
current_topology: Topology::default(),
|
||||
udp_discovery: UdpDiscovery::new(node_info),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -49,7 +56,6 @@ struct Shard {
|
||||
pub n_layers: i32,
|
||||
}
|
||||
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
struct NodeStatus {
|
||||
node_id: String,
|
||||
@ -151,11 +157,19 @@ impl NodeService for Node {
|
||||
todo!()
|
||||
}
|
||||
|
||||
// TODO: Why aren't we using the request?
|
||||
async fn collect_topology(
|
||||
&self,
|
||||
request: Request<CollectTopologyRequest>,
|
||||
) -> Result<Response<TopologyProto>, Status> {
|
||||
todo!()
|
||||
let request = request.into_inner();
|
||||
let max_depth = request.max_depth as u8;
|
||||
let visited = request.visited;
|
||||
|
||||
self.update_topology_inner(max_depth, visited.into_iter().collect())
|
||||
.await;
|
||||
|
||||
Ok(Response::new(self.current_topology.clone().into()))
|
||||
}
|
||||
|
||||
async fn send_result(
|
||||
@ -188,14 +202,49 @@ impl NodeService for Node {
|
||||
}
|
||||
}
|
||||
|
||||
impl Node {
|
||||
async fn update_topology(&mut self) {
|
||||
let overall_max_depth = 4;
|
||||
let visited: HashSet<String> = HashSet::new();
|
||||
|
||||
self.current_topology = self.update_topology_inner(overall_max_depth, visited).await;
|
||||
}
|
||||
|
||||
async fn update_topology_inner(&self, max_depth: u8, mut visited: HashSet<String>) -> Topology {
|
||||
let mut new_topology = Topology::default();
|
||||
|
||||
new_topology.update_node(
|
||||
self.node_info.node_id.clone(),
|
||||
self.node_info.device_capabilities.clone(),
|
||||
);
|
||||
|
||||
for peer in self.udp_discovery.peers.lock().await.values() {
|
||||
new_topology.update_node(peer.node_id.clone(), peer.device_capabilities.clone());
|
||||
new_topology.update_edge(
|
||||
self.node_info.node_id.clone(),
|
||||
peer.node_id.clone(),
|
||||
peer.description.clone(),
|
||||
);
|
||||
|
||||
visited.insert(peer.node_id.clone());
|
||||
|
||||
if !visited.contains(&peer.node_id) {
|
||||
let topology = peer.collect_topology(visited.clone(), max_depth - 1).await;
|
||||
new_topology.merge_restricted(&peer.node_id, topology);
|
||||
}
|
||||
}
|
||||
|
||||
new_topology
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
// install global collector configured based on RUST_LOG env var.
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let grpc_addr = "[::1]:50051".parse()?;
|
||||
let node = Node::default();
|
||||
let udp_discovery = UdpDiscovery::new(NodeInfo::default());
|
||||
let node: Node = Node::default();
|
||||
|
||||
// TODO: Also implement discovery
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
use std::collections::HashSet;
|
||||
use crate::node_service::node_service_client::NodeServiceClient;
|
||||
use crate::node_service::HealthCheckRequest;
|
||||
use crate::topology::DeviceCapabilities;
|
||||
use crate::node_service::{CollectTopologyRequest, HealthCheckRequest, Topology as TopologyProto};
|
||||
use crate::topology::{DeviceCapabilities, Topology};
|
||||
use std::net::SocketAddr;
|
||||
use tonic::codec::CompressionEncoding;
|
||||
|
||||
@ -46,4 +47,13 @@ impl PeerHandle {
|
||||
.map(|x| x.into_inner().is_healthy)
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
pub async fn collect_topology(&self, visited: HashSet<String>, max_depth: u8) -> Topology {
|
||||
let response = self.client.lock().await.collect_topology(CollectTopologyRequest {
|
||||
visited: visited.clone().into_iter().collect(),
|
||||
max_depth: max_depth as i32
|
||||
}).await.unwrap().into_inner();
|
||||
|
||||
response.into()
|
||||
}
|
||||
}
|
||||
|
||||
171
src/topology.rs
171
src/topology.rs
@ -1,13 +1,132 @@
|
||||
use crate::{device_capability_data, node_service};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::process::Command;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use crate::device_capability_data;
|
||||
use tonic::Response;
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
pub struct Topology {
|
||||
nodes: HashMap<String, DeviceCapabilities>,
|
||||
peer_graph: HashMap<String, Vec<PeerConnection>>,
|
||||
active_node_id: Option<String>
|
||||
pub nodes: HashMap<String, DeviceCapabilities>,
|
||||
pub peer_graph: HashMap<String, Vec<PeerConnection>>,
|
||||
pub active_node_id: Option<String>,
|
||||
}
|
||||
|
||||
impl Topology {
|
||||
pub fn update_node(&mut self, node_id: String, device_capabilities: DeviceCapabilities) {
|
||||
self.nodes.insert(node_id, device_capabilities);
|
||||
}
|
||||
|
||||
pub fn update_edge(&mut self, from_id: String, to_id: String, description: Option<String>) {
|
||||
let conn = PeerConnection {
|
||||
from_id: from_id.clone(),
|
||||
to_id,
|
||||
description,
|
||||
};
|
||||
|
||||
match self.peer_graph.get_mut(&from_id) {
|
||||
None => {
|
||||
self.peer_graph.insert(from_id, vec![conn]);
|
||||
}
|
||||
|
||||
Some(existing) => {
|
||||
existing.push(conn);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn merge_restricted(&mut self, from_peer_id: &str, topology: Topology) {
|
||||
if let Some(peer_capabilities) = topology.nodes.get(from_peer_id) {
|
||||
self.nodes
|
||||
.insert(from_peer_id.to_string(), peer_capabilities.clone());
|
||||
}
|
||||
|
||||
self.peer_graph.extend(
|
||||
topology
|
||||
.peer_graph
|
||||
.into_iter()
|
||||
.filter(|(id, _)| id == from_peer_id),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::node_service::Topology> for Topology {
|
||||
fn from(proto: crate::node_service::Topology) -> Self {
|
||||
let nodes = proto
|
||||
.nodes
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, v.into()))
|
||||
.collect();
|
||||
|
||||
let peer_graph = proto
|
||||
.peer_graph
|
||||
.into_iter()
|
||||
.map(|(from_id, connections)| {
|
||||
(
|
||||
from_id.clone(),
|
||||
connections
|
||||
.connections
|
||||
.into_iter()
|
||||
.map(|pc| PeerConnection {
|
||||
from_id: from_id.clone(),
|
||||
to_id: pc.to_id,
|
||||
description: pc.description,
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
Topology {
|
||||
nodes,
|
||||
peer_graph,
|
||||
active_node_id: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<crate::node_service::Topology> for Topology {
|
||||
fn into(self) -> crate::node_service::Topology {
|
||||
let nodes = self
|
||||
.nodes
|
||||
.iter()
|
||||
.map(|(node_id, cap)| {
|
||||
(
|
||||
node_id.clone(),
|
||||
node_service::DeviceCapabilities {
|
||||
model: cap.model.clone(),
|
||||
chip: cap.chip.clone(),
|
||||
memory: cap.memory as i32,
|
||||
flops: Some(node_service::DeviceFlops {
|
||||
fp32: cap.flops.fp32,
|
||||
fp16: cap.flops.fp16,
|
||||
int8: cap.flops.int8,
|
||||
}),
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect::<HashMap<String, node_service::DeviceCapabilities>>();
|
||||
|
||||
let peer_graph = self
|
||||
.peer_graph
|
||||
.iter()
|
||||
.map(|(node_id, connections)| {
|
||||
(
|
||||
node_id.clone(),
|
||||
node_service::PeerConnections {
|
||||
connections: connections
|
||||
.iter()
|
||||
.map(|conn| node_service::PeerConnection {
|
||||
to_id: conn.to_id.clone(),
|
||||
description: conn.description.clone(),
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect::<HashMap<String, node_service::PeerConnections>>();
|
||||
|
||||
node_service::Topology { nodes, peer_graph }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Topology {
|
||||
@ -22,16 +141,16 @@ impl Default for Topology {
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
pub struct DeviceCapabilities {
|
||||
model: String,
|
||||
chip: String,
|
||||
memory: u64,
|
||||
flops: DeviceFlops,
|
||||
pub model: String,
|
||||
pub chip: String,
|
||||
pub memory: u64,
|
||||
pub flops: DeviceFlops,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
struct SystemProfilerOutputData {
|
||||
#[serde(rename = "SPHardwareDataType")]
|
||||
hardware: Vec<SPHardwareDataType>
|
||||
hardware: Vec<SPHardwareDataType>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||
@ -51,7 +170,7 @@ struct SPHardwareDataType {
|
||||
platform_uuid: String,
|
||||
#[serde(rename = "provisioning_UDID")]
|
||||
provisioning_udid: String,
|
||||
serial_number: String
|
||||
serial_number: String,
|
||||
}
|
||||
|
||||
impl DeviceCapabilities {
|
||||
@ -83,7 +202,8 @@ impl DeviceCapabilities {
|
||||
};
|
||||
|
||||
DeviceCapabilities {
|
||||
flops: device_capability_data::look_up(&chip).expect("Failed to find FLOPS data for chip"),
|
||||
flops: device_capability_data::look_up(&chip)
|
||||
.expect("Failed to find FLOPS data for chip"),
|
||||
model,
|
||||
chip,
|
||||
memory,
|
||||
@ -91,6 +211,17 @@ impl DeviceCapabilities {
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::node_service::DeviceCapabilities> for DeviceCapabilities {
|
||||
fn from(value: crate::node_service::DeviceCapabilities) -> Self {
|
||||
DeviceCapabilities {
|
||||
model: value.model,
|
||||
chip: value.chip,
|
||||
memory: value.memory as u64,
|
||||
flops: value.flops.map(|x| x.into()).unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
|
||||
pub struct DeviceFlops {
|
||||
pub fp32: f64,
|
||||
@ -98,9 +229,19 @@ pub struct DeviceFlops {
|
||||
pub int8: f64,
|
||||
}
|
||||
|
||||
impl From<crate::node_service::DeviceFlops> for DeviceFlops {
|
||||
fn from(value: crate::node_service::DeviceFlops) -> Self {
|
||||
DeviceFlops {
|
||||
fp32: value.fp32,
|
||||
fp16: value.fp16,
|
||||
int8: value.int8,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize, Clone, Hash, Eq, PartialEq)]
|
||||
pub struct PeerConnection {
|
||||
from_id: String,
|
||||
to_id: String,
|
||||
description: Option<String>,
|
||||
pub from_id: String,
|
||||
pub to_id: String,
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user