exo-rs/src/topology.rs

268 lines
7.7 KiB
Rust

use crate::partitioning::{shard_model_by_partition, PartitionStrategy};
use crate::{device_capability_data, node_service, Shard};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::process::Command;
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct Topology {
pub nodes: HashMap<String, DeviceCapabilities>,
pub peer_graph: HashMap<String, Vec<PeerConnection>>,
pub active_node_id: Option<String>,
}
impl Topology {
pub fn get_shard_for_node(&self, base_shard: Shard, node_id: &str) -> Shard {
let partition_set = PartitionStrategy::RingMemoryWeighted.partition(&self);
// TODO: This feels like it could be a better data structure
let partition_index = partition_set
.iter()
.position(|s| s.node_id == node_id)
.expect("Did not find node in partition set");
let shards = shard_model_by_partition(
&partition_set,
base_shard.total_layers.try_into().unwrap(),
base_shard.model_id.as_str(),
);
shards[partition_index].clone()
}
}
impl Topology {
pub fn update_node(&mut self, node_id: String, device_capabilities: DeviceCapabilities) {
self.nodes.insert(node_id, device_capabilities);
}
pub fn update_edge(&mut self, from_id: String, to_id: String, description: Option<String>) {
let conn = PeerConnection {
from_id: from_id.clone(),
to_id,
description,
};
match self.peer_graph.get_mut(&from_id) {
None => {
self.peer_graph.insert(from_id, vec![conn]);
}
Some(existing) => {
existing.push(conn);
}
}
}
pub(crate) fn merge_restricted(&mut self, from_peer_id: &str, topology: Topology) {
if let Some(peer_capabilities) = topology.nodes.get(from_peer_id) {
self.nodes
.insert(from_peer_id.to_string(), peer_capabilities.clone());
}
self.peer_graph.extend(
topology
.peer_graph
.into_iter()
.filter(|(id, _)| id == from_peer_id),
);
}
}
impl From<node_service::Topology> for Topology {
fn from(proto: node_service::Topology) -> Self {
let nodes = proto
.nodes
.into_iter()
.map(|(k, v)| (k, v.into()))
.collect();
let peer_graph = proto
.peer_graph
.into_iter()
.map(|(from_id, connections)| {
(
from_id.clone(),
connections
.connections
.into_iter()
.map(|pc| PeerConnection {
from_id: from_id.clone(),
to_id: pc.to_id,
description: pc.description,
})
.collect(),
)
})
.collect();
Topology {
nodes,
peer_graph,
active_node_id: None,
}
}
}
impl Into<node_service::Topology> for Topology {
fn into(self) -> node_service::Topology {
let nodes = self
.nodes
.iter()
.map(|(node_id, cap)| {
(
node_id.clone(),
node_service::DeviceCapabilities {
model: cap.model.clone(),
chip: cap.chip.clone(),
memory: cap.memory as i32,
flops: Some(node_service::DeviceFlops {
fp32: cap.flops.fp32,
fp16: cap.flops.fp16,
int8: cap.flops.int8,
}),
},
)
})
.collect::<HashMap<String, node_service::DeviceCapabilities>>();
let peer_graph = self
.peer_graph
.iter()
.map(|(node_id, connections)| {
(
node_id.clone(),
node_service::PeerConnections {
connections: connections
.iter()
.map(|conn| node_service::PeerConnection {
to_id: conn.to_id.clone(),
description: conn.description.clone(),
})
.collect(),
},
)
})
.collect::<HashMap<String, node_service::PeerConnections>>();
node_service::Topology { nodes, peer_graph }
}
}
impl Default for Topology {
fn default() -> Self {
Self {
nodes: HashMap::new(),
peer_graph: HashMap::new(),
active_node_id: None,
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct DeviceCapabilities {
pub model: String,
pub chip: String,
pub memory: u64,
pub flops: DeviceFlops,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
struct SystemProfilerOutputData {
#[serde(rename = "SPHardwareDataType")]
hardware: Vec<SPHardwareDataType>,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
struct SPHardwareDataType {
#[serde(rename = "_name")]
name: String,
activation_lock_status: String,
boot_rom_version: String,
chip_type: String,
machine_model: String,
machine_name: String,
model_number: String,
number_processors: String,
os_loader_version: String,
physical_memory: String,
#[serde(rename = "platform_UUID")]
platform_uuid: String,
#[serde(rename = "provisioning_UDID")]
provisioning_udid: String,
serial_number: String,
}
impl DeviceCapabilities {
pub fn determine() -> DeviceCapabilities {
let s = Command::new("system_profiler")
.arg("SPHardwareDataType")
.arg("-json")
.output()
.unwrap()
.stdout;
let mut data = serde_json::from_slice::<SystemProfilerOutputData>(&s).unwrap();
let hardware = data.hardware.remove(0);
let model = hardware.machine_name;
let chip = hardware.chip_type;
let memory = {
let parts: Vec<&str> = hardware.physical_memory.split_ascii_whitespace().collect();
if parts.len() >= 2 {
let value = parts[0].parse::<u64>().unwrap_or(0);
if parts[1] == "GB" {
value * 1024
} else {
value
}
} else {
0
}
};
DeviceCapabilities {
flops: device_capability_data::look_up(&chip)
.expect("Failed to find FLOPS data for chip"),
model,
chip,
memory,
}
}
}
impl From<node_service::DeviceCapabilities> for DeviceCapabilities {
fn from(value: node_service::DeviceCapabilities) -> Self {
DeviceCapabilities {
model: value.model,
chip: value.chip,
memory: value.memory as u64,
flops: value.flops.map(|x| x.into()).unwrap_or_default(),
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone, Default)]
pub struct DeviceFlops {
pub fp32: f64,
pub fp16: f64,
pub int8: f64,
}
impl From<node_service::DeviceFlops> for DeviceFlops {
fn from(value: node_service::DeviceFlops) -> Self {
DeviceFlops {
fp32: value.fp32,
fp16: value.fp16,
int8: value.int8,
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone, Hash, Eq, PartialEq)]
pub struct PeerConnection {
pub from_id: String,
pub to_id: String,
pub description: Option<String>,
}