diff --git a/Cargo.lock b/Cargo.lock index 55e086b..3a33733 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -145,12 +145,38 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "bindgen" +version = "0.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f49d8fed880d473ea71efb9bf597651e77201bdd4893efe54c9e5d65ae04ce6f" +dependencies = [ + "bitflags", + "cexpr", + "clang-sys", + "itertools 0.13.0", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "syn", +] + [[package]] name = "bitflags" version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" +[[package]] +name = "bytemuck" +version = "1.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef657dfab802224e671f5818e9a4935f9b1957ed18e58292690cc39e7a4092a3" + [[package]] name = "byteorder" version = "1.5.0" @@ -172,12 +198,41 @@ dependencies = [ "shlex", ] +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + [[package]] name = "cfg-if" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "clang-sys" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +dependencies = [ + "glob", + "libc", + "libloading", +] + +[[package]] +name = "cmake" +version = "0.1.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +dependencies = [ + "cc", +] + [[package]] name = "core-foundation" version = "0.9.4" @@ -203,6 +258,53 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crunchy" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" + +[[package]] +name = "darling" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" +dependencies = [ + "darling_core", + "darling_macro", +] + +[[package]] +name = "darling_core" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn", +] + +[[package]] +name = "darling_macro" +version = "0.20.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" +dependencies = [ + "darling_core", + "quote", + "syn", +] + +[[package]] +name = "dyn-clone" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "feeef44e73baff3a26d371801df019877a9866a8c493d315ab00177843314f35" + [[package]] name = "either" version = "1.13.0" @@ -229,6 +331,7 @@ dependencies = [ name = "exo-rs" version = "0.1.0" dependencies = [ + "mlx-rs", "network-interface", "phf", "prost", @@ -342,6 +445,12 @@ version = "0.31.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" +[[package]] +name = "glob" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" + [[package]] name = "h2" version = "0.4.7" @@ -361,6 +470,16 @@ dependencies = [ "tracing", ] +[[package]] +name = "half" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +dependencies = [ + "cfg-if", + "crunchy", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -478,6 +597,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "ident_case" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" + [[package]] name = "indexmap" version = "1.9.3" @@ -507,6 +632,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.14" @@ -525,6 +659,16 @@ version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" +[[package]] +name = "libloading" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34" +dependencies = [ + "cfg-if", + "windows-targets", +] + [[package]] name = "linux-raw-sys" version = "0.4.15" @@ -547,6 +691,15 @@ version = "0.4.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" +[[package]] +name = "mach-sys" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48460c2e82a3a0de197152fdf8d2c2d5e43adc501501553e439bf2156e6f87c7" +dependencies = [ + "fastrand", +] + [[package]] name = "matchit" version = "0.7.3" @@ -565,6 +718,12 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "miniz_oxide" version = "0.8.4" @@ -585,6 +744,67 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "mlx-internal-macros" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0756e4528d38dfd2c30551e3cb05f42b346d4b9fd14a867767d86353232056d" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "mlx-macros" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "177ff342309c789defa1552e763ea8fbb5548e3ec17134a45009a27fbddb6c26" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "mlx-rs" +version = "0.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c686ead28a57db28004d2c72f940bb3b4366ad01649899cacd06bca495f93ca" +dependencies = [ + "bytemuck", + "dyn-clone", + "half", + "itertools 0.14.0", + "libc", + "mach-sys", + "mlx-internal-macros", + "mlx-macros", + "mlx-sys", + "num-complex", + "num-traits", + "num_enum", + "parking_lot", + "paste", + "safetensors", + "smallvec", + "strum", + "thiserror 1.0.69", +] + +[[package]] +name = "mlx-sys" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af33a6b662998e5bb4099b1a191b4352fcb11d97706e82e4c8922fe200bb11f2" +dependencies = [ + "bindgen", + "cc", + "cmake", +] + [[package]] name = "multimap" version = "0.10.0" @@ -603,6 +823,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -613,6 +843,45 @@ dependencies = [ "winapi", ] +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "num_enum" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e613fc340b2220f734a8595782c551f1250e969d87d3be1ae0579e8d4065179" +dependencies = [ + "num_enum_derive", +] + +[[package]] +name = "num_enum_derive" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af1844ef2428cc3e1cb900be36181049ef3d3193c63e43026cfe202983b27a56" +dependencies = [ + "proc-macro-crate", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "object" version = "0.36.7" @@ -657,6 +926,12 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + [[package]] name = "percent-encoding" version = "2.3.1" @@ -766,6 +1041,15 @@ dependencies = [ "syn", ] +[[package]] +name = "proc-macro-crate" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecf48c7ca261d60b74ab1a7b20da18bede46776b2e55535cb958eb595c5fa7b" +dependencies = [ + "toml_edit", +] + [[package]] name = "proc-macro2" version = "1.0.93" @@ -792,7 +1076,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0f3e5beed80eb580c68e2c600937ac2c4eedabdfd5ef1e5b7ea4f3fba84497b" dependencies = [ "heck", - "itertools", + "itertools 0.13.0", "log", "multimap", "once_cell", @@ -812,7 +1096,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "157c5a9d7ea5c2ed2d9fb8f495b64759f7816c7eaea54ba3978f0d63000162e3" dependencies = [ "anyhow", - "itertools", + "itertools 0.13.0", "proc-macro2", "quote", "syn", @@ -910,6 +1194,12 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rustix" version = "0.38.44" @@ -935,6 +1225,16 @@ version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" +[[package]] +name = "safetensors" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0436dbfa2778e4ec1a00801b0ae24a1dd619499247d48b0589b679103379d0d4" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -1028,6 +1328,34 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "strum" +version = "0.26.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "rustversion", + "syn", +] + [[package]] name = "syn" version = "2.0.98" @@ -1183,6 +1511,23 @@ dependencies = [ "tokio", ] +[[package]] +name = "toml_datetime" +version = "0.6.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" + +[[package]] +name = "toml_edit" +version = "0.22.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474" +dependencies = [ + "indexmap 2.7.1", + "toml_datetime", + "winnow", +] + [[package]] name = "tonic" version = "0.12.3" @@ -1477,6 +1822,15 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "winnow" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59690dea168f2198d1a3b0cac23b8063efcd11012f10ae4698f284808c8ef603" +dependencies = [ + "memchr", +] + [[package]] name = "wit-bindgen-rt" version = "0.33.0" diff --git a/Cargo.toml b/Cargo.toml index 6025a36..7c094f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,7 @@ network-interface = "2.0.0" uuid = { version = "1.13.1", features = ["v4"] } regex = "1.11.1" phf = { version = "0.11.3", features = ["macros"] } +mlx-rs = { version = "0.21.0", features = ["metal", "accelerate", "safetensors"] } [build-dependencies] tonic-build = "0.12.3" diff --git a/src/discovery/mod.rs b/src/discovery/mod.rs index babf2c7..63f8fcf 100644 --- a/src/discovery/mod.rs +++ b/src/discovery/mod.rs @@ -69,6 +69,7 @@ impl Default for NodeInfo { } } +#[derive(Debug)] pub struct UdpDiscovery { discovery_handle: JoinHandle<()>, presence_handle: JoinHandle<()>, diff --git a/src/main.rs b/src/main.rs index d52be8e..50c27f8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,24 +11,53 @@ use tonic::{transport::Server, Request, Response, Status}; use crate::discovery::{NodeInfo, UdpDiscovery}; use crate::node_service::{ - CollectTopologyRequest, Empty, ExampleRequest, HealthCheckRequest, HealthCheckResponse, Loss, - PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor, Topology as TopologyProto, + CollectTopologyRequest, Empty, ExampleRequest, HealthCheckRequest, HealthCheckResponse, + InferenceState, Loss, PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor, + Topology as TopologyProto, }; use node_service::node_service_server::{NodeService, NodeServiceServer}; use node_service::TensorRequest; use std::collections::HashSet; use topology::Topology; +use uuid::Uuid; pub mod node_service { tonic::include_proto!("node_service"); // The string specified here must match the proto package name } +#[derive(Debug)] struct Node { node_info: NodeInfo, current_topology: Topology, udp_discovery: UdpDiscovery, } +impl Node { + #[tracing::instrument] + pub async fn process_prompt( + &self, + base_shard: Shard, + prompt: String, + request_id: String, + inference_state: Option, + ) { + let shard = self + .current_topology + .get_shard_for_node(base_shard, &self.node_info.node_id); + + todo!(); + // if shard.is_first_layer() { + // let result = self + // .inference_engine + // .infer_prompt(request_id, shard, prompt, inference_state) + // .await; + // self.process_inference_result(shard, result, request_id, inference_state) + // } else { + // self.forward_prompt(shard, prompt, request_id, inference_state) + // } + } +} + impl Default for Node { fn default() -> Self { let node_info = NodeInfo::default(); @@ -52,9 +81,35 @@ enum OpaqueStatus { #[derive(Debug, Deserialize, Serialize, Clone)] struct Shard { pub model_id: String, - pub start_layer: i32, - pub end_layer: i32, - pub n_layers: i32, + pub start_layer: u32, + pub end_layer: u32, + #[serde(rename = "n_layers")] + pub total_layers: u32, +} + +impl Shard { + pub fn is_first_layer(&self) -> bool { + self.start_layer == 0 + } + + pub fn is_last_layer(&self) -> bool { + self.end_layer == self.total_layers - 1 + } + + pub fn len(&self) -> u32 { + self.end_layer - self.start_layer + 1 + } +} + +impl From for Shard { + fn from(proto: node_service::Shard) -> Self { + Self { + model_id: proto.model_id, + start_layer: proto.start_layer as u32, + end_layer: proto.end_layer as u32, + total_layers: proto.n_layers as u32, + } + } } #[derive(Debug, Deserialize, Serialize, Clone)] @@ -141,7 +196,22 @@ impl NodeService for Node { &self, request: Request, ) -> Result, Status> { - todo!() + let request = request.into_inner(); + let request_id = request + .request_id + .unwrap_or_else(|| Uuid::new_v4().to_string()); + + let result = self.process_prompt( + request + .shard + .expect("No shard given. ExoPy does not allow this") + .into(), + request.prompt, + request_id, + request.inference_state, + ); + + todo!(); } async fn send_tensor( diff --git a/src/partitioning.rs b/src/partitioning.rs index 7db3b91..634c2e7 100644 --- a/src/partitioning.rs +++ b/src/partitioning.rs @@ -1,3 +1,4 @@ +use crate::Shard; use crate::topology::Topology; pub enum PartitionStrategy { @@ -5,9 +6,9 @@ pub enum PartitionStrategy { } pub struct Partition { - node_id: String, - start: f32, - end: f32, + pub node_id: String, + pub start: f32, + pub end: f32, } impl PartitionStrategy { @@ -34,30 +35,23 @@ impl PartitionStrategy { } } -pub struct ModelShard<'a> { - model_id: &'a str, - start_layer: u8, - end_layer: u8, - total_layers: u8, -} - pub fn shard_model_by_partition( - partitions: Vec, - total_layers: u8, + partition_set: &[Partition], + total_layers: u32, model_id: &str, -) -> Vec { - let mut shards: Vec> = Vec::with_capacity(partitions.len()); +) -> Vec { + let mut shards: Vec = Vec::with_capacity(partition_set.len()); - for partition in partitions { - let start_layer = (partition.start * total_layers as f32).round() as u8; - let mut end_layer = (partition.end * total_layers as f32).round() as u8 - 1; + for partition in partition_set { + let start_layer = (partition.start * total_layers as f32).round() as u32; + let mut end_layer = (partition.end * total_layers as f32).round() as u32 - 1; if end_layer < start_layer { end_layer = total_layers - 1; } - shards.push(ModelShard { - model_id, + shards.push(Shard { + model_id: model_id.to_string(), start_layer, end_layer, total_layers, diff --git a/src/topology.rs b/src/topology.rs index f45e543..0b76225 100644 --- a/src/topology.rs +++ b/src/topology.rs @@ -1,4 +1,5 @@ -use crate::{device_capability_data, node_service}; +use crate::partitioning::{shard_model_by_partition, PartitionStrategy}; +use crate::{device_capability_data, node_service, Shard}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::process::Command; @@ -10,6 +11,26 @@ pub struct Topology { pub active_node_id: Option, } +impl Topology { + pub fn get_shard_for_node(&self, base_shard: Shard, node_id: &str) -> Shard { + let partition_set = PartitionStrategy::RingMemoryWeighted.partition(&self); + + // TODO: This feels like it could be a better data structure + let partition_index = partition_set + .iter() + .position(|s| s.node_id == node_id) + .expect("Did not find node in partition set"); + + let shards = shard_model_by_partition( + &partition_set, + base_shard.total_layers.try_into().unwrap(), + base_shard.model_id.as_str(), + ); + + shards[partition_index].clone() + } +} + impl Topology { pub fn update_node(&mut self, node_id: String, device_capabilities: DeviceCapabilities) { self.nodes.insert(node_id, device_capabilities); @@ -48,8 +69,8 @@ impl Topology { } } -impl From for Topology { - fn from(proto: crate::node_service::Topology) -> Self { +impl From for Topology { + fn from(proto: node_service::Topology) -> Self { let nodes = proto .nodes .into_iter() @@ -83,8 +104,8 @@ impl From for Topology { } } -impl Into for Topology { - fn into(self) -> crate::node_service::Topology { +impl Into for Topology { + fn into(self) -> node_service::Topology { let nodes = self .nodes .iter() @@ -210,8 +231,8 @@ impl DeviceCapabilities { } } -impl From for DeviceCapabilities { - fn from(value: crate::node_service::DeviceCapabilities) -> Self { +impl From for DeviceCapabilities { + fn from(value: node_service::DeviceCapabilities) -> Self { DeviceCapabilities { model: value.model, chip: value.chip, @@ -228,8 +249,8 @@ pub struct DeviceFlops { pub int8: f64, } -impl From for DeviceFlops { - fn from(value: crate::node_service::DeviceFlops) -> Self { +impl From for DeviceFlops { + fn from(value: node_service::DeviceFlops) -> Self { DeviceFlops { fp32: value.fp32, fp16: value.fp16,