63 lines
1.6 KiB
Rust
63 lines
1.6 KiB
Rust
use crate::Shard;
|
|
use crate::topology::Topology;
|
|
|
|
pub enum PartitionStrategy {
|
|
RingMemoryWeighted,
|
|
}
|
|
|
|
pub struct Partition {
|
|
pub node_id: String,
|
|
pub start: f32,
|
|
pub end: f32,
|
|
}
|
|
|
|
impl PartitionStrategy {
|
|
pub fn partition(&self, topology: &Topology) -> Vec<Partition> {
|
|
let mut entries = topology.nodes.iter().collect::<Vec<_>>();
|
|
entries.sort_by_key(|(node_id, device_capabilities)| {
|
|
(device_capabilities.memory, node_id.clone())
|
|
});
|
|
|
|
let mut start = 0.0;
|
|
let mut partitions = Vec::with_capacity(entries.len());
|
|
|
|
for (node_id, device_capabilities) in entries {
|
|
let end = ((start + device_capabilities.memory as f32) * 100000.0).round() / 100000.0;
|
|
partitions.push(Partition {
|
|
node_id: node_id.to_string(),
|
|
start,
|
|
end,
|
|
});
|
|
start = end;
|
|
}
|
|
|
|
partitions
|
|
}
|
|
}
|
|
|
|
pub fn shard_model_by_partition(
|
|
partition_set: &[Partition],
|
|
total_layers: u32,
|
|
model_id: &str,
|
|
) -> Vec<Shard> {
|
|
let mut shards: Vec<Shard> = Vec::with_capacity(partition_set.len());
|
|
|
|
for partition in partition_set {
|
|
let start_layer = (partition.start * total_layers as f32).round() as u32;
|
|
let mut end_layer = (partition.end * total_layers as f32).round() as u32 - 1;
|
|
|
|
if end_layer < start_layer {
|
|
end_layer = total_layers - 1;
|
|
}
|
|
|
|
shards.push(Shard {
|
|
model_id: model_id.to_string(),
|
|
start_layer,
|
|
end_layer,
|
|
total_layers,
|
|
});
|
|
}
|
|
|
|
shards
|
|
}
|