Stash gen of LLAMA mlx
This commit is contained in:
parent
4cd96b58b5
commit
a129325ade
@ -6,6 +6,9 @@ use mlx_rs::nested::NestedHashMap;
|
||||
use mlx_rs::nn;
|
||||
use mlx_rs::Array;
|
||||
use std::rc::Rc;
|
||||
use mlx_rs::builder::Builder;
|
||||
use mlx_rs::nn::RmsNormBuilder;
|
||||
use serde::Deserialize;
|
||||
|
||||
// Define Shard struct to mirror Python dataclass
|
||||
#[derive(Debug, Clone)]
|
||||
@ -38,7 +41,7 @@ impl Shard {
|
||||
}
|
||||
|
||||
// Define ModelArgs struct to mirror Python dataclass ModelArgs
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ModelArgsRs {
|
||||
pub vocab_size: i32,
|
||||
pub hidden_size: i32,
|
||||
@ -102,7 +105,7 @@ impl LlamaModelRs {
|
||||
}
|
||||
|
||||
let mut layers = Vec::new();
|
||||
for i in 0..num_hidden_layers {
|
||||
for i in 0..(num_hidden_layers as usize) {
|
||||
if args.shard.start_layer <= i && i <= args.shard.end_layer {
|
||||
// Using placeholder dimensions for TransformerBlock, adjust as needed
|
||||
layers.push(TransformerBlock::new(
|
||||
@ -121,7 +124,10 @@ impl LlamaModelRs {
|
||||
|
||||
let mut norm = None;
|
||||
if args.shard.is_last_layer() {
|
||||
norm = Some(nn::RmsNorm::new(args.hidden_size, args.rms_norm_eps)?);
|
||||
norm = Some(RmsNormBuilder::new(args.hidden_size)
|
||||
.eps(args.rms_norm_eps)
|
||||
.build()
|
||||
.unwrap());
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
use serde_json::Value;
|
||||
use crate::llama_module::{LlamaModelRs, ModelArgsRs};
|
||||
use crate::Shard;
|
||||
|
||||
fn load_config(
|
||||
@ -35,18 +37,30 @@ async fn load_model_shard(
|
||||
"n_layers": shard.total_layers,
|
||||
});
|
||||
|
||||
let weight_files = glob::glob(model_path.join("model*.safetensors").to_str().unwrap())
|
||||
let mut weight_files = glob::glob(model_path.join("model*.safetensors").to_str().unwrap())
|
||||
.unwrap()
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
.unwrap();
|
||||
|
||||
let weight_files = weight_files
|
||||
.iter()
|
||||
.map(|path| path.file_name().unwrap().to_str().unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
let model_args = serde_json::from_value::<ModelArgsRs>(Value::Object(config)).unwrap();
|
||||
let model = LlamaModelRs::new(model_args).unwrap();
|
||||
|
||||
weight_files.sort_by_key(|x| x.file_name().unwrap().to_str().unwrap());
|
||||
|
||||
let mut weights = HashMap::new();
|
||||
|
||||
for weight_file in weight_files {
|
||||
weights.extend(mlx_rs::Array::load_safetensors(&weight_file).unwrap());
|
||||
}
|
||||
|
||||
model
|
||||
|
||||
let weights = weight_files.iter().map(|file| {
|
||||
|
||||
});
|
||||
todo!();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_llama() {
|
||||
load_model_shard(
|
||||
|
||||
)
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user