Stash gen of LLAMA mlx

This commit is contained in:
Joshua Coles 2025-02-12 17:07:06 +00:00
parent 39ac5a86dd
commit 4cd96b58b5
7 changed files with 593 additions and 8 deletions

217
Cargo.lock generated
View File

@ -145,6 +145,29 @@ version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]]
name = "bindgen"
version = "0.69.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
dependencies = [
"bitflags",
"cexpr",
"clang-sys",
"itertools 0.12.1",
"lazy_static",
"lazycell",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash",
"shlex",
"syn",
"which",
]
[[package]] [[package]]
name = "bindgen" name = "bindgen"
version = "0.70.1" version = "0.70.1"
@ -195,6 +218,8 @@ version = "1.2.13"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7777341816418c02e033934a09f20dc0ccaf65a5201ef8a450ae0105a573fda" checksum = "c7777341816418c02e033934a09f20dc0ccaf65a5201ef8a450ae0105a573fda"
dependencies = [ dependencies = [
"jobserver",
"libc",
"shlex", "shlex",
] ]
@ -311,6 +336,26 @@ version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0"
[[package]]
name = "enumflags2"
version = "0.7.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba2f4b465f5318854c6f8dd686ede6c0a9dc67d4b1ac241cf0eb51521a309147"
dependencies = [
"enumflags2_derive",
]
[[package]]
name = "enumflags2_derive"
version = "0.7.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc4caf64a58d7a6d65ab00639b046ff54399a39f5f2554728895ace4b297cd79"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "equivalent" name = "equivalent"
version = "1.0.1" version = "1.0.1"
@ -324,13 +369,15 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d"
dependencies = [ dependencies = [
"libc", "libc",
"windows-sys", "windows-sys 0.52.0",
] ]
[[package]] [[package]]
name = "exo-rs" name = "exo-rs"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"glob",
"llama-cpp-2",
"mlx-rs", "mlx-rs",
"network-interface", "network-interface",
"phf", "phf",
@ -341,6 +388,7 @@ dependencies = [
"socket2", "socket2",
"system-configuration", "system-configuration",
"thiserror 2.0.11", "thiserror 2.0.11",
"tinygrad",
"tokio", "tokio",
"tonic", "tonic",
"tonic-build", "tonic-build",
@ -498,6 +546,15 @@ version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "home"
version = "0.5.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf"
dependencies = [
"windows-sys 0.59.0",
]
[[package]] [[package]]
name = "http" name = "http"
version = "1.2.0" version = "1.2.0"
@ -623,6 +680,15 @@ dependencies = [
"hashbrown 0.15.2", "hashbrown 0.15.2",
] ]
[[package]]
name = "itertools"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
dependencies = [
"either",
]
[[package]] [[package]]
name = "itertools" name = "itertools"
version = "0.13.0" version = "0.13.0"
@ -647,12 +713,27 @@ version = "1.0.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674"
[[package]]
name = "jobserver"
version = "0.1.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "lazy_static" name = "lazy_static"
version = "1.5.0" version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]]
name = "lazycell"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.169" version = "0.2.169"
@ -675,6 +756,32 @@ version = "0.4.15"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab"
[[package]]
name = "llama-cpp-2"
version = "0.1.93"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "44818b6967a77379b8c8e105e2684d2bd2bca999ad24cfd806d8476a80c53255"
dependencies = [
"enumflags2",
"llama-cpp-sys-2",
"thiserror 1.0.69",
"tracing",
"tracing-core",
]
[[package]]
name = "llama-cpp-sys-2"
version = "0.1.93"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92add8e8dabf941518dd573075721a6e0568db86c6557e176cec99bf56883c0a"
dependencies = [
"bindgen 0.69.5",
"cc",
"cmake",
"glob",
"walkdir",
]
[[package]] [[package]]
name = "lock_api" name = "lock_api"
version = "0.4.12" version = "0.4.12"
@ -706,6 +813,16 @@ version = "0.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
[[package]]
name = "matrixmultiply"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a"
dependencies = [
"autocfg",
"rawpointer",
]
[[package]] [[package]]
name = "memchr" name = "memchr"
version = "2.7.4" version = "2.7.4"
@ -741,7 +858,7 @@ checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd"
dependencies = [ dependencies = [
"libc", "libc",
"wasi 0.11.0+wasi-snapshot-preview1", "wasi 0.11.0+wasi-snapshot-preview1",
"windows-sys", "windows-sys 0.52.0",
] ]
[[package]] [[package]]
@ -800,7 +917,7 @@ version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af33a6b662998e5bb4099b1a191b4352fcb11d97706e82e4c8922fe200bb11f2" checksum = "af33a6b662998e5bb4099b1a191b4352fcb11d97706e82e4c8922fe200bb11f2"
dependencies = [ dependencies = [
"bindgen", "bindgen 0.70.1",
"cc", "cc",
"cmake", "cmake",
] ]
@ -811,6 +928,19 @@ version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03" checksum = "defc4c55412d89136f966bbb339008b474350e5e6e78d2714439c386b3137a03"
[[package]]
name = "ndarray"
version = "0.15.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32"
dependencies = [
"matrixmultiply",
"num-complex",
"num-integer",
"num-traits",
"rawpointer",
]
[[package]] [[package]]
name = "network-interface" name = "network-interface"
version = "2.0.0" version = "2.0.0"
@ -852,6 +982,15 @@ dependencies = [
"num-traits", "num-traits",
] ]
[[package]]
name = "num-integer"
version = "0.1.46"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f"
dependencies = [
"num-traits",
]
[[package]] [[package]]
name = "num-traits" name = "num-traits"
version = "0.2.19" version = "0.2.19"
@ -1150,6 +1289,12 @@ dependencies = [
"getrandom 0.2.15", "getrandom 0.2.15",
] ]
[[package]]
name = "rawpointer"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
[[package]] [[package]]
name = "redox_syscall" name = "redox_syscall"
version = "0.5.8" version = "0.5.8"
@ -1210,7 +1355,7 @@ dependencies = [
"errno", "errno",
"libc", "libc",
"linux-raw-sys", "linux-raw-sys",
"windows-sys", "windows-sys 0.52.0",
] ]
[[package]] [[package]]
@ -1235,6 +1380,15 @@ dependencies = [
"serde_json", "serde_json",
] ]
[[package]]
name = "same-file"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
dependencies = [
"winapi-util",
]
[[package]] [[package]]
name = "scopeguard" name = "scopeguard"
version = "1.2.0" version = "1.2.0"
@ -1325,7 +1479,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8"
dependencies = [ dependencies = [
"libc", "libc",
"windows-sys", "windows-sys 0.52.0",
] ]
[[package]] [[package]]
@ -1405,7 +1559,7 @@ dependencies = [
"getrandom 0.3.1", "getrandom 0.3.1",
"once_cell", "once_cell",
"rustix", "rustix",
"windows-sys", "windows-sys 0.52.0",
] ]
[[package]] [[package]]
@ -1458,6 +1612,15 @@ dependencies = [
"once_cell", "once_cell",
] ]
[[package]]
name = "tinygrad"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc8951e9ced45095eb3ad7342c4e74b038bda930833df79b7debb019bb653c18"
dependencies = [
"ndarray",
]
[[package]] [[package]]
name = "tokio" name = "tokio"
version = "1.43.0" version = "1.43.0"
@ -1473,7 +1636,7 @@ dependencies = [
"signal-hook-registry", "signal-hook-registry",
"socket2", "socket2",
"tokio-macros", "tokio-macros",
"windows-sys", "windows-sys 0.52.0",
] ]
[[package]] [[package]]
@ -1703,6 +1866,16 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65"
[[package]]
name = "walkdir"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b"
dependencies = [
"same-file",
"winapi-util",
]
[[package]] [[package]]
name = "want" name = "want"
version = "0.3.1" version = "0.3.1"
@ -1727,6 +1900,18 @@ dependencies = [
"wit-bindgen-rt", "wit-bindgen-rt",
] ]
[[package]]
name = "which"
version = "4.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
dependencies = [
"either",
"home",
"once_cell",
"rustix",
]
[[package]] [[package]]
name = "winapi" name = "winapi"
version = "0.3.9" version = "0.3.9"
@ -1743,6 +1928,15 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-util"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
dependencies = [
"windows-sys 0.52.0",
]
[[package]] [[package]]
name = "winapi-x86_64-pc-windows-gnu" name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0" version = "0.4.0"
@ -1758,6 +1952,15 @@ dependencies = [
"windows-targets", "windows-targets",
] ]
[[package]]
name = "windows-sys"
version = "0.59.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b"
dependencies = [
"windows-targets",
]
[[package]] [[package]]
name = "windows-targets" name = "windows-targets"
version = "0.52.6" version = "0.52.6"

