exo-rs/src/partitioning.rs

63 lines
1.6 KiB
Rust

use crate::Shard;
use crate::topology::Topology;
pub enum PartitionStrategy {
RingMemoryWeighted,
}
pub struct Partition {
pub node_id: String,
pub start: f32,
pub end: f32,
}
impl PartitionStrategy {
pub fn partition(&self, topology: &Topology) -> Vec<Partition> {
let mut entries = topology.nodes.iter().collect::<Vec<_>>();
entries.sort_by_key(|(node_id, device_capabilities)| {
(device_capabilities.memory, node_id.clone())
});
let mut start = 0.0;
let mut partitions = Vec::with_capacity(entries.len());
for (node_id, device_capabilities) in entries {
let end = ((start + device_capabilities.memory as f32) * 100000.0).round() / 100000.0;
partitions.push(Partition {
node_id: node_id.to_string(),
start,
end,
});
start = end;
}
partitions
}
}
pub fn shard_model_by_partition(
partition_set: &[Partition],
total_layers: u32,
model_id: &str,
) -> Vec<Shard> {
let mut shards: Vec<Shard> = Vec::with_capacity(partition_set.len());
for partition in partition_set {
let start_layer = (partition.start * total_layers as f32).round() as u32;
let mut end_layer = (partition.end * total_layers as f32).round() as u32 - 1;
if end_layer < start_layer {
end_layer = total_layers - 1;
}
shards.push(Shard {
model_id: model_id.to_string(),
start_layer,
end_layer,
total_layers,
});
}
shards
}