Add partition and shards

This commit is contained in:
Joshua Coles 2025-02-12 14:08:10 +00:00
parent 3aca4a22ae
commit 02694cbacc
3 changed files with 69 additions and 1 deletions

View File

@ -1,4 +1,3 @@
use std::cell::RefCell;
use std::collections::HashMap;
use crate::network::get_broadcast_creation_info;
use crate::topology::DeviceCapabilities;

View File

@ -3,6 +3,7 @@ mod discovery;
mod network;
mod orchestration;
mod topology;
mod partitioning;
use serde::{Deserialize, Serialize};
use serde_json::Value;

68
src/partitioning.rs Normal file
View File

@ -0,0 +1,68 @@
use crate::topology::Topology;
pub enum PartitionStrategy {
RingMemoryWeighted,
}
pub struct Partition {
node_id: String,
start: f32,
end: f32,
}
impl PartitionStrategy {
pub fn partition(&self, topology: &Topology) -> Vec<Partition> {
let mut entries = topology.nodes.iter().collect::<Vec<_>>();
entries.sort_by_key(|(node_id, device_capabilities)| {
(device_capabilities.memory, node_id.clone())
});
let mut start = 0.0;
let mut partitions = Vec::with_capacity(entries.len());
for (node_id, device_capabilities) in entries {
let end = ((start + device_capabilities.memory as f32) * 100000.0).round() / 100000.0;
partitions.push(Partition {
node_id: node_id.to_string(),
start,
end,
});
start = end;
}
partitions
}
}
pub struct ModelShard<'a> {
model_id: &'a str,
start_layer: u8,
end_layer: u8,
total_layers: u8,
}
pub fn shard_model_by_partition(
partitions: Vec<Partition>,
total_layers: u8,
model_id: &str,
) -> Vec<ModelShard> {
let mut shards: Vec<ModelShard<'_>> = Vec::with_capacity(partitions.len());
for partition in partitions {
let start_layer = (partition.start * total_layers as f32).round() as u8;
let mut end_layer = (partition.end * total_layers as f32).round() as u8 - 1;
if end_layer < start_layer {
end_layer = total_layers - 1;
}
shards.push(ModelShard {
model_id,
start_layer,
end_layer,
total_layers,
});
}
shards
}