From 4cd96b58b5808694a3a0cbce26b8c68f9364f5f9 Mon Sep 17 00:00:00 2001 From: Joshua Coles Date: Wed, 12 Feb 2025 17:07:06 +0000 Subject: [PATCH] Stash gen of LLAMA mlx --- Cargo.lock | 217 ++++++++++++++++++++++++++++++++++++++-- Cargo.toml | 3 + src/inference.rs | 19 ++++ src/llama_module.rs | 226 ++++++++++++++++++++++++++++++++++++++++++ src/llama_test.rs | 73 ++++++++++++++ src/main.rs | 11 +- src/module_loading.rs | 52 ++++++++++ 7 files changed, 593 insertions(+), 8 deletions(-) create mode 100644 src/inference.rs create mode 100644 src/llama_module.rs create mode 100644 src/llama_test.rs create mode 100644 src/module_loading.rs diff --git a/Cargo.lock b/Cargo.lock index 3a33733..f069f5a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -145,6 +145,29 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "bindgen" +version = "0.69.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" +dependencies = [ + "bitflags", + "cexpr", + "clang-sys", + "itertools 0.12.1", + "lazy_static", + "lazycell", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "syn", + "which", +] + [[package]] name = "bindgen" version = "0.70.1" @@ -195,6 +218,8 @@ version = "1.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7777341816418c02e033934a09f20dc0ccaf65a5201ef8a450ae0105a573fda" dependencies = [ + "jobserver", + "libc", "shlex", ] @@ -311,6 +336,26 @@ version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +[[package]] +name = "enumflags2" +version = "0.7.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba2f4b465f5318854c6f8dd686ede6c0a9dc67d4b1ac241cf0eb51521a309147" +dependencies = [ + "enumflags2_derive", +] + +[[package]] +name = "enumflags2_derive" +version = "0.7.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc4caf64a58d7a6d65ab00639b046ff54399a39f5f2554728895ace4b297cd79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "equivalent" version = "1.0.1" @@ -324,13 +369,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] name = "exo-rs" version = "0.1.0" dependencies = [ + "glob", + "llama-cpp-2", "mlx-rs", "network-interface", "phf", @@ -341,6 +388,7 @@ dependencies = [ "socket2", "system-configuration", "thiserror 2.0.11", + "tinygrad", "tokio", "tonic", "tonic-build", @@ -498,6 +546,15 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" +[[package]] +name = "home" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" +dependencies = [ + "windows-sys 0.59.0", +] + [[package]] name = "http" version = "1.2.0" @@ -623,6 +680,15 @@ dependencies = [ "hashbrown 0.15.2", ] +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.13.0" @@ -647,12 +713,27 @@ version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" +[[package]] +name = "jobserver" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +dependencies = [ + "libc", +] + [[package]] name = "lazy_static" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + [[package]] name = "libc" version = "0.2.169" @@ -675,6 +756,32 @@ version = "0.4.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" +[[package]] +name = "llama-cpp-2" +version = "0.1.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44818b6967a77379b8c8e105e2684d2bd2bca999ad24cfd806d8476a80c53255" +dependencies = [ + "enumflags2", + "llama-cpp-sys-2", + "thiserror 1.0.69", + "tracing", + "tracing-core", +] + +[[package]] +name = "llama-cpp-sys-2" +version = "0.1.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92add8e8dabf941518dd573075721a6e0568db86c6557e176cec99bf56883c0a" +dependencies = [ + "bindgen 0.69.5", + "cc", + "cmake", + "glob", + "walkdir", +] + [[package]] name = "lock_api" version = "0.4.12" @@ -706,6 +813,16 @@ version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" +[[package]] +name = "matrixmultiply" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "memchr" version = "2.7.4" @@ -741,7 +858,7 @@ checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd" dependencies = [ "libc", "wasi 0.11.0+wasi-snapshot-preview1", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -800,7 +917,7 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af33a6b662998e5bb4099b1a191b4352fcb11d97706e82e4c8922fe200bb11f2" dependencies = [ - "bindgen", + "bindgen 0.70.1", "cc", "cmake", ] @@ -811,6 +928,19 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03" +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "rawpointer", +] + [[package]] name = "network-interface" version = "2.0.0" @@ -852,6 +982,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -1150,6 +1289,12 @@ dependencies = [ "getrandom 0.2.15", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "redox_syscall" version = "0.5.8" @@ -1210,7 +1355,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -1235,6 +1380,15 @@ dependencies = [ "serde_json", ] +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -1325,7 +1479,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" dependencies = [ "libc", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -1405,7 +1559,7 @@ dependencies = [ "getrandom 0.3.1", "once_cell", "rustix", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -1458,6 +1612,15 @@ dependencies = [ "once_cell", ] +[[package]] +name = "tinygrad" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc8951e9ced45095eb3ad7342c4e74b038bda930833df79b7debb019bb653c18" +dependencies = [ + "ndarray", +] + [[package]] name = "tokio" version = "1.43.0" @@ -1473,7 +1636,7 @@ dependencies = [ "signal-hook-registry", "socket2", "tokio-macros", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -1703,6 +1866,16 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + [[package]] name = "want" version = "0.3.1" @@ -1727,6 +1900,18 @@ dependencies = [ "wit-bindgen-rt", ] +[[package]] +name = "which" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" +dependencies = [ + "either", + "home", + "once_cell", + "rustix", +] + [[package]] name = "winapi" version = "0.3.9" @@ -1743,6 +1928,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" +[[package]] +name = "winapi-util" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" +dependencies = [ + "windows-sys 0.52.0", +] + [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" @@ -1758,6 +1952,15 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets", +] + [[package]] name = "windows-targets" version = "0.52.6" diff --git a/Cargo.toml b/Cargo.toml index 7c094f2..1aab321 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,9 @@ uuid = { version = "1.13.1", features = ["v4"] } regex = "1.11.1" phf = { version = "0.11.3", features = ["macros"] } mlx-rs = { version = "0.21.0", features = ["metal", "accelerate", "safetensors"] } +tinygrad = "0.1.0" +llama-cpp-2 = { version = "0.1.93", features = ["metal", "native", "sampler"] } +glob = "0.3.2" [build-dependencies] tonic-build = "0.12.3" diff --git a/src/inference.rs b/src/inference.rs new file mode 100644 index 0000000..ae1b208 --- /dev/null +++ b/src/inference.rs @@ -0,0 +1,19 @@ +use crate::node_service::{InferenceState, Tensor}; +use crate::Shard; + +#[derive(Debug)] +pub struct InferenceEngine { + state_cache: HashMap, +} + +impl InferenceEngine { + pub(crate) fn infer_tensor( + &self, + request_id: String, + shard: Shard, + tensor: Option, + inference_state: Option, + ) -> Tensor { + + } +} diff --git a/src/llama_module.rs b/src/llama_module.rs new file mode 100644 index 0000000..febebcf --- /dev/null +++ b/src/llama_module.rs @@ -0,0 +1,226 @@ +use mlx_rs::macros::ModuleParameters; +use mlx_rs::module::Module; +use mlx_rs::module::ModuleParameters; +use mlx_rs::module::Param; +use mlx_rs::nested::NestedHashMap; +use mlx_rs::nn; +use mlx_rs::Array; +use std::rc::Rc; + +// Define Shard struct to mirror Python dataclass +#[derive(Debug, Clone)] +pub struct Shard { + pub name: String, + pub start_layer: usize, + pub end_layer: usize, +} + +impl Shard { + pub fn new(name: String, start_layer: usize, end_layer: usize) -> Self { + Shard { + name, + start_layer, + end_layer, + } + } + + pub fn is_first_layer(&self) -> bool { + self.start_layer == 0 + } + + pub fn is_last_layer(&self) -> bool { + // Assuming end_layer is inclusive and represents the last layer index in the shard + // and num_hidden_layers is the total number of layers. + // We would need num_hidden_layers to accurately determine the last layer. + // For now, let's assume if end_layer is very large, it's the last layer in shard. + self.end_layer > 9999 // A large number as a placeholder, adjust as needed + } +} + +// Define ModelArgs struct to mirror Python dataclass ModelArgs +#[derive(Debug, Clone)] +pub struct ModelArgsRs { + pub vocab_size: i32, + pub hidden_size: i32, + pub num_hidden_layers: i32, + pub num_attention_heads: i32, + pub num_key_value_heads: i32, + pub rms_norm_eps: f32, + pub tie_word_embeddings: bool, + pub model_type: String, // Assuming model_type is a String + pub head_dim: Option, // Using Option to represent optional field + pub shard: Shard, // Using the Shard struct defined above +} + +// Placeholder for TransformerBlock - You'll need to implement this in Rust +#[derive(Debug, Clone, ModuleParameters)] +pub struct TransformerBlock { + // Define the layers within TransformerBlock as needed, e.g., attention, norm, etc. + // For now, using a linear layer as a placeholder + pub linear: nn::Linear, +} + +impl TransformerBlock { + pub fn new(dims: i32, mlp_dims: i32) -> Self { + Self { + linear: nn::Linear::new(dims, mlp_dims).unwrap(), // Example, adjust params + } + } +} + +impl Module<&Array> for TransformerBlock { + type Output = Array; + type Error = mlx_rs::error::Exception; + + fn forward(&mut self, input: &Array) -> Result { + // Implement the forward pass logic for TransformerBlock + // For now, just passing through the linear layer + self.linear.forward(input) + } + + fn training_mode(&mut self, mode: bool) {} +} + +// Define LlamaModel struct +#[derive(Debug, Clone, ModuleParameters)] +pub struct LlamaModelRs { + pub args: ModelArgsRs, + pub vocab_size: i32, + pub num_hidden_layers: i32, + pub embed_tokens: Option, // Embedding layer is optional based on sharding + pub layers: Vec, // Using placeholder TransformerBlock + pub norm: Option, // RMSNorm layer is optional based on sharding +} + +impl LlamaModelRs { + pub fn new(args: ModelArgsRs) -> Result { + let vocab_size = args.vocab_size; + let num_hidden_layers = args.num_hidden_layers; + let mut embed_tokens = None; + if args.shard.is_first_layer() || (args.shard.is_last_layer() && args.tie_word_embeddings) { + embed_tokens = Some(nn::Embedding::new(args.vocab_size, args.hidden_size)?); + } + + let mut layers = Vec::new(); + for i in 0..num_hidden_layers { + if args.shard.start_layer <= i && i <= args.shard.end_layer { + // Using placeholder dimensions for TransformerBlock, adjust as needed + layers.push(TransformerBlock::new( + args.hidden_size, + args.hidden_size * 4, + )); + } else { + // Placeholder for IdentityBlock - you might need to create a Rust version if needed + // For now, just pushing a default TransformerBlock or handle differently + layers.push(TransformerBlock::new( + args.hidden_size, + args.hidden_size * 4, + )); // IdentityBlock() in Python seems to be a no-op + } + } + + let mut norm = None; + if args.shard.is_last_layer() { + norm = Some(nn::RmsNorm::new(args.hidden_size, args.rms_norm_eps)?); + } + + Ok(Self { + args, + vocab_size, + num_hidden_layers, + embed_tokens, + layers, + norm, + }) + } +} + +impl Module<&Array> for LlamaModelRs { + type Output = Array; + type Error = mlx_rs::error::Exception; + + fn forward(&mut self, inputs: &Array) -> Result { + let mut h; + if self.args.shard.is_first_layer() && self.embed_tokens.is_some() { + h = self.embed_tokens.as_ref().unwrap().forward(inputs)?; + } else { + h = inputs.clone(); // Assuming input is already embedded if not the first layer + } + + // Mask creation logic would go here - needs to be implemented in Rust + // let mask = None; + // if h.ndim() > 1 && h.shape()[1] > 1 { + // mask = create_attention_mask(h, cache); // Need to port create_attention_mask to Rust + // } + + // Cache handling - needs more detailed implementation for Rust + // let mut cache = cache.unwrap_or_else(|| vec![None; self.layers.len()]); + + for layer in &mut self.layers { + h = layer.forward(&h)?; // Pass mask and cache when implemented + } + + if self.args.shard.is_last_layer() && self.norm.is_some() { + h = self.norm.as_ref().unwrap().forward(&h)?; + } + Ok(h) + } + + fn training_mode(&mut self, mode: bool) {} +} + +// Define Model struct +#[derive(Debug, Clone, ModuleParameters)] +pub struct ModelRs { + pub args: ModelArgsRs, + pub model_type: String, + pub model: LlamaModelRs, + pub lm_head: Option, // Linear layer for language model head, optional based on tie_word_embeddings +} + +impl ModelRs { + pub fn new(args: ModelArgsRs) -> Result { + let model = LlamaModelRs::new(args.clone())?; // Clone args for LlamaModel + let model_type = args.model_type.clone(); + let mut lm_head = None; + if args.shard.is_last_layer() && !args.tie_word_embeddings { + lm_head = Some(nn::Linear::new(args.hidden_size, args.vocab_size)?); + } + + Ok(Self { + args, + model_type, + model, + lm_head, + }) + } +} + +impl Module<&Array> for ModelRs { + type Output = Array; + type Error = mlx_rs::error::Exception; + + fn forward(&mut self, inputs: &Array) -> Result { + let mut out = self.model.forward(inputs)?; + + if self.args.shard.is_last_layer() { + if self.args.tie_word_embeddings && self.model.embed_tokens.is_some() { + // Need to implement as_linear() equivalent in Rust or directly use embedding weights for linear transformation + // Placeholder - direct linear transformation using embedding weights is not directly available in mlx-rs as in python + if let Some(embed_tokens) = &self.model.embed_tokens { + if let Ok(params) = embed_tokens.parameters() { + if let Some(weight_param) = params.get("weight") { + // This is a very simplified placeholder - needs proper matrix multiplication with 'out' and 'weight_param' + out = weight_param.clone(); // Incorrect - replace with actual linear transformation + } + } + } + } else if self.lm_head.is_some() { + out = self.lm_head.as_ref().unwrap().forward(&out)?; + } + } + Ok(out) + } + + fn training_mode(&mut self, mode: bool) {} +} diff --git a/src/llama_test.rs b/src/llama_test.rs new file mode 100644 index 0000000..c297c87 --- /dev/null +++ b/src/llama_test.rs @@ -0,0 +1,73 @@ +use llama_cpp_2::context::params::LlamaContextParams; +use llama_cpp_2::llama_backend::LlamaBackend; +use llama_cpp_2::llama_batch::LlamaBatch; +use llama_cpp_2::model::{AddBos, LlamaModel, Special}; +use llama_cpp_2::model::params::LlamaModelParams; +use llama_cpp_2::sampling::LlamaSampler; + +fn test() { + let model_path = std::env::args().nth(1).expect("Please specify model path"); + let backend = LlamaBackend::init().unwrap(); + let params = LlamaModelParams::default(); + + let prompt = + "<|im_start|>user\nHello! how are you?<|im_end|>\n<|im_start|>assistant\n".to_string(); + LlamaContextParams::default(); + let model = + LlamaModel::load_from_file(&backend, model_path, ¶ms).expect("unable to load model"); + let ctx_params = LlamaContextParams::default(); + let mut ctx = model + .new_context(&backend, ctx_params) + .expect("unable to create the llama_context"); + let tokens_list = model + .str_to_token(&prompt, AddBos::Always) + .unwrap_or_else(|_| panic!("failed to tokenize {prompt}")); + let n_len = 64; + + // create a llama_batch with size 512 + // we use this object to submit token data for decoding + let mut batch = LlamaBatch::new(512, 1); + + let last_index = tokens_list.len() as i32 - 1; + for (i, token) in (0_i32..).zip(tokens_list.into_iter()) { + // llama_decode will output logits only for the last token of the prompt + let is_last = i == last_index; + batch.add(token, i, &[0], is_last).unwrap(); + } + ctx.decode(&mut batch).expect("llama_decode() failed"); + + let mut n_cur = batch.n_tokens(); + + // The `Decoder` + let mut decoder = encoding_rs::UTF_8.new_decoder(); + let mut sampler = LlamaSampler::greedy(); + + while n_cur <= n_len { + // sample the next token + { + let token = sampler.sample(&ctx, batch.n_tokens() - 1); + + sampler.accept(token); + + // is it an end of stream? + if token == model.token_eos() { + eprintln!(); + break; + } + + let output_bytes = model.token_to_bytes(token, Special::Tokenize).unwrap(); + // use `Decoder.decode_to_string()` to avoid the intermediate buffer + let mut output_string = String::with_capacity(32); + let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false); + print!("{output_string}"); + std::io::stdout().flush().unwrap(); + + batch.clear(); + batch.add(token, n_cur, &[0], true).unwrap(); + } + + n_cur += 1; + + ctx.decode(&mut batch).expect("failed to eval"); + } +} diff --git a/src/main.rs b/src/main.rs index a5a10d8..70011bb 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,15 +1,20 @@ mod device_capability_data; mod discovery; +mod inference; mod network; mod orchestration; mod partitioning; mod topology; +mod llama_test; +mod module_loading; +mod llama_module; use serde::{Deserialize, Serialize}; use serde_json::Value; use tonic::{transport::Server, Request, Response, Status}; use crate::discovery::{NodeInfo, UdpDiscovery}; +use crate::inference::InferenceEngine; use crate::node_service::{ CollectTopologyRequest, Empty, ExampleRequest, HealthCheckRequest, HealthCheckResponse, InferenceState, Loss, PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor, @@ -30,6 +35,7 @@ struct Node { node_info: NodeInfo, current_topology: Topology, udp_discovery: UdpDiscovery, + inference_engine: InferenceEngine, } impl Node { @@ -45,7 +51,10 @@ impl Node { .current_topology .get_shard_for_node(base_shard, &self.node_info.node_id); - let result = self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state); + let result: Tensor = self + .inference_engine + .infer_tensor(request_id, shard, tensor, inference_state); + let result = self.process_inference_result(shard, result, request_id, inference_state); result diff --git a/src/module_loading.rs b/src/module_loading.rs new file mode 100644 index 0000000..04df095 --- /dev/null +++ b/src/module_loading.rs @@ -0,0 +1,52 @@ +use std::path::Path; + +use crate::Shard; + +fn load_config( + model_path: &Path, +) -> serde_json::Map { + let config_path = model_path.join("config.json"); + let model_index = model_path.join("model_index.json"); + + if config_path.exists() { + let config = std::fs::read_to_string(config_path).unwrap(); + serde_json::from_str(&config).unwrap() + } else { + let model_index = std::fs::read_to_string(model_index).unwrap(); + serde_json::from_str(&model_index).unwrap() + } +} + +async fn load_model_shard( + model_path: &Path, + shard: Shard, + lazy: bool, + model_config: serde_json::Map, +) { + let mut config = load_config(model_path); + config.extend(model_config.into_iter()); + + let model_name = model_path.file_name().unwrap().to_str().unwrap(); + + config["shard"] = serde_json::json!({ + "model_id": model_name, + "start_layer": shard.start_layer, + "end_layer": shard.end_layer, + "n_layers": shard.total_layers, + }); + + let weight_files = glob::glob(model_path.join("model*.safetensors").to_str().unwrap()) + .unwrap() + .collect::, _>>() + .unwrap(); + + let weight_files = weight_files + .iter() + .map(|path| path.file_name().unwrap().to_str().unwrap()) + .collect::>(); + + let weights = weight_files.iter().map(|file| { + + }); + todo!(); +} \ No newline at end of file