View File

@ -19,6 +19,9 @@ uuid = { version = "1.13.1", features = ["v4"] }
regex = "1.11.1" regex = "1.11.1"
phf = { version = "0.11.3", features = ["macros"] } phf = { version = "0.11.3", features = ["macros"] }
mlx-rs = { version = "0.21.0", features = ["metal", "accelerate", "safetensors"] } mlx-rs = { version = "0.21.0", features = ["metal", "accelerate", "safetensors"] }
tinygrad = "0.1.0"
llama-cpp-2 = { version = "0.1.93", features = ["metal", "native", "sampler"] }
glob = "0.3.2"
[build-dependencies] [build-dependencies]
tonic-build = "0.12.3" tonic-build = "0.12.3"

19
src/inference.rs Normal file
View File

@ -0,0 +1,19 @@
use crate::node_service::{InferenceState, Tensor};
use crate::Shard;
#[derive(Debug)]
pub struct InferenceEngine {
state_cache: HashMap<String, _>,
}
impl InferenceEngine {
pub(crate) fn infer_tensor(
&self,
request_id: String,
shard: Shard,
tensor: Option<Tensor>,
inference_state: Option<InferenceState>,
) -> Tensor {
}
}

226
src/llama_module.rs Normal file
View File

@ -0,0 +1,226 @@
use mlx_rs::macros::ModuleParameters;
use mlx_rs::module::Module;
use mlx_rs::module::ModuleParameters;
use mlx_rs::module::Param;
use mlx_rs::nested::NestedHashMap;
use mlx_rs::nn;
use mlx_rs::Array;
use std::rc::Rc;
// Define Shard struct to mirror Python dataclass
#[derive(Debug, Clone)]
pub struct Shard {
pub name: String,
pub start_layer: usize,
pub end_layer: usize,
}
impl Shard {
pub fn new(name: String, start_layer: usize, end_layer: usize) -> Self {
Shard {
name,
start_layer,
end_layer,
}
}
pub fn is_first_layer(&self) -> bool {
self.start_layer == 0
}
pub fn is_last_layer(&self) -> bool {
// Assuming end_layer is inclusive and represents the last layer index in the shard
// and num_hidden_layers is the total number of layers.
// We would need num_hidden_layers to accurately determine the last layer.
// For now, let's assume if end_layer is very large, it's the last layer in shard.
self.end_layer > 9999 // A large number as a placeholder, adjust as needed
}
}
// Define ModelArgs struct to mirror Python dataclass ModelArgs
#[derive(Debug, Clone)]
pub struct ModelArgsRs {
pub vocab_size: i32,
pub hidden_size: i32,
pub num_hidden_layers: i32,
pub num_attention_heads: i32,
pub num_key_value_heads: i32,
pub rms_norm_eps: f32,
pub tie_word_embeddings: bool,
pub model_type: String, // Assuming model_type is a String
pub head_dim: Option<i32>, // Using Option to represent optional field
pub shard: Shard, // Using the Shard struct defined above
}
// Placeholder for TransformerBlock - You'll need to implement this in Rust
#[derive(Debug, Clone, ModuleParameters)]
pub struct TransformerBlock {
// Define the layers within TransformerBlock as needed, e.g., attention, norm, etc.
// For now, using a linear layer as a placeholder
pub linear: nn::Linear,
}
impl TransformerBlock {
pub fn new(dims: i32, mlp_dims: i32) -> Self {
Self {
linear: nn::Linear::new(dims, mlp_dims).unwrap(), // Example, adjust params
}
}
}
impl Module<&Array> for TransformerBlock {
type Output = Array;
type Error = mlx_rs::error::Exception;
fn forward(&mut self, input: &Array) -> Result<Self::Output, Self::Error> {
// Implement the forward pass logic for TransformerBlock
// For now, just passing through the linear layer
self.linear.forward(input)
}
fn training_mode(&mut self, mode: bool) {}
}
// Define LlamaModel struct
#[derive(Debug, Clone, ModuleParameters)]
pub struct LlamaModelRs {
pub args: ModelArgsRs,
pub vocab_size: i32,
pub num_hidden_layers: i32,
pub embed_tokens: Option<nn::Embedding>, // Embedding layer is optional based on sharding
pub layers: Vec<TransformerBlock>, // Using placeholder TransformerBlock
pub norm: Option<nn::RmsNorm>, // RMSNorm layer is optional based on sharding
}
impl LlamaModelRs {
pub fn new(args: ModelArgsRs) -> Result<Self, mlx_rs::error::Exception> {
let vocab_size = args.vocab_size;
let num_hidden_layers = args.num_hidden_layers;
let mut embed_tokens = None;
if args.shard.is_first_layer() || (args.shard.is_last_layer() && args.tie_word_embeddings) {
embed_tokens = Some(nn::Embedding::new(args.vocab_size, args.hidden_size)?);
}
let mut layers = Vec::new();
for i in 0..num_hidden_layers {
if args.shard.start_layer <= i && i <= args.shard.end_layer {
// Using placeholder dimensions for TransformerBlock, adjust as needed
layers.push(TransformerBlock::new(
args.hidden_size,
args.hidden_size * 4,
));
} else {
// Placeholder for IdentityBlock - you might need to create a Rust version if needed
// For now, just pushing a default TransformerBlock or handle differently
layers.push(TransformerBlock::new(
args.hidden_size,
args.hidden_size * 4,
)); // IdentityBlock() in Python seems to be a no-op
}
}
let mut norm = None;
if args.shard.is_last_layer() {
norm = Some(nn::RmsNorm::new(args.hidden_size, args.rms_norm_eps)?);
}
Ok(Self {
args,
vocab_size,
num_hidden_layers,
embed_tokens,
layers,
norm,
})
}
}
impl Module<&Array> for LlamaModelRs {
type Output = Array;
type Error = mlx_rs::error::Exception;
fn forward(&mut self, inputs: &Array) -> Result<Self::Output, Self::Error> {
let mut h;
if self.args.shard.is_first_layer() && self.embed_tokens.is_some() {
h = self.embed_tokens.as_ref().unwrap().forward(inputs)?;
} else {
h = inputs.clone(); // Assuming input is already embedded if not the first layer
}
// Mask creation logic would go here - needs to be implemented in Rust
// let mask = None;
// if h.ndim() > 1 && h.shape()[1] > 1 {
// mask = create_attention_mask(h, cache); // Need to port create_attention_mask to Rust
// }
// Cache handling - needs more detailed implementation for Rust
// let mut cache = cache.unwrap_or_else(|| vec![None; self.layers.len()]);
for layer in &mut self.layers {
h = layer.forward(&h)?; // Pass mask and cache when implemented
}
if self.args.shard.is_last_layer() && self.norm.is_some() {
h = self.norm.as_ref().unwrap().forward(&h)?;
}
Ok(h)
}
fn training_mode(&mut self, mode: bool) {}
}
// Define Model struct
#[derive(Debug, Clone, ModuleParameters)]
pub struct ModelRs {
pub args: ModelArgsRs,
pub model_type: String,
pub model: LlamaModelRs,
pub lm_head: Option<nn::Linear>, // Linear layer for language model head, optional based on tie_word_embeddings
}
impl ModelRs {
pub fn new(args: ModelArgsRs) -> Result<Self, mlx_rs::error::Exception> {
let model = LlamaModelRs::new(args.clone())?; // Clone args for LlamaModel
let model_type = args.model_type.clone();
let mut lm_head = None;
if args.shard.is_last_layer() && !args.tie_word_embeddings {
lm_head = Some(nn::Linear::new(args.hidden_size, args.vocab_size)?);
}
Ok(Self {
args,
model_type,
model,
lm_head,
})
}
}
impl Module<&Array> for ModelRs {
type Output = Array;
type Error = mlx_rs::error::Exception;
fn forward(&mut self, inputs: &Array) -> Result<Self::Output, Self::Error> {
let mut out = self.model.forward(inputs)?;
if self.args.shard.is_last_layer() {
if self.args.tie_word_embeddings && self.model.embed_tokens.is_some() {
// Need to implement as_linear() equivalent in Rust or directly use embedding weights for linear transformation
// Placeholder - direct linear transformation using embedding weights is not directly available in mlx-rs as in python
if let Some(embed_tokens) = &self.model.embed_tokens {
if let Ok(params) = embed_tokens.parameters() {
if let Some(weight_param) = params.get("weight") {
// This is a very simplified placeholder - needs proper matrix multiplication with 'out' and 'weight_param'
out = weight_param.clone(); // Incorrect - replace with actual linear transformation
}
}
}
} else if self.lm_head.is_some() {
out = self.lm_head.as_ref().unwrap().forward(&out)?;
}
}
Ok(out)
}
fn training_mode(&mut self, mode: bool) {}
}

