Stash gen of LLAMA mlx
This commit is contained in:
parent
4cd96b58b5
commit
a129325ade
@ -6,6 +6,9 @@ use mlx_rs::nested::NestedHashMap;
|
|||||||
use mlx_rs::nn;
|
use mlx_rs::nn;
|
||||||
use mlx_rs::Array;
|
use mlx_rs::Array;
|
||||||
use std::rc::Rc;
|
use std::rc::Rc;
|
||||||
|
use mlx_rs::builder::Builder;
|
||||||
|
use mlx_rs::nn::RmsNormBuilder;
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
// Define Shard struct to mirror Python dataclass
|
// Define Shard struct to mirror Python dataclass
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -38,7 +41,7 @@ impl Shard {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Define ModelArgs struct to mirror Python dataclass ModelArgs
|
// Define ModelArgs struct to mirror Python dataclass ModelArgs
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
pub struct ModelArgsRs {
|
pub struct ModelArgsRs {
|
||||||
pub vocab_size: i32,
|
pub vocab_size: i32,
|
||||||
pub hidden_size: i32,
|
pub hidden_size: i32,
|
||||||
@ -102,7 +105,7 @@ impl LlamaModelRs {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let mut layers = Vec::new();
|
let mut layers = Vec::new();
|
||||||
for i in 0..num_hidden_layers {
|
for i in 0..(num_hidden_layers as usize) {
|
||||||
if args.shard.start_layer <= i && i <= args.shard.end_layer {
|
if args.shard.start_layer <= i && i <= args.shard.end_layer {
|
||||||
// Using placeholder dimensions for TransformerBlock, adjust as needed
|
// Using placeholder dimensions for TransformerBlock, adjust as needed
|
||||||
layers.push(TransformerBlock::new(
|
layers.push(TransformerBlock::new(
|
||||||
@ -121,7 +124,10 @@ impl LlamaModelRs {
|
|||||||
|
|
||||||
let mut norm = None;
|
let mut norm = None;
|
||||||
if args.shard.is_last_layer() {
|
if args.shard.is_last_layer() {
|
||||||
norm = Some(nn::RmsNorm::new(args.hidden_size, args.rms_norm_eps)?);
|
norm = Some(RmsNormBuilder::new(args.hidden_size)
|
||||||
|
.eps(args.rms_norm_eps)
|
||||||
|
.build()
|
||||||
|
.unwrap());
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
use serde_json::Value;
|
||||||
|
use crate::llama_module::{LlamaModelRs, ModelArgsRs};
|
||||||
use crate::Shard;
|
use crate::Shard;
|
||||||
|
|
||||||
fn load_config(
|
fn load_config(
|
||||||
@ -35,18 +37,30 @@ async fn load_model_shard(
|
|||||||
"n_layers": shard.total_layers,
|
"n_layers": shard.total_layers,
|
||||||
});
|
});
|
||||||
|
|
||||||
let weight_files = glob::glob(model_path.join("model*.safetensors").to_str().unwrap())
|
let mut weight_files = glob::glob(model_path.join("model*.safetensors").to_str().unwrap())
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.collect::<Result<Vec<_>, _>>()
|
.collect::<Result<Vec<_>, _>>()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let weight_files = weight_files
|
let model_args = serde_json::from_value::<ModelArgsRs>(Value::Object(config)).unwrap();
|
||||||
.iter()
|
let model = LlamaModelRs::new(model_args).unwrap();
|
||||||
.map(|path| path.file_name().unwrap().to_str().unwrap())
|
|
||||||
.collect::<Vec<_>>();
|
weight_files.sort_by_key(|x| x.file_name().unwrap().to_str().unwrap());
|
||||||
|
|
||||||
|
let mut weights = HashMap::new();
|
||||||
|
|
||||||
|
for weight_file in weight_files {
|
||||||
|
weights.extend(mlx_rs::Array::load_safetensors(&weight_file).unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
model
|
||||||
|
|
||||||
let weights = weight_files.iter().map(|file| {
|
|
||||||
|
|
||||||
});
|
|
||||||
todo!();
|
todo!();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_load_llama() {
|
||||||
|
load_model_shard(
|
||||||
|
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user