Add partition and shards
This commit is contained in:
parent
3aca4a22ae
commit
02694cbacc
@ -1,4 +1,3 @@
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use crate::network::get_broadcast_creation_info;
|
||||
use crate::topology::DeviceCapabilities;
|
||||
|
||||
@ -3,6 +3,7 @@ mod discovery;
|
||||
mod network;
|
||||
mod orchestration;
|
||||
mod topology;
|
||||
mod partitioning;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
68
src/partitioning.rs
Normal file
68
src/partitioning.rs
Normal file
@ -0,0 +1,68 @@
|
||||
use crate::topology::Topology;
|
||||
|
||||
pub enum PartitionStrategy {
|
||||
RingMemoryWeighted,
|
||||
}
|
||||
|
||||
pub struct Partition {
|
||||
node_id: String,
|
||||
start: f32,
|
||||
end: f32,
|
||||
}
|
||||
|
||||
impl PartitionStrategy {
|
||||
pub fn partition(&self, topology: &Topology) -> Vec<Partition> {
|
||||
let mut entries = topology.nodes.iter().collect::<Vec<_>>();
|
||||
entries.sort_by_key(|(node_id, device_capabilities)| {
|
||||
(device_capabilities.memory, node_id.clone())
|
||||
});
|
||||
|
||||
let mut start = 0.0;
|
||||
let mut partitions = Vec::with_capacity(entries.len());
|
||||
|
||||
for (node_id, device_capabilities) in entries {
|
||||
let end = ((start + device_capabilities.memory as f32) * 100000.0).round() / 100000.0;
|
||||
partitions.push(Partition {
|
||||
node_id: node_id.to_string(),
|
||||
start,
|
||||
end,
|
||||
});
|
||||
start = end;
|
||||
}
|
||||
|
||||
partitions
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ModelShard<'a> {
|
||||
model_id: &'a str,
|
||||
start_layer: u8,
|
||||
end_layer: u8,
|
||||
total_layers: u8,
|
||||
}
|
||||
|
||||
pub fn shard_model_by_partition(
|
||||
partitions: Vec<Partition>,
|
||||
total_layers: u8,
|
||||
model_id: &str,
|
||||
) -> Vec<ModelShard> {
|
||||
let mut shards: Vec<ModelShard<'_>> = Vec::with_capacity(partitions.len());
|
||||
|
||||
for partition in partitions {
|
||||
let start_layer = (partition.start * total_layers as f32).round() as u8;
|
||||
let mut end_layer = (partition.end * total_layers as f32).round() as u8 - 1;
|
||||
|
||||
if end_layer < start_layer {
|
||||
end_layer = total_layers - 1;
|
||||
}
|
||||
|
||||
shards.push(ModelShard {
|
||||
model_id,
|
||||
start_layer,
|
||||
end_layer,
|
||||
total_layers,
|
||||
});
|
||||
}
|
||||
|
||||
shards
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user