73
src/llama_test.rs Normal file
View File

@ -0,0 +1,73 @@
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::{AddBos, LlamaModel, Special};
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::sampling::LlamaSampler;
fn test() {
let model_path = std::env::args().nth(1).expect("Please specify model path");
let backend = LlamaBackend::init().unwrap();
let params = LlamaModelParams::default();
let prompt =
"<|im_start|>user\nHello! how are you?<|im_end|>\n<|im_start|>assistant\n".to_string();
LlamaContextParams::default();
let model =
LlamaModel::load_from_file(&backend, model_path, &params).expect("unable to load model");
let ctx_params = LlamaContextParams::default();
let mut ctx = model
.new_context(&backend, ctx_params)
.expect("unable to create the llama_context");
let tokens_list = model
.str_to_token(&prompt, AddBos::Always)
.unwrap_or_else(|_| panic!("failed to tokenize {prompt}"));
let n_len = 64;
// create a llama_batch with size 512
// we use this object to submit token data for decoding
let mut batch = LlamaBatch::new(512, 1);
let last_index = tokens_list.len() as i32 - 1;
for (i, token) in (0_i32..).zip(tokens_list.into_iter()) {
// llama_decode will output logits only for the last token of the prompt
let is_last = i == last_index;
batch.add(token, i, &[0], is_last).unwrap();
}
ctx.decode(&mut batch).expect("llama_decode() failed");
let mut n_cur = batch.n_tokens();
// The `Decoder`
let mut decoder = encoding_rs::UTF_8.new_decoder();
let mut sampler = LlamaSampler::greedy();
while n_cur <= n_len {
// sample the next token
{
let token = sampler.sample(&ctx, batch.n_tokens() - 1);
sampler.accept(token);
// is it an end of stream?
if token == model.token_eos() {
eprintln!();
break;
}
let output_bytes = model.token_to_bytes(token, Special::Tokenize).unwrap();
// use `Decoder.decode_to_string()` to avoid the intermediate buffer
let mut output_string = String::with_capacity(32);
let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false);
print!("{output_string}");
std::io::stdout().flush().unwrap();
batch.clear();
batch.add(token, n_cur, &[0], true).unwrap();
}
n_cur += 1;
ctx.decode(&mut batch).expect("failed to eval");
}
}

