Stash gen of LLAMA mlx

This commit is contained in:
Joshua Coles 2025-02-12 17:18:40 +00:00
parent 4cd96b58b5
commit a129325ade
2 changed files with 33 additions and 13 deletions

View File

@ -6,6 +6,9 @@ use mlx_rs::nested::NestedHashMap;
use mlx_rs::nn;
use mlx_rs::Array;
use std::rc::Rc;
use mlx_rs::builder::Builder;
use mlx_rs::nn::RmsNormBuilder;
use serde::Deserialize;
// Define Shard struct to mirror Python dataclass
#[derive(Debug, Clone)]
@ -38,7 +41,7 @@ impl Shard {
}
// Define ModelArgs struct to mirror Python dataclass ModelArgs
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Deserialize)]
pub struct ModelArgsRs {
pub vocab_size: i32,
pub hidden_size: i32,
@ -102,7 +105,7 @@ impl LlamaModelRs {
}
let mut layers = Vec::new();
for i in 0..num_hidden_layers {
for i in 0..(num_hidden_layers as usize) {
if args.shard.start_layer <= i && i <= args.shard.end_layer {
// Using placeholder dimensions for TransformerBlock, adjust as needed
layers.push(TransformerBlock::new(
@ -121,7 +124,10 @@ impl LlamaModelRs {
let mut norm = None;
if args.shard.is_last_layer() {
norm = Some(nn::RmsNorm::new(args.hidden_size, args.rms_norm_eps)?);
norm = Some(RmsNormBuilder::new(args.hidden_size)
.eps(args.rms_norm_eps)
.build()
.unwrap());
}
Ok(Self {

View File

@ -1,5 +1,7 @@
use std::collections::HashMap;
use std::path::Path;
use serde_json::Value;
use crate::llama_module::{LlamaModelRs, ModelArgsRs};
use crate::Shard;
fn load_config(
@ -35,18 +37,30 @@ async fn load_model_shard(
"n_layers": shard.total_layers,
});
let weight_files = glob::glob(model_path.join("model*.safetensors").to_str().unwrap())
let mut weight_files = glob::glob(model_path.join("model*.safetensors").to_str().unwrap())
.unwrap()
.collect::<Result<Vec<_>, _>>()
.unwrap();
let weight_files = weight_files
.iter()
.map(|path| path.file_name().unwrap().to_str().unwrap())
.collect::<Vec<_>>();
let model_args = serde_json::from_value::<ModelArgsRs>(Value::Object(config)).unwrap();
let model = LlamaModelRs::new(model_args).unwrap();
weight_files.sort_by_key(|x| x.file_name().unwrap().to_str().unwrap());
let mut weights = HashMap::new();
for weight_file in weight_files {
weights.extend(mlx_rs::Array::load_safetensors(&weight_file).unwrap());
}
model
let weights = weight_files.iter().map(|file| {
});
todo!();
}
}
#[test]
fn test_load_llama() {
load_model_shard(
)
}