use crate::partitioning::{shard_model_by_partition, PartitionStrategy}; use crate::{device_capability_data, node_service, Shard}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::process::Command; #[derive(Debug, Deserialize, Serialize, Clone)] pub struct Topology { pub nodes: HashMap, pub peer_graph: HashMap>, pub active_node_id: Option, } impl Topology { pub fn get_shard_for_node(&self, base_shard: Shard, node_id: &str) -> Shard { let partition_set = PartitionStrategy::RingMemoryWeighted.partition(&self); // TODO: This feels like it could be a better data structure let partition_index = partition_set .iter() .position(|s| s.node_id == node_id) .expect("Did not find node in partition set"); let shards = shard_model_by_partition( &partition_set, base_shard.total_layers.try_into().unwrap(), base_shard.model_id.as_str(), ); shards[partition_index].clone() } } impl Topology { pub fn update_node(&mut self, node_id: String, device_capabilities: DeviceCapabilities) { self.nodes.insert(node_id, device_capabilities); } pub fn update_edge(&mut self, from_id: String, to_id: String, description: Option) { let conn = PeerConnection { from_id: from_id.clone(), to_id, description, }; match self.peer_graph.get_mut(&from_id) { None => { self.peer_graph.insert(from_id, vec![conn]); } Some(existing) => { existing.push(conn); } } } pub(crate) fn merge_restricted(&mut self, from_peer_id: &str, topology: Topology) { if let Some(peer_capabilities) = topology.nodes.get(from_peer_id) { self.nodes .insert(from_peer_id.to_string(), peer_capabilities.clone()); } self.peer_graph.extend( topology .peer_graph .into_iter() .filter(|(id, _)| id == from_peer_id), ); } } impl From for Topology { fn from(proto: node_service::Topology) -> Self { let nodes = proto .nodes .into_iter() .map(|(k, v)| (k, v.into())) .collect(); let peer_graph = proto .peer_graph .into_iter() .map(|(from_id, connections)| { ( from_id.clone(), connections .connections .into_iter() .map(|pc| PeerConnection { from_id: from_id.clone(), to_id: pc.to_id, description: pc.description, }) .collect(), ) }) .collect(); Topology { nodes, peer_graph, active_node_id: None, } } } impl Into for Topology { fn into(self) -> node_service::Topology { let nodes = self .nodes .iter() .map(|(node_id, cap)| { ( node_id.clone(), node_service::DeviceCapabilities { model: cap.model.clone(), chip: cap.chip.clone(), memory: cap.memory as i32, flops: Some(node_service::DeviceFlops { fp32: cap.flops.fp32, fp16: cap.flops.fp16, int8: cap.flops.int8, }), }, ) }) .collect::>(); let peer_graph = self .peer_graph .iter() .map(|(node_id, connections)| { ( node_id.clone(), node_service::PeerConnections { connections: connections .iter() .map(|conn| node_service::PeerConnection { to_id: conn.to_id.clone(), description: conn.description.clone(), }) .collect(), }, ) }) .collect::>(); node_service::Topology { nodes, peer_graph } } } impl Default for Topology { fn default() -> Self { Self { nodes: HashMap::new(), peer_graph: HashMap::new(), active_node_id: None, } } } #[derive(Debug, Deserialize, Serialize, Clone)] pub struct DeviceCapabilities { pub model: String, pub chip: String, pub memory: u64, pub flops: DeviceFlops, } #[derive(Debug, Deserialize, Serialize, Clone)] struct SystemProfilerOutputData { #[serde(rename = "SPHardwareDataType")] hardware: Vec, } #[derive(Debug, Deserialize, Serialize, Clone)] struct SPHardwareDataType { #[serde(rename = "_name")] name: String, activation_lock_status: String, boot_rom_version: String, chip_type: String, machine_model: String, machine_name: String, model_number: String, number_processors: String, os_loader_version: String, physical_memory: String, #[serde(rename = "platform_UUID")] platform_uuid: String, #[serde(rename = "provisioning_UDID")] provisioning_udid: String, serial_number: String, } impl DeviceCapabilities { pub fn determine() -> DeviceCapabilities { let s = Command::new("system_profiler") .arg("SPHardwareDataType") .arg("-json") .output() .unwrap() .stdout; let mut data = serde_json::from_slice::(&s).unwrap(); let hardware = data.hardware.remove(0); let model = hardware.machine_name; let chip = hardware.chip_type; let memory = { let parts: Vec<&str> = hardware.physical_memory.split_ascii_whitespace().collect(); if parts.len() >= 2 { let value = parts[0].parse::().unwrap_or(0); if parts[1] == "GB" { value * 1024 } else { value } } else { 0 } }; DeviceCapabilities { flops: device_capability_data::look_up(&chip) .expect("Failed to find FLOPS data for chip"), model, chip, memory, } } } impl From for DeviceCapabilities { fn from(value: node_service::DeviceCapabilities) -> Self { DeviceCapabilities { model: value.model, chip: value.chip, memory: value.memory as u64, flops: value.flops.map(|x| x.into()).unwrap_or_default(), } } } #[derive(Debug, Deserialize, Serialize, Clone, Default)] pub struct DeviceFlops { pub fp32: f64, pub fp16: f64, pub int8: f64, } impl From for DeviceFlops { fn from(value: node_service::DeviceFlops) -> Self { DeviceFlops { fp32: value.fp32, fp16: value.fp16, int8: value.int8, } } } #[derive(Debug, Deserialize, Serialize, Clone, Hash, Eq, PartialEq)] pub struct PeerConnection { pub from_id: String, pub to_id: String, pub description: Option, }