From a129325ade5e3aee6ba88b76feb69d5346c42e0f Mon Sep 17 00:00:00 2001 From: Joshua Coles Date: Wed, 12 Feb 2025 17:18:40 +0000 Subject: [PATCH] Stash gen of LLAMA mlx --- src/llama_module.rs | 12 +++++++++--- src/module_loading.rs | 34 ++++++++++++++++++++++++---------- 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/src/llama_module.rs b/src/llama_module.rs index febebcf..be99c7e 100644 --- a/src/llama_module.rs +++ b/src/llama_module.rs @@ -6,6 +6,9 @@ use mlx_rs::nested::NestedHashMap; use mlx_rs::nn; use mlx_rs::Array; use std::rc::Rc; +use mlx_rs::builder::Builder; +use mlx_rs::nn::RmsNormBuilder; +use serde::Deserialize; // Define Shard struct to mirror Python dataclass #[derive(Debug, Clone)] @@ -38,7 +41,7 @@ impl Shard { } // Define ModelArgs struct to mirror Python dataclass ModelArgs -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Deserialize)] pub struct ModelArgsRs { pub vocab_size: i32, pub hidden_size: i32, @@ -102,7 +105,7 @@ impl LlamaModelRs { } let mut layers = Vec::new(); - for i in 0..num_hidden_layers { + for i in 0..(num_hidden_layers as usize) { if args.shard.start_layer <= i && i <= args.shard.end_layer { // Using placeholder dimensions for TransformerBlock, adjust as needed layers.push(TransformerBlock::new( @@ -121,7 +124,10 @@ impl LlamaModelRs { let mut norm = None; if args.shard.is_last_layer() { - norm = Some(nn::RmsNorm::new(args.hidden_size, args.rms_norm_eps)?); + norm = Some(RmsNormBuilder::new(args.hidden_size) + .eps(args.rms_norm_eps) + .build() + .unwrap()); } Ok(Self { diff --git a/src/module_loading.rs b/src/module_loading.rs index 04df095..bfd2a99 100644 --- a/src/module_loading.rs +++ b/src/module_loading.rs @@ -1,5 +1,7 @@ +use std::collections::HashMap; use std::path::Path; - +use serde_json::Value; +use crate::llama_module::{LlamaModelRs, ModelArgsRs}; use crate::Shard; fn load_config( @@ -35,18 +37,30 @@ async fn load_model_shard( "n_layers": shard.total_layers, }); - let weight_files = glob::glob(model_path.join("model*.safetensors").to_str().unwrap()) + let mut weight_files = glob::glob(model_path.join("model*.safetensors").to_str().unwrap()) .unwrap() .collect::, _>>() .unwrap(); - let weight_files = weight_files - .iter() - .map(|path| path.file_name().unwrap().to_str().unwrap()) - .collect::>(); + let model_args = serde_json::from_value::(Value::Object(config)).unwrap(); + let model = LlamaModelRs::new(model_args).unwrap(); + + weight_files.sort_by_key(|x| x.file_name().unwrap().to_str().unwrap()); + + let mut weights = HashMap::new(); + + for weight_file in weight_files { + weights.extend(mlx_rs::Array::load_safetensors(&weight_file).unwrap()); + } + + model - let weights = weight_files.iter().map(|file| { - - }); todo!(); -} \ No newline at end of file +} + +#[test] +fn test_load_llama() { + load_model_shard( + + ) +}