exo-rs/src/module_loading.rs
2025-02-12 17:32:09 +00:00

65 lines
1.8 KiB
Rust

use std::collections::HashMap;
use std::path::Path;
use serde_json::Value;
use crate::llama_module::{LlamaModelRs, ModelArgsRs};
use crate::Shard;
fn load_config(
model_path: &Path,
) -> serde_json::Map<String, serde_json::Value> {
let config_path = model_path.join("config.json");
let model_index = model_path.join("model_index.json");
if config_path.exists() {
let config = std::fs::read_to_string(config_path).unwrap();
serde_json::from_str(&config).unwrap()
} else {
let model_index = std::fs::read_to_string(model_index).unwrap();
serde_json::from_str(&model_index).unwrap()
}
}
async fn load_model_shard(
model_path: &Path,
shard: Shard,
lazy: bool,
model_config: serde_json::Map<String, serde_json::Value>,
) {
let mut config = load_config(model_path);
config.extend(model_config.into_iter());
let model_name = model_path.file_name().unwrap().to_str().unwrap();
config["shard"] = serde_json::json!({
"model_id": model_name,
"start_layer": shard.start_layer,
"end_layer": shard.end_layer,
"n_layers": shard.total_layers,
});
let mut weight_files = glob::glob(model_path.join("model*.safetensors").to_str().unwrap())
.unwrap()
.collect::<Result<Vec<_>, _>>()
.unwrap();
let model_args = serde_json::from_value::<ModelArgsRs>(Value::Object(config)).unwrap();
let model = LlamaModelRs::new(model_args).unwrap();
weight_files.sort_by_key(|x| x.file_name().unwrap().to_str().unwrap().to_string());
let mut weights = HashMap::new();
for weight_file in weight_files {
weights.extend(mlx_rs::Array::load_safetensors(&weight_file).unwrap());
}
todo!();
}
#[test]
fn test_load_llama() {
load_model_shard(
)
}