diff --git a/src/discovery/mod.rs b/src/discovery/mod.rs index 6f9fe5e..132f573 100644 --- a/src/discovery/mod.rs +++ b/src/discovery/mod.rs @@ -1,4 +1,3 @@ -use std::cell::RefCell; use std::collections::HashMap; use crate::network::get_broadcast_creation_info; use crate::topology::DeviceCapabilities; diff --git a/src/main.rs b/src/main.rs index d823fa4..5029132 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,6 +3,7 @@ mod discovery; mod network; mod orchestration; mod topology; +mod partitioning; use serde::{Deserialize, Serialize}; use serde_json::Value; diff --git a/src/partitioning.rs b/src/partitioning.rs new file mode 100644 index 0000000..7db3b91 --- /dev/null +++ b/src/partitioning.rs @@ -0,0 +1,68 @@ +use crate::topology::Topology; + +pub enum PartitionStrategy { + RingMemoryWeighted, +} + +pub struct Partition { + node_id: String, + start: f32, + end: f32, +} + +impl PartitionStrategy { + pub fn partition(&self, topology: &Topology) -> Vec { + let mut entries = topology.nodes.iter().collect::>(); + entries.sort_by_key(|(node_id, device_capabilities)| { + (device_capabilities.memory, node_id.clone()) + }); + + let mut start = 0.0; + let mut partitions = Vec::with_capacity(entries.len()); + + for (node_id, device_capabilities) in entries { + let end = ((start + device_capabilities.memory as f32) * 100000.0).round() / 100000.0; + partitions.push(Partition { + node_id: node_id.to_string(), + start, + end, + }); + start = end; + } + + partitions + } +} + +pub struct ModelShard<'a> { + model_id: &'a str, + start_layer: u8, + end_layer: u8, + total_layers: u8, +} + +pub fn shard_model_by_partition( + partitions: Vec, + total_layers: u8, + model_id: &str, +) -> Vec { + let mut shards: Vec> = Vec::with_capacity(partitions.len()); + + for partition in partitions { + let start_layer = (partition.start * total_layers as f32).round() as u8; + let mut end_layer = (partition.end * total_layers as f32).round() as u8 - 1; + + if end_layer < start_layer { + end_layer = total_layers - 1; + } + + shards.push(ModelShard { + model_id, + start_layer, + end_layer, + total_layers, + }); + } + + shards +}