Implement collect topology
This commit is contained in:
parent
cd0b4a1bbf
commit
3aca4a22ae
@ -7,7 +7,7 @@ const TFLOPS: f64 = 1.00;
|
|||||||
pub static CHIP_FLOPS: phf::Map<&'static str, DeviceFlops> = phf_map! {
|
pub static CHIP_FLOPS: phf::Map<&'static str, DeviceFlops> = phf_map! {
|
||||||
// Source: https://www.cpu-monkey.com
|
// Source: https://www.cpu-monkey.com
|
||||||
// Note: currently no distinction between variants of M3 Max and M3 Pro, we pick the lower one to be conservative
|
// Note: currently no distinction between variants of M3 Max and M3 Pro, we pick the lower one to be conservative
|
||||||
/// M chips
|
// M chips
|
||||||
"Apple M1" => DeviceFlops { fp32: 2.29*TFLOPS, fp16: 4.58*TFLOPS, int8: 9.16*TFLOPS },
|
"Apple M1" => DeviceFlops { fp32: 2.29*TFLOPS, fp16: 4.58*TFLOPS, int8: 9.16*TFLOPS },
|
||||||
"Apple M1 Pro" => DeviceFlops { fp32: 5.30*TFLOPS, fp16: 10.60*TFLOPS, int8: 21.20*TFLOPS },
|
"Apple M1 Pro" => DeviceFlops { fp32: 5.30*TFLOPS, fp16: 10.60*TFLOPS, int8: 21.20*TFLOPS },
|
||||||
"Apple M1 Max" => DeviceFlops { fp32: 10.60*TFLOPS, fp16: 21.20*TFLOPS, int8: 42.40*TFLOPS },
|
"Apple M1 Max" => DeviceFlops { fp32: 10.60*TFLOPS, fp16: 21.20*TFLOPS, int8: 42.40*TFLOPS },
|
||||||
@ -22,13 +22,13 @@ pub static CHIP_FLOPS: phf::Map<&'static str, DeviceFlops> = phf_map! {
|
|||||||
"Apple M4" => DeviceFlops { fp32: 4.26*TFLOPS, fp16: 8.52*TFLOPS, int8: 17.04*TFLOPS },
|
"Apple M4" => DeviceFlops { fp32: 4.26*TFLOPS, fp16: 8.52*TFLOPS, int8: 17.04*TFLOPS },
|
||||||
"Apple M4 Pro" => DeviceFlops { fp32: 5.72*TFLOPS, fp16: 11.44*TFLOPS, int8: 22.88*TFLOPS },
|
"Apple M4 Pro" => DeviceFlops { fp32: 5.72*TFLOPS, fp16: 11.44*TFLOPS, int8: 22.88*TFLOPS },
|
||||||
"Apple M4 Max" => DeviceFlops { fp32: 18.03*TFLOPS, fp16: 36.07*TFLOPS, int8: 72.14*TFLOPS },
|
"Apple M4 Max" => DeviceFlops { fp32: 18.03*TFLOPS, fp16: 36.07*TFLOPS, int8: 72.14*TFLOPS },
|
||||||
/// A chips
|
// A chips
|
||||||
"Apple A13 Bionic" => DeviceFlops { fp32: 0.69*TFLOPS, fp16: 1.38*TFLOPS, int8: 2.76*TFLOPS },
|
"Apple A13 Bionic" => DeviceFlops { fp32: 0.69*TFLOPS, fp16: 1.38*TFLOPS, int8: 2.76*TFLOPS },
|
||||||
"Apple A14 Bionic" => DeviceFlops { fp32: 0.75*TFLOPS, fp16: 1.50*TFLOPS, int8: 3.00*TFLOPS },
|
"Apple A14 Bionic" => DeviceFlops { fp32: 0.75*TFLOPS, fp16: 1.50*TFLOPS, int8: 3.00*TFLOPS },
|
||||||
"Apple A15 Bionic" => DeviceFlops { fp32: 1.37*TFLOPS, fp16: 2.74*TFLOPS, int8: 5.48*TFLOPS },
|
"Apple A15 Bionic" => DeviceFlops { fp32: 1.37*TFLOPS, fp16: 2.74*TFLOPS, int8: 5.48*TFLOPS },
|
||||||
"Apple A16 Bionic" => DeviceFlops { fp32: 1.79*TFLOPS, fp16: 3.58*TFLOPS, int8: 7.16*TFLOPS },
|
"Apple A16 Bionic" => DeviceFlops { fp32: 1.79*TFLOPS, fp16: 3.58*TFLOPS, int8: 7.16*TFLOPS },
|
||||||
"Apple A17 Pro" => DeviceFlops { fp32: 2.15*TFLOPS, fp16: 4.30*TFLOPS, int8: 8.60*TFLOPS },
|
"Apple A17 Pro" => DeviceFlops { fp32: 2.15*TFLOPS, fp16: 4.30*TFLOPS, int8: 8.60*TFLOPS },
|
||||||
/// NVIDIA GPUs
|
// NVIDIA GPUs
|
||||||
// RTX 40 series
|
// RTX 40 series
|
||||||
"NVIDIA GEFORCE RTX 4090" => DeviceFlops { fp32: 82.58*TFLOPS, fp16: 165.16*TFLOPS, int8: 330.32*TFLOPS },
|
"NVIDIA GEFORCE RTX 4090" => DeviceFlops { fp32: 82.58*TFLOPS, fp16: 165.16*TFLOPS, int8: 330.32*TFLOPS },
|
||||||
"NVIDIA GEFORCE RTX 4080" => DeviceFlops { fp32: 48.74*TFLOPS, fp16: 97.48*TFLOPS, int8: 194.96*TFLOPS },
|
"NVIDIA GEFORCE RTX 4080" => DeviceFlops { fp32: 48.74*TFLOPS, fp16: 97.48*TFLOPS, int8: 194.96*TFLOPS },
|
||||||
@ -82,7 +82,7 @@ pub static CHIP_FLOPS: phf::Map<&'static str, DeviceFlops> = phf_map! {
|
|||||||
"NVIDIA A800 80GB PCIE" => DeviceFlops { fp32: 19.5*TFLOPS, fp16: 312.0*TFLOPS, int8: 624.0*TFLOPS },
|
"NVIDIA A800 80GB PCIE" => DeviceFlops { fp32: 19.5*TFLOPS, fp16: 312.0*TFLOPS, int8: 624.0*TFLOPS },
|
||||||
"NVIDIA A100 80GB SXM" => DeviceFlops { fp32: 19.5*TFLOPS, fp16: 312.0*TFLOPS, int8: 624.0*TFLOPS },
|
"NVIDIA A100 80GB SXM" => DeviceFlops { fp32: 19.5*TFLOPS, fp16: 312.0*TFLOPS, int8: 624.0*TFLOPS },
|
||||||
"NVIDIA A800 80GB SXM" => DeviceFlops { fp32: 19.5*TFLOPS, fp16: 312.0*TFLOPS, int8: 624.0*TFLOPS },
|
"NVIDIA A800 80GB SXM" => DeviceFlops { fp32: 19.5*TFLOPS, fp16: 312.0*TFLOPS, int8: 624.0*TFLOPS },
|
||||||
/// AMD GPUs
|
// AMD GPUs
|
||||||
// RX 6000 series
|
// RX 6000 series
|
||||||
"AMD Radeon RX 6900 XT" => DeviceFlops { fp32: 23.04*TFLOPS, fp16: 46.08*TFLOPS, int8: 92.16*TFLOPS },
|
"AMD Radeon RX 6900 XT" => DeviceFlops { fp32: 23.04*TFLOPS, fp16: 46.08*TFLOPS, int8: 92.16*TFLOPS },
|
||||||
"AMD Radeon RX 6800 XT" => DeviceFlops { fp32: 20.74*TFLOPS, fp16: 41.48*TFLOPS, int8: 82.96*TFLOPS },
|
"AMD Radeon RX 6800 XT" => DeviceFlops { fp32: 20.74*TFLOPS, fp16: 41.48*TFLOPS, int8: 82.96*TFLOPS },
|
||||||
|
|||||||
@ -1,12 +1,18 @@
|
|||||||
|
use std::cell::RefCell;
|
||||||
|
use std::collections::HashMap;
|
||||||
use crate::network::get_broadcast_creation_info;
|
use crate::network::get_broadcast_creation_info;
|
||||||
use crate::topology::DeviceCapabilities;
|
use crate::topology::DeviceCapabilities;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use socket2::{Domain, Protocol, Socket, Type};
|
use socket2::{Domain, Protocol, Socket, Type};
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use tokio::net::UdpSocket;
|
use tokio::net::UdpSocket;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
use tokio::task::JoinHandle;
|
use tokio::task::JoinHandle;
|
||||||
|
use tracing::{debug, info};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
use crate::orchestration::PeerHandle;
|
||||||
|
|
||||||
mod broadcast;
|
mod broadcast;
|
||||||
mod udp_listen;
|
mod udp_listen;
|
||||||
@ -52,13 +58,13 @@ impl Default for NodeInfo {
|
|||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
NodeInfo {
|
NodeInfo {
|
||||||
node_id: Uuid::new_v4().to_string(),
|
node_id: Uuid::new_v4().to_string(),
|
||||||
discovery_listen_port: 0,
|
discovery_listen_port: 5678,
|
||||||
broadcast_port: 0,
|
broadcast_port: 5678,
|
||||||
broadcast_interval: Default::default(),
|
broadcast_interval: Duration::from_secs_f32(2.5),
|
||||||
grpc_port: 0,
|
grpc_port: 49152,
|
||||||
allowed_peer_ids: None,
|
allowed_peer_ids: None,
|
||||||
allowed_interfaces: None,
|
allowed_interfaces: None,
|
||||||
discovery_timeout: Default::default(),
|
discovery_timeout: Duration::from_secs(30),
|
||||||
device_capabilities: DeviceCapabilities::determine(),
|
device_capabilities: DeviceCapabilities::determine(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -69,19 +75,25 @@ pub struct UdpDiscovery {
|
|||||||
discovery_handle: JoinHandle<()>,
|
discovery_handle: JoinHandle<()>,
|
||||||
presence_handle: JoinHandle<()>,
|
presence_handle: JoinHandle<()>,
|
||||||
peer_manager_handle: JoinHandle<()>,
|
peer_manager_handle: JoinHandle<()>,
|
||||||
|
pub peers: Arc<Mutex<HashMap<String, PeerHandle>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UdpDiscovery {
|
impl UdpDiscovery {
|
||||||
pub fn new(node_info: NodeInfo) -> Self {
|
pub fn new(node_info: NodeInfo) -> Self {
|
||||||
let broadcast_creation_info = get_broadcast_creation_info();
|
let broadcast_creation_info = get_broadcast_creation_info();
|
||||||
|
info!("Found addresses: {:?}", broadcast_creation_info);
|
||||||
|
|
||||||
|
let peers = Arc::new(Mutex::new(HashMap::new()));
|
||||||
|
|
||||||
let discovery_handle = tokio::spawn(broadcast::listen_all(node_info.clone(), broadcast_creation_info));
|
let discovery_handle = tokio::spawn(broadcast::listen_all(node_info.clone(), broadcast_creation_info));
|
||||||
let (presence_handle, peer_manager_handle) = udp_listen::manage_discovery(node_info.clone());
|
let (presence_handle, peer_manager_handle) = udp_listen::manage_discovery(node_info.clone(), peers.clone());
|
||||||
|
|
||||||
UdpDiscovery {
|
UdpDiscovery {
|
||||||
node_info,
|
node_info,
|
||||||
discovery_handle,
|
discovery_handle,
|
||||||
presence_handle,
|
presence_handle,
|
||||||
peer_manager_handle,
|
peer_manager_handle,
|
||||||
|
peers
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,13 +1,16 @@
|
|||||||
|
use std::cell::RefCell;
|
||||||
use crate::discovery::{DiscoveryMessage, NodeInfo};
|
use crate::discovery::{DiscoveryMessage, NodeInfo};
|
||||||
use crate::orchestration::PeerHandle;
|
use crate::orchestration::PeerHandle;
|
||||||
use crate::topology::DeviceCapabilities;
|
use crate::topology::DeviceCapabilities;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
use system_configuration::sys::libc::disconnectx;
|
use system_configuration::sys::libc::disconnectx;
|
||||||
use tokio::net::UdpSocket;
|
use tokio::net::UdpSocket;
|
||||||
use tokio::select;
|
use tokio::select;
|
||||||
use tokio::sync::mpsc::UnboundedSender;
|
use tokio::sync::mpsc::UnboundedSender;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
use tokio::task::JoinHandle;
|
use tokio::task::JoinHandle;
|
||||||
use tonic::transport::Error;
|
use tonic::transport::Error;
|
||||||
use tracing::{debug, error, info};
|
use tracing::{debug, error, info};
|
||||||
@ -119,8 +122,7 @@ async fn handle_new_peer(
|
|||||||
peers.insert(message.node_id, new_peer);
|
peers.insert(message.node_id, new_peer);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn manage_discovery(node_info: NodeInfo) -> (JoinHandle<()>, JoinHandle<()>) {
|
pub fn manage_discovery(node_info: NodeInfo, peers: Arc<Mutex<HashMap<String, PeerHandle>>>) -> (JoinHandle<()>, JoinHandle<()>) {
|
||||||
let mut peers: HashMap<String, PeerHandle> = HashMap::new();
|
|
||||||
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<(SocketAddr, DiscoveryMessage)>();
|
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<(SocketAddr, DiscoveryMessage)>();
|
||||||
|
|
||||||
// TODO: How do we handle killing this?
|
// TODO: How do we handle killing this?
|
||||||
@ -128,11 +130,13 @@ pub fn manage_discovery(node_info: NodeInfo) -> (JoinHandle<()>, JoinHandle<()>)
|
|||||||
|
|
||||||
let peer_manager_handle = tokio::spawn(async move {
|
let peer_manager_handle = tokio::spawn(async move {
|
||||||
loop {
|
loop {
|
||||||
let action = select! {
|
let action: Action = select! {
|
||||||
_ = tokio::time::sleep(node_info.discovery_timeout) => Action::HealthChecks,
|
_ = tokio::time::sleep(node_info.discovery_timeout) => Action::HealthChecks,
|
||||||
Some((addr, message)) = rx.recv() => Action::NewPeer(addr, message),
|
Some((addr, message)) = rx.recv() => Action::NewPeer(addr, message),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let mut peers = peers.lock().await;
|
||||||
|
|
||||||
match action {
|
match action {
|
||||||
Action::NewPeer(addr, message) => handle_new_peer(&mut peers, addr, message).await,
|
Action::NewPeer(addr, message) => handle_new_peer(&mut peers, addr, message).await,
|
||||||
Action::HealthChecks => perform_health_checks(&mut peers).await,
|
Action::HealthChecks => perform_health_checks(&mut peers).await,
|
||||||
|
|||||||
65
src/main.rs
65
src/main.rs
@ -1,34 +1,41 @@
|
|||||||
mod topology;
|
mod device_capability_data;
|
||||||
mod orchestration;
|
|
||||||
mod discovery;
|
mod discovery;
|
||||||
mod network;
|
mod network;
|
||||||
mod device_capability_data;
|
mod orchestration;
|
||||||
|
mod topology;
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use tonic::{transport::Server, Request, Response, Status};
|
use tonic::{transport::Server, Request, Response, Status};
|
||||||
|
|
||||||
|
use crate::discovery::{NodeInfo, UdpDiscovery};
|
||||||
use crate::node_service::{
|
use crate::node_service::{
|
||||||
CollectTopologyRequest, Empty, ExampleRequest, HealthCheckRequest, HealthCheckResponse, Loss,
|
CollectTopologyRequest, Empty, ExampleRequest, HealthCheckRequest, HealthCheckResponse, Loss,
|
||||||
PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor, Topology as TopologyProto,
|
PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor, Topology as TopologyProto,
|
||||||
};
|
};
|
||||||
use node_service::node_service_server::{NodeService, NodeServiceServer};
|
use node_service::node_service_server::{NodeService, NodeServiceServer};
|
||||||
use node_service::TensorRequest;
|
use node_service::TensorRequest;
|
||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
use topology::Topology;
|
use topology::Topology;
|
||||||
use crate::discovery::{NodeInfo, UdpDiscovery};
|
|
||||||
|
|
||||||
pub mod node_service {
|
pub mod node_service {
|
||||||
tonic::include_proto!("node_service"); // The string specified here must match the proto package name
|
tonic::include_proto!("node_service"); // The string specified here must match the proto package name
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Node {
|
struct Node {
|
||||||
|
node_info: NodeInfo,
|
||||||
current_topology: Topology,
|
current_topology: Topology,
|
||||||
|
udp_discovery: UdpDiscovery,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for Node {
|
impl Default for Node {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
|
let node_info = NodeInfo::default();
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
|
node_info: node_info.clone(),
|
||||||
current_topology: Topology::default(),
|
current_topology: Topology::default(),
|
||||||
|
udp_discovery: UdpDiscovery::new(node_info),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -49,7 +56,6 @@ struct Shard {
|
|||||||
pub n_layers: i32,
|
pub n_layers: i32,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||||
struct NodeStatus {
|
struct NodeStatus {
|
||||||
node_id: String,
|
node_id: String,
|
||||||
@ -151,11 +157,19 @@ impl NodeService for Node {
|
|||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: Why aren't we using the request?
|
||||||
async fn collect_topology(
|
async fn collect_topology(
|
||||||
&self,
|
&self,
|
||||||
request: Request<CollectTopologyRequest>,
|
request: Request<CollectTopologyRequest>,
|
||||||
) -> Result<Response<TopologyProto>, Status> {
|
) -> Result<Response<TopologyProto>, Status> {
|
||||||
todo!()
|
let request = request.into_inner();
|
||||||
|
let max_depth = request.max_depth as u8;
|
||||||
|
let visited = request.visited;
|
||||||
|
|
||||||
|
self.update_topology_inner(max_depth, visited.into_iter().collect())
|
||||||
|
.await;
|
||||||
|
|
||||||
|
Ok(Response::new(self.current_topology.clone().into()))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send_result(
|
async fn send_result(
|
||||||
@ -188,14 +202,49 @@ impl NodeService for Node {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Node {
|
||||||
|
async fn update_topology(&mut self) {
|
||||||
|
let overall_max_depth = 4;
|
||||||
|
let visited: HashSet<String> = HashSet::new();
|
||||||
|
|
||||||
|
self.current_topology = self.update_topology_inner(overall_max_depth, visited).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn update_topology_inner(&self, max_depth: u8, mut visited: HashSet<String>) -> Topology {
|
||||||
|
let mut new_topology = Topology::default();
|
||||||
|
|
||||||
|
new_topology.update_node(
|
||||||
|
self.node_info.node_id.clone(),
|
||||||
|
self.node_info.device_capabilities.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
|
for peer in self.udp_discovery.peers.lock().await.values() {
|
||||||
|
new_topology.update_node(peer.node_id.clone(), peer.device_capabilities.clone());
|
||||||
|
new_topology.update_edge(
|
||||||
|
self.node_info.node_id.clone(),
|
||||||
|
peer.node_id.clone(),
|
||||||
|
peer.description.clone(),
|
||||||
|
);
|
||||||
|
|
||||||
|
visited.insert(peer.node_id.clone());
|
||||||
|
|
||||||
|
if !visited.contains(&peer.node_id) {
|
||||||
|
let topology = peer.collect_topology(visited.clone(), max_depth - 1).await;
|
||||||
|
new_topology.merge_restricted(&peer.node_id, topology);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
new_topology
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||||
// install global collector configured based on RUST_LOG env var.
|
// install global collector configured based on RUST_LOG env var.
|
||||||
tracing_subscriber::fmt::init();
|
tracing_subscriber::fmt::init();
|
||||||
|
|
||||||
let grpc_addr = "[::1]:50051".parse()?;
|
let grpc_addr = "[::1]:50051".parse()?;
|
||||||
let node = Node::default();
|
let node: Node = Node::default();
|
||||||
let udp_discovery = UdpDiscovery::new(NodeInfo::default());
|
|
||||||
|
|
||||||
// TODO: Also implement discovery
|
// TODO: Also implement discovery
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
|
use std::collections::HashSet;
|
||||||
use crate::node_service::node_service_client::NodeServiceClient;
|
use crate::node_service::node_service_client::NodeServiceClient;
|
||||||
use crate::node_service::HealthCheckRequest;
|
use crate::node_service::{CollectTopologyRequest, HealthCheckRequest, Topology as TopologyProto};
|
||||||
use crate::topology::DeviceCapabilities;
|
use crate::topology::{DeviceCapabilities, Topology};
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use tonic::codec::CompressionEncoding;
|
use tonic::codec::CompressionEncoding;
|
||||||
|
|
||||||
@ -46,4 +47,13 @@ impl PeerHandle {
|
|||||||
.map(|x| x.into_inner().is_healthy)
|
.map(|x| x.into_inner().is_healthy)
|
||||||
.unwrap_or(false)
|
.unwrap_or(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn collect_topology(&self, visited: HashSet<String>, max_depth: u8) -> Topology {
|
||||||
|
let response = self.client.lock().await.collect_topology(CollectTopologyRequest {
|
||||||
|
visited: visited.clone().into_iter().collect(),
|
||||||
|
max_depth: max_depth as i32
|
||||||
|
}).await.unwrap().into_inner();
|
||||||
|
|
||||||
|
response.into()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
171
src/topology.rs
171
src/topology.rs
@ -1,13 +1,132 @@
|
|||||||
|
use crate::{device_capability_data, node_service};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::process::Command;
|
use std::process::Command;
|
||||||
use serde::{Deserialize, Serialize};
|
use tonic::Response;
|
||||||
use crate::device_capability_data;
|
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||||
pub struct Topology {
|
pub struct Topology {
|
||||||
nodes: HashMap<String, DeviceCapabilities>,
|
pub nodes: HashMap<String, DeviceCapabilities>,
|
||||||
peer_graph: HashMap<String, Vec<PeerConnection>>,
|
pub peer_graph: HashMap<String, Vec<PeerConnection>>,
|
||||||
active_node_id: Option<String>
|
pub active_node_id: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Topology {
|
||||||
|
pub fn update_node(&mut self, node_id: String, device_capabilities: DeviceCapabilities) {
|
||||||
|
self.nodes.insert(node_id, device_capabilities);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn update_edge(&mut self, from_id: String, to_id: String, description: Option<String>) {
|
||||||
|
let conn = PeerConnection {
|
||||||
|
from_id: from_id.clone(),
|
||||||
|
to_id,
|
||||||
|
description,
|
||||||
|
};
|
||||||
|
|
||||||
|
match self.peer_graph.get_mut(&from_id) {
|
||||||
|
None => {
|
||||||
|
self.peer_graph.insert(from_id, vec![conn]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(existing) => {
|
||||||
|
existing.push(conn);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn merge_restricted(&mut self, from_peer_id: &str, topology: Topology) {
|
||||||
|
if let Some(peer_capabilities) = topology.nodes.get(from_peer_id) {
|
||||||
|
self.nodes
|
||||||
|
.insert(from_peer_id.to_string(), peer_capabilities.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
self.peer_graph.extend(
|
||||||
|
topology
|
||||||
|
.peer_graph
|
||||||
|
.into_iter()
|
||||||
|
.filter(|(id, _)| id == from_peer_id),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<crate::node_service::Topology> for Topology {
|
||||||
|
fn from(proto: crate::node_service::Topology) -> Self {
|
||||||
|
let nodes = proto
|
||||||
|
.nodes
|
||||||
|
.into_iter()
|
||||||
|
.map(|(k, v)| (k, v.into()))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let peer_graph = proto
|
||||||
|
.peer_graph
|
||||||
|
.into_iter()
|
||||||
|
.map(|(from_id, connections)| {
|
||||||
|
(
|
||||||
|
from_id.clone(),
|
||||||
|
connections
|
||||||
|
.connections
|
||||||
|
.into_iter()
|
||||||
|
.map(|pc| PeerConnection {
|
||||||
|
from_id: from_id.clone(),
|
||||||
|
to_id: pc.to_id,
|
||||||
|
description: pc.description,
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Topology {
|
||||||
|
nodes,
|
||||||
|
peer_graph,
|
||||||
|
active_node_id: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Into<crate::node_service::Topology> for Topology {
|
||||||
|
fn into(self) -> crate::node_service::Topology {
|
||||||
|
let nodes = self
|
||||||
|
.nodes
|
||||||
|
.iter()
|
||||||
|
.map(|(node_id, cap)| {
|
||||||
|
(
|
||||||
|
node_id.clone(),
|
||||||
|
node_service::DeviceCapabilities {
|
||||||
|
model: cap.model.clone(),
|
||||||
|
chip: cap.chip.clone(),
|
||||||
|
memory: cap.memory as i32,
|
||||||
|
flops: Some(node_service::DeviceFlops {
|
||||||
|
fp32: cap.flops.fp32,
|
||||||
|
fp16: cap.flops.fp16,
|
||||||
|
int8: cap.flops.int8,
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<HashMap<String, node_service::DeviceCapabilities>>();
|
||||||
|
|
||||||
|
let peer_graph = self
|
||||||
|
.peer_graph
|
||||||
|
.iter()
|
||||||
|
.map(|(node_id, connections)| {
|
||||||
|
(
|
||||||
|
node_id.clone(),
|
||||||
|
node_service::PeerConnections {
|
||||||
|
connections: connections
|
||||||
|
.iter()
|
||||||
|
.map(|conn| node_service::PeerConnection {
|
||||||
|
to_id: conn.to_id.clone(),
|
||||||
|
description: conn.description.clone(),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect::<HashMap<String, node_service::PeerConnections>>();
|
||||||
|
|
||||||
|
node_service::Topology { nodes, peer_graph }
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for Topology {
|
impl Default for Topology {
|
||||||
@ -22,16 +141,16 @@ impl Default for Topology {
|
|||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||||
pub struct DeviceCapabilities {
|
pub struct DeviceCapabilities {
|
||||||
model: String,
|
pub model: String,
|
||||||
chip: String,
|
pub chip: String,
|
||||||
memory: u64,
|
pub memory: u64,
|
||||||
flops: DeviceFlops,
|
pub flops: DeviceFlops,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||||
struct SystemProfilerOutputData {
|
struct SystemProfilerOutputData {
|
||||||
#[serde(rename = "SPHardwareDataType")]
|
#[serde(rename = "SPHardwareDataType")]
|
||||||
hardware: Vec<SPHardwareDataType>
|
hardware: Vec<SPHardwareDataType>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize, Clone)]
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
||||||
@ -51,7 +170,7 @@ struct SPHardwareDataType {
|
|||||||
platform_uuid: String,
|
platform_uuid: String,
|
||||||
#[serde(rename = "provisioning_UDID")]
|
#[serde(rename = "provisioning_UDID")]
|
||||||
provisioning_udid: String,
|
provisioning_udid: String,
|
||||||
serial_number: String
|
serial_number: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DeviceCapabilities {
|
impl DeviceCapabilities {
|
||||||
@ -83,7 +202,8 @@ impl DeviceCapabilities {
|
|||||||
};
|
};
|
||||||
|
|
||||||
DeviceCapabilities {
|
DeviceCapabilities {
|
||||||
flops: device_capability_data::look_up(&chip).expect("Failed to find FLOPS data for chip"),
|
flops: device_capability_data::look_up(&chip)
|
||||||
|
.expect("Failed to find FLOPS data for chip"),
|
||||||
model,
|
model,
|
||||||
chip,
|
chip,
|
||||||
memory,
|
memory,
|
||||||
@ -91,6 +211,17 @@ impl DeviceCapabilities {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<crate::node_service::DeviceCapabilities> for DeviceCapabilities {
|
||||||
|
fn from(value: crate::node_service::DeviceCapabilities) -> Self {
|
||||||
|
DeviceCapabilities {
|
||||||
|
model: value.model,
|
||||||
|
chip: value.chip,
|
||||||
|
memory: value.memory as u64,
|
||||||
|
flops: value.flops.map(|x| x.into()).unwrap_or_default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
|
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
|
||||||
pub struct DeviceFlops {
|
pub struct DeviceFlops {
|
||||||
pub fp32: f64,
|
pub fp32: f64,
|
||||||
@ -98,9 +229,19 @@ pub struct DeviceFlops {
|
|||||||
pub int8: f64,
|
pub int8: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<crate::node_service::DeviceFlops> for DeviceFlops {
|
||||||
|
fn from(value: crate::node_service::DeviceFlops) -> Self {
|
||||||
|
DeviceFlops {
|
||||||
|
fp32: value.fp32,
|
||||||
|
fp16: value.fp16,
|
||||||
|
int8: value.int8,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize, Clone, Hash, Eq, PartialEq)]
|
#[derive(Debug, Deserialize, Serialize, Clone, Hash, Eq, PartialEq)]
|
||||||
pub struct PeerConnection {
|
pub struct PeerConnection {
|
||||||
from_id: String,
|
pub from_id: String,
|
||||||
to_id: String,
|
pub to_id: String,
|
||||||
description: Option<String>,
|
pub description: Option<String>,
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user