268 lines
7.7 KiB
Rust
268 lines
7.7 KiB
Rust
use crate::partitioning::{shard_model_by_partition, PartitionStrategy};
|
|
use crate::{device_capability_data, node_service, Shard};
|
|
use serde::{Deserialize, Serialize};
|
|
use std::collections::HashMap;
|
|
use std::process::Command;
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
|
pub struct Topology {
|
|
pub nodes: HashMap<String, DeviceCapabilities>,
|
|
pub peer_graph: HashMap<String, Vec<PeerConnection>>,
|
|
pub active_node_id: Option<String>,
|
|
}
|
|
|
|
impl Topology {
|
|
pub fn get_shard_for_node(&self, base_shard: Shard, node_id: &str) -> Shard {
|
|
let partition_set = PartitionStrategy::RingMemoryWeighted.partition(&self);
|
|
|
|
// TODO: This feels like it could be a better data structure
|
|
let partition_index = partition_set
|
|
.iter()
|
|
.position(|s| s.node_id == node_id)
|
|
.expect("Did not find node in partition set");
|
|
|
|
let shards = shard_model_by_partition(
|
|
&partition_set,
|
|
base_shard.total_layers.try_into().unwrap(),
|
|
base_shard.model_id.as_str(),
|
|
);
|
|
|
|
shards[partition_index].clone()
|
|
}
|
|
}
|
|
|
|
impl Topology {
|
|
pub fn update_node(&mut self, node_id: String, device_capabilities: DeviceCapabilities) {
|
|
self.nodes.insert(node_id, device_capabilities);
|
|
}
|
|
|
|
pub fn update_edge(&mut self, from_id: String, to_id: String, description: Option<String>) {
|
|
let conn = PeerConnection {
|
|
from_id: from_id.clone(),
|
|
to_id,
|
|
description,
|
|
};
|
|
|
|
match self.peer_graph.get_mut(&from_id) {
|
|
None => {
|
|
self.peer_graph.insert(from_id, vec![conn]);
|
|
}
|
|
|
|
Some(existing) => {
|
|
existing.push(conn);
|
|
}
|
|
}
|
|
}
|
|
|
|
pub(crate) fn merge_restricted(&mut self, from_peer_id: &str, topology: Topology) {
|
|
if let Some(peer_capabilities) = topology.nodes.get(from_peer_id) {
|
|
self.nodes
|
|
.insert(from_peer_id.to_string(), peer_capabilities.clone());
|
|
}
|
|
|
|
self.peer_graph.extend(
|
|
topology
|
|
.peer_graph
|
|
.into_iter()
|
|
.filter(|(id, _)| id == from_peer_id),
|
|
);
|
|
}
|
|
}
|
|
|
|
impl From<node_service::Topology> for Topology {
|
|
fn from(proto: node_service::Topology) -> Self {
|
|
let nodes = proto
|
|
.nodes
|
|
.into_iter()
|
|
.map(|(k, v)| (k, v.into()))
|
|
.collect();
|
|
|
|
let peer_graph = proto
|
|
.peer_graph
|
|
.into_iter()
|
|
.map(|(from_id, connections)| {
|
|
(
|
|
from_id.clone(),
|
|
connections
|
|
.connections
|
|
.into_iter()
|
|
.map(|pc| PeerConnection {
|
|
from_id: from_id.clone(),
|
|
to_id: pc.to_id,
|
|
description: pc.description,
|
|
})
|
|
.collect(),
|
|
)
|
|
})
|
|
.collect();
|
|
|
|
Topology {
|
|
nodes,
|
|
peer_graph,
|
|
active_node_id: None,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Into<node_service::Topology> for Topology {
|
|
fn into(self) -> node_service::Topology {
|
|
let nodes = self
|
|
.nodes
|
|
.iter()
|
|
.map(|(node_id, cap)| {
|
|
(
|
|
node_id.clone(),
|
|
node_service::DeviceCapabilities {
|
|
model: cap.model.clone(),
|
|
chip: cap.chip.clone(),
|
|
memory: cap.memory as i32,
|
|
flops: Some(node_service::DeviceFlops {
|
|
fp32: cap.flops.fp32,
|
|
fp16: cap.flops.fp16,
|
|
int8: cap.flops.int8,
|
|
}),
|
|
},
|
|
)
|
|
})
|
|
.collect::<HashMap<String, node_service::DeviceCapabilities>>();
|
|
|
|
let peer_graph = self
|
|
.peer_graph
|
|
.iter()
|
|
.map(|(node_id, connections)| {
|
|
(
|
|
node_id.clone(),
|
|
node_service::PeerConnections {
|
|
connections: connections
|
|
.iter()
|
|
.map(|conn| node_service::PeerConnection {
|
|
to_id: conn.to_id.clone(),
|
|
description: conn.description.clone(),
|
|
})
|
|
.collect(),
|
|
},
|
|
)
|
|
})
|
|
.collect::<HashMap<String, node_service::PeerConnections>>();
|
|
|
|
node_service::Topology { nodes, peer_graph }
|
|
}
|
|
}
|
|
|
|
impl Default for Topology {
|
|
fn default() -> Self {
|
|
Self {
|
|
nodes: HashMap::new(),
|
|
peer_graph: HashMap::new(),
|
|
active_node_id: None,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
|
pub struct DeviceCapabilities {
|
|
pub model: String,
|
|
pub chip: String,
|
|
pub memory: u64,
|
|
pub flops: DeviceFlops,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
|
struct SystemProfilerOutputData {
|
|
#[serde(rename = "SPHardwareDataType")]
|
|
hardware: Vec<SPHardwareDataType>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone)]
|
|
struct SPHardwareDataType {
|
|
#[serde(rename = "_name")]
|
|
name: String,
|
|
activation_lock_status: String,
|
|
boot_rom_version: String,
|
|
chip_type: String,
|
|
machine_model: String,
|
|
machine_name: String,
|
|
model_number: String,
|
|
number_processors: String,
|
|
os_loader_version: String,
|
|
physical_memory: String,
|
|
#[serde(rename = "platform_UUID")]
|
|
platform_uuid: String,
|
|
#[serde(rename = "provisioning_UDID")]
|
|
provisioning_udid: String,
|
|
serial_number: String,
|
|
}
|
|
|
|
impl DeviceCapabilities {
|
|
pub fn determine() -> DeviceCapabilities {
|
|
let s = Command::new("system_profiler")
|
|
.arg("SPHardwareDataType")
|
|
.arg("-json")
|
|
.output()
|
|
.unwrap()
|
|
.stdout;
|
|
|
|
let mut data = serde_json::from_slice::<SystemProfilerOutputData>(&s).unwrap();
|
|
let hardware = data.hardware.remove(0);
|
|
|
|
let model = hardware.machine_name;
|
|
let chip = hardware.chip_type;
|
|
let memory = {
|
|
let parts: Vec<&str> = hardware.physical_memory.split_ascii_whitespace().collect();
|
|
if parts.len() >= 2 {
|
|
let value = parts[0].parse::<u64>().unwrap_or(0);
|
|
if parts[1] == "GB" {
|
|
value * 1024
|
|
} else {
|
|
value
|
|
}
|
|
} else {
|
|
0
|
|
}
|
|
};
|
|
|
|
DeviceCapabilities {
|
|
flops: device_capability_data::look_up(&chip)
|
|
.expect("Failed to find FLOPS data for chip"),
|
|
model,
|
|
chip,
|
|
memory,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl From<node_service::DeviceCapabilities> for DeviceCapabilities {
|
|
fn from(value: node_service::DeviceCapabilities) -> Self {
|
|
DeviceCapabilities {
|
|
model: value.model,
|
|
chip: value.chip,
|
|
memory: value.memory as u64,
|
|
flops: value.flops.map(|x| x.into()).unwrap_or_default(),
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
|
|
pub struct DeviceFlops {
|
|
pub fp32: f64,
|
|
pub fp16: f64,
|
|
pub int8: f64,
|
|
}
|
|
|
|
impl From<node_service::DeviceFlops> for DeviceFlops {
|
|
fn from(value: node_service::DeviceFlops) -> Self {
|
|
DeviceFlops {
|
|
fp32: value.fp32,
|
|
fp16: value.fp16,
|
|
int8: value.int8,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone, Hash, Eq, PartialEq)]
|
|
pub struct PeerConnection {
|
|
pub from_id: String,
|
|
pub to_id: String,
|
|
pub description: Option<String>,
|
|
}
|