Add partition and shards
This commit is contained in:
parent
3aca4a22ae
commit
02694cbacc
@ -1,4 +1,3 @@
|
|||||||
use std::cell::RefCell;
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use crate::network::get_broadcast_creation_info;
|
use crate::network::get_broadcast_creation_info;
|
||||||
use crate::topology::DeviceCapabilities;
|
use crate::topology::DeviceCapabilities;
|
||||||
|
|||||||
@ -3,6 +3,7 @@ mod discovery;
|
|||||||
mod network;
|
mod network;
|
||||||
mod orchestration;
|
mod orchestration;
|
||||||
mod topology;
|
mod topology;
|
||||||
|
mod partitioning;
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
|||||||
68
src/partitioning.rs
Normal file
68
src/partitioning.rs
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
use crate::topology::Topology;
|
||||||
|
|
||||||
|
pub enum PartitionStrategy {
|
||||||
|
RingMemoryWeighted,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Partition {
|
||||||
|
node_id: String,
|
||||||
|
start: f32,
|
||||||
|
end: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartitionStrategy {
|
||||||
|
pub fn partition(&self, topology: &Topology) -> Vec<Partition> {
|
||||||
|
let mut entries = topology.nodes.iter().collect::<Vec<_>>();
|
||||||
|
entries.sort_by_key(|(node_id, device_capabilities)| {
|
||||||
|
(device_capabilities.memory, node_id.clone())
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut start = 0.0;
|
||||||
|
let mut partitions = Vec::with_capacity(entries.len());
|
||||||
|
|
||||||
|
for (node_id, device_capabilities) in entries {
|
||||||
|
let end = ((start + device_capabilities.memory as f32) * 100000.0).round() / 100000.0;
|
||||||
|
partitions.push(Partition {
|
||||||
|
node_id: node_id.to_string(),
|
||||||
|
start,
|
||||||
|
end,
|
||||||
|
});
|
||||||
|
start = end;
|
||||||
|
}
|
||||||
|
|
||||||
|
partitions
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ModelShard<'a> {
|
||||||
|
model_id: &'a str,
|
||||||
|
start_layer: u8,
|
||||||
|
end_layer: u8,
|
||||||
|
total_layers: u8,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn shard_model_by_partition(
|
||||||
|
partitions: Vec<Partition>,
|
||||||
|
total_layers: u8,
|
||||||
|
model_id: &str,
|
||||||
|
) -> Vec<ModelShard> {
|
||||||
|
let mut shards: Vec<ModelShard<'_>> = Vec::with_capacity(partitions.len());
|
||||||
|
|
||||||
|
for partition in partitions {
|
||||||
|
let start_layer = (partition.start * total_layers as f32).round() as u8;
|
||||||
|
let mut end_layer = (partition.end * total_layers as f32).round() as u8 - 1;
|
||||||
|
|
||||||
|
if end_layer < start_layer {
|
||||||
|
end_layer = total_layers - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
shards.push(ModelShard {
|
||||||
|
model_id,
|
||||||
|
start_layer,
|
||||||
|
end_layer,
|
||||||
|
total_layers,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
shards
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user