View File

@ -1,15 +1,20 @@
mod device_capability_data; mod device_capability_data;
mod discovery; mod discovery;
mod inference;
mod network; mod network;
mod orchestration; mod orchestration;
mod partitioning; mod partitioning;
mod topology; mod topology;
mod llama_test;
mod module_loading;
mod llama_module;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::Value; use serde_json::Value;
use tonic::{transport::Server, Request, Response, Status}; use tonic::{transport::Server, Request, Response, Status};
use crate::discovery::{NodeInfo, UdpDiscovery}; use crate::discovery::{NodeInfo, UdpDiscovery};
use crate::inference::InferenceEngine;
use crate::node_service::{ use crate::node_service::{
CollectTopologyRequest, Empty, ExampleRequest, HealthCheckRequest, HealthCheckResponse, CollectTopologyRequest, Empty, ExampleRequest, HealthCheckRequest, HealthCheckResponse,
InferenceState, Loss, PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor, InferenceState, Loss, PromptRequest, SendOpaqueStatusRequest, SendResultRequest, Tensor,
@ -30,6 +35,7 @@ struct Node {
node_info: NodeInfo, node_info: NodeInfo,
current_topology: Topology, current_topology: Topology,
udp_discovery: UdpDiscovery, udp_discovery: UdpDiscovery,
inference_engine: InferenceEngine,
} }
impl Node { impl Node {
@ -45,7 +51,10 @@ impl Node {
.current_topology .current_topology
.get_shard_for_node(base_shard, &self.node_info.node_id); .get_shard_for_node(base_shard, &self.node_info.node_id);
let result = self.inference_engine.infer_tensor(request_id, shard, tensor, inference_state); let result: Tensor = self
.inference_engine
.infer_tensor(request_id, shard, tensor, inference_state);
let result = self.process_inference_result(shard, result, request_id, inference_state); let result = self.process_inference_result(shard, result, request_id, inference_state);
result result

52
src/module_loading.rs Normal file
View File

@ -0,0 +1,52 @@
use std::path::Path;
use crate::Shard;
fn load_config(
model_path: &Path,
) -> serde_json::Map<String, serde_json::Value> {
let config_path = model_path.join("config.json");
let model_index = model_path.join("model_index.json");
if config_path.exists() {
let config = std::fs::read_to_string(config_path).unwrap();
serde_json::from_str(&config).unwrap()
} else {
let model_index = std::fs::read_to_string(model_index).unwrap();
serde_json::from_str(&model_index).unwrap()
}
}
async fn load_model_shard(
model_path: &Path,
shard: Shard,
lazy: bool,
model_config: serde_json::Map<String, serde_json::Value>,
) {
let mut config = load_config(model_path);
config.extend(model_config.into_iter());
let model_name = model_path.file_name().unwrap().to_str().unwrap();
config["shard"] = serde_json::json!({
"model_id": model_name,
"start_layer": shard.start_layer,
"end_layer": shard.end_layer,
"n_layers": shard.total_layers,
});
let weight_files = glob::glob(model_path.join("model*.safetensors").to_str().unwrap())
.unwrap()
.collect::<Result<Vec<_>, _>>()
.unwrap();
let weight_files = weight_files
.iter()
.map(|path| path.file_name().unwrap().to_str().unwrap())
.collect::<Vec<_>>();
let weights = weight_files.iter().map(|file| {
});
todo!();
}