use std::collections::HashMap; use std::path::Path; use serde_json::Value; use crate::llama_module::{LlamaModelRs, ModelArgsRs}; use crate::Shard; fn load_config( model_path: &Path, ) -> serde_json::Map { let config_path = model_path.join("config.json"); let model_index = model_path.join("model_index.json"); if config_path.exists() { let config = std::fs::read_to_string(config_path).unwrap(); serde_json::from_str(&config).unwrap() } else { let model_index = std::fs::read_to_string(model_index).unwrap(); serde_json::from_str(&model_index).unwrap() } } async fn load_model_shard( model_path: &Path, shard: Shard, lazy: bool, model_config: serde_json::Map, ) { let mut config = load_config(model_path); config.extend(model_config.into_iter()); let model_name = model_path.file_name().unwrap().to_str().unwrap(); config["shard"] = serde_json::json!({ "model_id": model_name, "start_layer": shard.start_layer, "end_layer": shard.end_layer, "n_layers": shard.total_layers, }); let mut weight_files = glob::glob(model_path.join("model*.safetensors").to_str().unwrap()) .unwrap() .collect::, _>>() .unwrap(); let model_args = serde_json::from_value::(Value::Object(config)).unwrap(); let model = LlamaModelRs::new(model_args).unwrap(); weight_files.sort_by_key(|x| x.file_name().unwrap().to_str().unwrap().to_string()); let mut weights = HashMap::new(); for weight_file in weight_files { weights.extend(mlx_rs::Array::load_safetensors(&weight_file).unwrap()); } todo!(); } #[test] fn test_load_llama() { load_model_shard( ) }