65 lines
1.8 KiB
Rust
65 lines
1.8 KiB
Rust
use std::collections::HashMap;
|
|
use std::path::Path;
|
|
use serde_json::Value;
|
|
use crate::llama_module::{LlamaModelRs, ModelArgsRs};
|
|
use crate::Shard;
|
|
|
|
fn load_config(
|
|
model_path: &Path,
|
|
) -> serde_json::Map<String, serde_json::Value> {
|
|
let config_path = model_path.join("config.json");
|
|
let model_index = model_path.join("model_index.json");
|
|
|
|
if config_path.exists() {
|
|
let config = std::fs::read_to_string(config_path).unwrap();
|
|
serde_json::from_str(&config).unwrap()
|
|
} else {
|
|
let model_index = std::fs::read_to_string(model_index).unwrap();
|
|
serde_json::from_str(&model_index).unwrap()
|
|
}
|
|
}
|
|
|
|
async fn load_model_shard(
|
|
model_path: &Path,
|
|
shard: Shard,
|
|
lazy: bool,
|
|
model_config: serde_json::Map<String, serde_json::Value>,
|
|
) {
|
|
let mut config = load_config(model_path);
|
|
config.extend(model_config.into_iter());
|
|
|
|
let model_name = model_path.file_name().unwrap().to_str().unwrap();
|
|
|
|
config["shard"] = serde_json::json!({
|
|
"model_id": model_name,
|
|
"start_layer": shard.start_layer,
|
|
"end_layer": shard.end_layer,
|
|
"n_layers": shard.total_layers,
|
|
});
|
|
|
|
let mut weight_files = glob::glob(model_path.join("model*.safetensors").to_str().unwrap())
|
|
.unwrap()
|
|
.collect::<Result<Vec<_>, _>>()
|
|
.unwrap();
|
|
|
|
let model_args = serde_json::from_value::<ModelArgsRs>(Value::Object(config)).unwrap();
|
|
let model = LlamaModelRs::new(model_args).unwrap();
|
|
|
|
weight_files.sort_by_key(|x| x.file_name().unwrap().to_str().unwrap().to_string());
|
|
|
|
let mut weights = HashMap::new();
|
|
|
|
for weight_file in weight_files {
|
|
weights.extend(mlx_rs::Array::load_safetensors(&weight_file).unwrap());
|
|
}
|
|
|
|
todo!();
|
|
}
|
|
|
|
#[test]
|
|
fn test_load_llama() {
|
|
load_model_shard(
|
|
|
|
)
|
|
}
|