use crate::Shard; use crate::topology::Topology; pub enum PartitionStrategy { RingMemoryWeighted, } pub struct Partition { pub node_id: String, pub start: f32, pub end: f32, } impl PartitionStrategy { pub fn partition(&self, topology: &Topology) -> Vec { let mut entries = topology.nodes.iter().collect::>(); entries.sort_by_key(|(node_id, device_capabilities)| { (device_capabilities.memory, node_id.clone()) }); let mut start = 0.0; let mut partitions = Vec::with_capacity(entries.len()); for (node_id, device_capabilities) in entries { let end = ((start + device_capabilities.memory as f32) * 100000.0).round() / 100000.0; partitions.push(Partition { node_id: node_id.to_string(), start, end, }); start = end; } partitions } } pub fn shard_model_by_partition( partition_set: &[Partition], total_layers: u32, model_id: &str, ) -> Vec { let mut shards: Vec = Vec::with_capacity(partition_set.len()); for partition in partition_set { let start_layer = (partition.start * total_layers as f32).round() as u32; let mut end_layer = (partition.end * total_layers as f32).round() as u32 - 1; if end_layer < start_layer { end_layer = total_layers - 1; } shards.push(Shard { model_id: model_id.to_string(), start_layer, end_layer, total_layers, }); } shards }