Stash initial attempt at implementing process prompt before being sad

This commit is contained in:
Joshua Coles 2025-02-12 14:55:26 +00:00
parent 8d91f64dbf
commit a631a1c0a9
6 changed files with 477 additions and 36 deletions

358
Cargo.lock generated
View File

@ -145,12 +145,38 @@ version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]]
name = "bindgen"
version = "0.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f49d8fed880d473ea71efb9bf597651e77201bdd4893efe54c9e5d65ae04ce6f"
dependencies = [
"bitflags",
"cexpr",
"clang-sys",
"itertools 0.13.0",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash",
"shlex",
"syn",
]
[[package]]
name = "bitflags"
version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36"
[[package]]
name = "bytemuck"
version = "1.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef657dfab802224e671f5818e9a4935f9b1957ed18e58292690cc39e7a4092a3"
[[package]]
name = "byteorder"
version = "1.5.0"
@ -172,12 +198,41 @@ dependencies = [
"shlex",
]
[[package]]
name = "cexpr"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
dependencies = [
"nom",
]
[[package]]
name = "cfg-if"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "clang-sys"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
dependencies = [
"glob",
"libc",
"libloading",
]
[[package]]
name = "cmake"
version = "0.1.54"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0"
dependencies = [
"cc",
]
[[package]]
name = "core-foundation"
version = "0.9.4"
@ -203,6 +258,53 @@ dependencies = [
"cfg-if",
]
[[package]]
name = "crunchy"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929"
[[package]]
name = "darling"
version = "0.20.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989"
dependencies = [
"darling_core",
"darling_macro",
]
[[package]]
name = "darling_core"
version = "0.20.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5"
dependencies = [
"fnv",
"ident_case",
"proc-macro2",
"quote",
"strsim",
"syn",
]
[[package]]
name = "darling_macro"
version = "0.20.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806"
dependencies = [
"darling_core",
"quote",
"syn",
]
[[package]]
name = "dyn-clone"
version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "feeef44e73baff3a26d371801df019877a9866a8c493d315ab00177843314f35"
[[package]]
name = "either"
version = "1.13.0"
@ -229,6 +331,7 @@ dependencies = [
name = "exo-rs"
version = "0.1.0"
dependencies = [
"mlx-rs",
"network-interface",
"phf",
"prost",
@ -342,6 +445,12 @@ version = "0.31.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
[[package]]
name = "glob"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2"
[[package]]
name = "h2"
version = "0.4.7"
@ -361,6 +470,16 @@ dependencies = [
"tracing",
]
[[package]]
name = "half"
version = "2.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888"
dependencies = [
"cfg-if",
"crunchy",
]
[[package]]
name = "hashbrown"
version = "0.12.3"
@ -478,6 +597,12 @@ dependencies = [
"tracing",
]
[[package]]
name = "ident_case"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
[[package]]
name = "indexmap"
version = "1.9.3"
@ -507,6 +632,15 @@ dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "1.0.14"
@ -525,6 +659,16 @@ version = "0.2.169"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a"
[[package]]
name = "libloading"
version = "0.8.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
dependencies = [
"cfg-if",
"windows-targets",
]
[[package]]
name = "linux-raw-sys"
version = "0.4.15"
@ -547,6 +691,15 @@ version = "0.4.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f"
[[package]]
name = "mach-sys"
version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48460c2e82a3a0de197152fdf8d2c2d5e43adc501501553e439bf2156e6f87c7"
dependencies = [
"fastrand",
]
[[package]]
name = "matchit"
version = "0.7.3"
@ -565,6 +718,12 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "minimal-lexical"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
[[package]]
name = "miniz_oxide"
version = "0.8.4"
@ -585,6 +744,67 @@ dependencies = [
"windows-sys",
]
[[package]]
name = "mlx-internal-macros"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f0756e4528d38dfd2c30551e3cb05f42b346d4b9fd14a867767d86353232056d"
dependencies = [
"darling",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "mlx-macros"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "177ff342309c789defa1552e763ea8fbb5548e3ec17134a45009a27fbddb6c26"
dependencies = [
"darling",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "mlx-rs"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c686ead28a57db28004d2c72f940bb3b4366ad01649899cacd06bca495f93ca"
dependencies = [
"bytemuck",
"dyn-clone",
"half",
"itertools 0.14.0",
"libc",
"mach-sys",
"mlx-internal-macros",
"mlx-macros",
"mlx-sys",
"num-complex",
"num-traits",
"num_enum",
"parking_lot",
"paste",
"safetensors",
"smallvec",
"strum",
"thiserror 1.0.69",
]
[[package]]
name = "mlx-sys"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af33a6b662998e5bb4099b1a191b4352fcb11d97706e82e4c8922fe200bb11f2"
dependencies = [
"bindgen",
"cc",
"cmake",
]
[[package]]
name = "multimap"
version = "0.10.0"
@ -603,6 +823,16 @@ dependencies = [
"winapi",
]
[[package]]
name = "nom"
version = "7.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a"
dependencies = [
"memchr",
"minimal-lexical",
]
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
@ -613,6 +843,45 @@ dependencies = [
"winapi",
]
[[package]]
name = "num-complex"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495"
dependencies = [
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841"
dependencies = [
"autocfg",
]
[[package]]
name = "num_enum"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e613fc340b2220f734a8595782c551f1250e969d87d3be1ae0579e8d4065179"
dependencies = [
"num_enum_derive",
]
[[package]]
name = "num_enum_derive"
version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af1844ef2428cc3e1cb900be36181049ef3d3193c63e43026cfe202983b27a56"
dependencies = [
"proc-macro-crate",
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "object"
version = "0.36.7"
@ -657,6 +926,12 @@ dependencies = [
"windows-targets",
]
[[package]]
name = "paste"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a"
[[package]]
name = "percent-encoding"
version = "2.3.1"
@ -766,6 +1041,15 @@ dependencies = [
"syn",
]
[[package]]
name = "proc-macro-crate"
version = "3.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ecf48c7ca261d60b74ab1a7b20da18bede46776b2e55535cb958eb595c5fa7b"
dependencies = [
"toml_edit",
]
[[package]]
name = "proc-macro2"
version = "1.0.93"
@ -792,7 +1076,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0f3e5beed80eb580c68e2c600937ac2c4eedabdfd5ef1e5b7ea4f3fba84497b"
dependencies = [
"heck",
"itertools",
"itertools 0.13.0",
"log",
"multimap",
"once_cell",
@ -812,7 +1096,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "157c5a9d7ea5c2ed2d9fb8f495b64759f7816c7eaea54ba3978f0d63000162e3"
dependencies = [
"anyhow",
"itertools",
"itertools 0.13.0",
"proc-macro2",
"quote",
"syn",
@ -910,6 +1194,12 @@ version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustc-hash"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "rustix"
version = "0.38.44"
@ -935,6 +1225,16 @@ version = "1.0.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd"
[[package]]
name = "safetensors"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0436dbfa2778e4ec1a00801b0ae24a1dd619499247d48b0589b679103379d0d4"
dependencies = [
"serde",
"serde_json",
]
[[package]]
name = "scopeguard"
version = "1.2.0"
@ -1028,6 +1328,34 @@ dependencies = [
"windows-sys",
]
[[package]]
name = "strsim"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "strum"
version = "0.26.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8fec0f0aef304996cf250b31b5a10dee7980c85da9d759361292b8bca5a18f06"
dependencies = [
"strum_macros",
]
[[package]]
name = "strum_macros"
version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c6bee85a5a24955dc440386795aa378cd9cf82acd5f764469152d2270e581be"
dependencies = [
"heck",
"proc-macro2",
"quote",
"rustversion",
"syn",
]
[[package]]
name = "syn"
version = "2.0.98"
@ -1183,6 +1511,23 @@ dependencies = [
"tokio",
]
[[package]]
name = "toml_datetime"
version = "0.6.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41"
[[package]]
name = "toml_edit"
version = "0.22.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474"
dependencies = [
"indexmap 2.7.1",
"toml_datetime",
"winnow",
]
[[package]]
name = "tonic"
version = "0.12.3"
@ -1477,6 +1822,15 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "winnow"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59690dea168f2198d1a3b0cac23b8063efcd11012f10ae4698f284808c8ef603"
dependencies = [
"memchr",
]
[[package]]
name = "wit-bindgen-rt"
version = "0.33.0"

View File

@ -18,6 +18,7 @@ network-interface = "2.0.0"
uuid = { version = "1.13.1", features = ["v4"] }
regex = "1.11.1"
phf = { version = "0.11.3", features = ["macros"] }
mlx-rs = { version = "0.21.0", features = ["metal", "accelerate", "safetensors"] }
[build-dependencies]
tonic-build = "0.12.3"

View File

@ -69,6 +69,7 @@ impl Default for NodeInfo {
}
}
#[derive(Debug)]
pub struct UdpDiscovery {
discovery_handle: JoinHandle<()>,
presence_handle: JoinHandle<()>,

View File

@ -11,24 +11,53 @@ use tonic::{transport::Server, Request, Response, Status};
use crate::discovery::{NodeInfo, UdpDiscovery};
use crate::node_service::{
CollectTopologyRequest, Empty, ExampleRequest, HealthCheckRequest, HealthCheckResponse, Loss,
PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor, Topology as TopologyProto,
CollectTopologyRequest, Empty, ExampleRequest, HealthCheckRequest, HealthCheckResponse,
InferenceState, Loss, PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor,
Topology as TopologyProto,
};
use node_service::node_service_server::{NodeService, NodeServiceServer};
use node_service::TensorRequest;
use std::collections::HashSet;
use topology::Topology;
use uuid::Uuid;
pub mod node_service {
tonic::include_proto!("node_service"); // The string specified here must match the proto package name
}
#[derive(Debug)]
struct Node {
node_info: NodeInfo,
current_topology: Topology,
udp_discovery: UdpDiscovery,
}
impl Node {
#[tracing::instrument]
pub async fn process_prompt(
&self,
base_shard: Shard,
prompt: String,
request_id: String,
inference_state: Option<InferenceState>,
) {
let shard = self
.current_topology
.get_shard_for_node(base_shard, &self.node_info.node_id);
todo!();
// if shard.is_first_layer() {
// let result = self
// .inference_engine
// .infer_prompt(request_id, shard, prompt, inference_state)
// .await;
// self.process_inference_result(shard, result, request_id, inference_state)
// } else {
// self.forward_prompt(shard, prompt, request_id, inference_state)
// }
}
}
impl Default for Node {
fn default() -> Self {
let node_info = NodeInfo::default();
@ -52,9 +81,35 @@ enum OpaqueStatus {
#[derive(Debug, Deserialize, Serialize, Clone)]
struct Shard {
pub model_id: String,
pub start_layer: i32,
pub end_layer: i32,
pub n_layers: i32,
pub start_layer: u32,
pub end_layer: u32,
#[serde(rename = "n_layers")]
pub total_layers: u32,
}
impl Shard {
pub fn is_first_layer(&self) -> bool {
self.start_layer == 0
}
pub fn is_last_layer(&self) -> bool {
self.end_layer == self.total_layers - 1
}
pub fn len(&self) -> u32 {
self.end_layer - self.start_layer + 1
}
}
impl From<node_service::Shard> for Shard {
fn from(proto: node_service::Shard) -> Self {
Self {
model_id: proto.model_id,
start_layer: proto.start_layer as u32,
end_layer: proto.end_layer as u32,
total_layers: proto.n_layers as u32,
}
}
}
#[derive(Debug, Deserialize, Serialize, Clone)]
@ -141,7 +196,22 @@ impl NodeService for Node {
&self,
request: Request<PromptRequest>,
) -> Result<Response<Tensor>, Status> {
todo!()
let request = request.into_inner();
let request_id = request
.request_id
.unwrap_or_else(|| Uuid::new_v4().to_string());
let result = self.process_prompt(
request
.shard
.expect("No shard given. ExoPy does not allow this")
.into(),
request.prompt,
request_id,
request.inference_state,
);
todo!();
}
async fn send_tensor(

View File

@ -1,3 +1,4 @@
use crate::Shard;
use crate::topology::Topology;
pub enum PartitionStrategy {
@ -5,9 +6,9 @@ pub enum PartitionStrategy {
}
pub struct Partition {
node_id: String,
start: f32,
end: f32,
pub node_id: String,
pub start: f32,
pub end: f32,
}
impl PartitionStrategy {
@ -34,30 +35,23 @@ impl PartitionStrategy {
}
}
pub struct ModelShard<'a> {
model_id: &'a str,
start_layer: u8,
end_layer: u8,
total_layers: u8,
}
pub fn shard_model_by_partition(
partitions: Vec<Partition>,
total_layers: u8,
partition_set: &[Partition],
total_layers: u32,
model_id: &str,
) -> Vec<ModelShard> {
let mut shards: Vec<ModelShard<'_>> = Vec::with_capacity(partitions.len());
) -> Vec<Shard> {
let mut shards: Vec<Shard> = Vec::with_capacity(partition_set.len());
for partition in partitions {
let start_layer = (partition.start * total_layers as f32).round() as u8;
let mut end_layer = (partition.end * total_layers as f32).round() as u8 - 1;
for partition in partition_set {
let start_layer = (partition.start * total_layers as f32).round() as u32;
let mut end_layer = (partition.end * total_layers as f32).round() as u32 - 1;
if end_layer < start_layer {
end_layer = total_layers - 1;
}
shards.push(ModelShard {
model_id,
shards.push(Shard {
model_id: model_id.to_string(),
start_layer,
end_layer,
total_layers,

View File

@ -1,4 +1,5 @@
use crate::{device_capability_data, node_service};
use crate::partitioning::{shard_model_by_partition, PartitionStrategy};
use crate::{device_capability_data, node_service, Shard};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::process::Command;
@ -10,6 +11,26 @@ pub struct Topology {
pub active_node_id: Option<String>,
}
impl Topology {
pub fn get_shard_for_node(&self, base_shard: Shard, node_id: &str) -> Shard {
let partition_set = PartitionStrategy::RingMemoryWeighted.partition(&self);
// TODO: This feels like it could be a better data structure
let partition_index = partition_set
.iter()
.position(|s| s.node_id == node_id)
.expect("Did not find node in partition set");
let shards = shard_model_by_partition(
&partition_set,
base_shard.total_layers.try_into().unwrap(),
base_shard.model_id.as_str(),
);
shards[partition_index].clone()
}
}
impl Topology {
pub fn update_node(&mut self, node_id: String, device_capabilities: DeviceCapabilities) {
self.nodes.insert(node_id, device_capabilities);
@ -48,8 +69,8 @@ impl Topology {
}
}
impl From<crate::node_service::Topology> for Topology {
fn from(proto: crate::node_service::Topology) -> Self {
impl From<node_service::Topology> for Topology {
fn from(proto: node_service::Topology) -> Self {
let nodes = proto
.nodes
.into_iter()
@ -83,8 +104,8 @@ impl From<crate::node_service::Topology> for Topology {
}
}
impl Into<crate::node_service::Topology> for Topology {
fn into(self) -> crate::node_service::Topology {
impl Into<node_service::Topology> for Topology {
fn into(self) -> node_service::Topology {
let nodes = self
.nodes
.iter()
@ -210,8 +231,8 @@ impl DeviceCapabilities {
}
}
impl From<crate::node_service::DeviceCapabilities> for DeviceCapabilities {
fn from(value: crate::node_service::DeviceCapabilities) -> Self {
impl From<node_service::DeviceCapabilities> for DeviceCapabilities {
fn from(value: node_service::DeviceCapabilities) -> Self {
DeviceCapabilities {
model: value.model,
chip: value.chip,
@ -228,8 +249,8 @@ pub struct DeviceFlops {
pub int8: f64,
}
impl From<crate::node_service::DeviceFlops> for DeviceFlops {
fn from(value: crate::node_service::DeviceFlops) -> Self {
impl From<node_service::DeviceFlops> for DeviceFlops {
fn from(value: node_service::DeviceFlops) -> Self {
DeviceFlops {
fp32: value.fp32,
fp16: value.fp16,