Stash
This commit is contained in:
		
							parent
							
								
									01e352bc2e
								
							
						
					
					
						commit
						e551efc0a4
					
				| @ -1,450 +1,160 @@ | |||||||
|  | use crate::Shard; | ||||||
| use mlx_rs::builder::Builder; | use mlx_rs::builder::Builder; | ||||||
| use mlx_rs::macros::ModuleParameters; | use mlx_rs::macros::ModuleParameters; | ||||||
| use mlx_rs::module::Module; | use mlx_rs::module::Module; | ||||||
| use mlx_rs::module::ModuleParameters; | use mlx_rs::nn::{Embedding, RmsNorm, RmsNormBuilder}; | ||||||
| use mlx_rs::module::Param; |  | ||||||
| use mlx_rs::nested::NestedHashMap; |  | ||||||
| use mlx_rs::nn; |  | ||||||
| use mlx_rs::nn::RmsNormBuilder; |  | ||||||
| use mlx_rs::Array; | use mlx_rs::Array; | ||||||
| use serde::Deserialize; | use serde::Deserialize; | ||||||
| use std::rc::Rc; | use std::collections::HashMap; | ||||||
|  | use std::env::args; | ||||||
|  | use mlx_rs::ops::zeros; | ||||||
| 
 | 
 | ||||||
| // Define Shard struct to mirror Python dataclass
 | #[derive(Debug, Deserialize, ModuleParameters)] | ||||||
| #[derive(Debug, Clone, Deserialize, Default)] | struct ModelArgs { | ||||||
| pub struct Shard { |     vocab_size: i32, | ||||||
|     pub name: String, |     hidden_size: i32, | ||||||
|     pub start_layer: usize, |     num_hidden_layers: i32, | ||||||
|     pub end_layer: usize, |     rms_norm_eps: f32, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl Shard { | #[derive(Debug)] | ||||||
|     pub fn new(name: String, start_layer: usize, end_layer: usize) -> Self { | enum ShardedLayer { | ||||||
|         Shard { |     TransformerBlock, | ||||||
|             name, |     IdentityBlock, | ||||||
|             start_layer, |  | ||||||
|             end_layer, |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     pub fn is_first_layer(&self) -> bool { |  | ||||||
|         self.start_layer == 0 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     pub fn is_last_layer(&self) -> bool { |  | ||||||
|         // Assuming end_layer is inclusive and represents the last layer index in the shard
 |  | ||||||
|         // and num_hidden_layers is the total number of layers.
 |  | ||||||
|         // We would need num_hidden_layers to accurately determine the last layer.
 |  | ||||||
|         // For now, let's assume if end_layer is very large, it's the last layer in shard.
 |  | ||||||
|         self.end_layer > 9999 // A large number as a placeholder, adjust as needed
 |  | ||||||
|     } |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Define ModelArgs struct to mirror Python dataclass ModelArgs
 | #[derive(Debug, ModuleParameters)] | ||||||
| #[derive(Debug, Clone, Deserialize)] | struct LlamaModel { | ||||||
| pub struct ModelArgsRs { |     args: ModelArgs, | ||||||
|     pub model_type: String, |     shard: Shard, | ||||||
|     pub hidden_size: i32, |     layers: Vec<ShardedLayer>, | ||||||
|     pub num_hidden_layers: i32, |  | ||||||
|     pub intermediate_size: i32, |  | ||||||
|     pub num_attention_heads: i32, |  | ||||||
|     pub rms_norm_eps: f32, |  | ||||||
|     pub vocab_size: i32, |  | ||||||
|     pub head_dim: Option<i32>, |  | ||||||
|     pub max_position_embeddings: Option<i32>, // Added max_position_embeddings
 |  | ||||||
|     pub num_key_value_heads: Option<i32>,     // Added num_key_value_heads
 |  | ||||||
|     pub attention_bias: bool,                 // Added attention_bias
 |  | ||||||
|     pub mlp_bias: bool,                       // Added mlp_bias
 |  | ||||||
|     pub rope_theta: f32,                      // Added rope_theta
 |  | ||||||
|     pub rope_traditional: bool,               // Added rope_traditional
 |  | ||||||
|     // pub rope_scaling: Option<Dict<str, Union<float, str>>>, // Complex type, needs handling if needed
 |  | ||||||
|     pub tie_word_embeddings: bool, |  | ||||||
| 
 | 
 | ||||||
|     #[serde(default)] |     embed_tokens: Embedding, | ||||||
|     pub shard: Shard, // Using the Shard struct defined above
 |     norm: RmsNorm, | ||||||
|  | 
 | ||||||
|  |     cache: Vec<Option<Array>>, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl ModelArgsRs { | impl LlamaModel { | ||||||
|     // Add a constructor or builder pattern here if needed for easier initialization
 |     fn new(args: ModelArgs, shard: Shard) -> Self { | ||||||
|     pub fn new( |         let embed_tokens = Embedding::new(args.vocab_size, args.hidden_size).unwrap(); | ||||||
|         model_type: String, |  | ||||||
|         hidden_size: i32, |  | ||||||
|         num_hidden_layers: i32, |  | ||||||
|         intermediate_size: i32, |  | ||||||
|         num_attention_heads: i32, |  | ||||||
|         rms_norm_eps: f32, |  | ||||||
|         vocab_size: i32, |  | ||||||
|         tie_word_embeddings: bool, |  | ||||||
|         shard: Shard, |  | ||||||
|     ) -> Self { |  | ||||||
|         ModelArgsRs { |  | ||||||
|             model_type, |  | ||||||
|             hidden_size, |  | ||||||
|             num_hidden_layers, |  | ||||||
|             intermediate_size, |  | ||||||
|             num_attention_heads, |  | ||||||
|             rms_norm_eps, |  | ||||||
|             vocab_size, |  | ||||||
|             head_dim: None,                // Default value
 |  | ||||||
|             max_position_embeddings: None, // Default value
 |  | ||||||
|             num_key_value_heads: None,     // Default value
 |  | ||||||
|             attention_bias: false,         // Default value
 |  | ||||||
|             mlp_bias: false,               // Default value
 |  | ||||||
|             rope_theta: 10000.0,           // Default value
 |  | ||||||
|             rope_traditional: false,       // Default value
 |  | ||||||
|             // rope_scaling: None,           // Default value
 |  | ||||||
|             tie_word_embeddings, |  | ||||||
|             shard, |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 | 
 | ||||||
| // Define Attention struct
 |         let layers = (0..(args.num_hidden_layers as u32)).map(|i| { | ||||||
| #[derive(Debug, Clone, ModuleParameters)] |             if shard.start_layer <= i && i <= shard.end_layer { | ||||||
| pub struct Attention { |                 ShardedLayer::TransformerBlock | ||||||
|     pub q_proj: nn::Linear, |             } else { | ||||||
|     pub k_proj: nn::Linear, |                 ShardedLayer::IdentityBlock | ||||||
|     pub v_proj: nn::Linear, |  | ||||||
|     pub o_proj: nn::Linear, |  | ||||||
|     pub n_heads: i32, |  | ||||||
|     pub n_kv_heads: i32, |  | ||||||
|     pub head_dim: i32, |  | ||||||
|     pub scale: f32, |  | ||||||
|     // pub rope: Rope, // Placeholder for Rope implementation
 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl Attention { |  | ||||||
|     pub fn new(args: &ModelArgsRs) -> Result<Self, mlx_rs::error::Exception> { |  | ||||||
|         let dim = args.hidden_size; |  | ||||||
|         let n_heads = args.num_attention_heads; |  | ||||||
|         let n_kv_heads = args.num_key_value_heads.unwrap_or(args.num_attention_heads); |  | ||||||
|         let head_dim = args.head_dim.unwrap_or(args.hidden_size / n_heads); // Default head_dim calculation
 |  | ||||||
|         let scale = (head_dim as f32).powf(-0.5); |  | ||||||
|         let attention_bias = args.attention_bias; // Use bias from args
 |  | ||||||
| 
 |  | ||||||
|         let q_proj = nn::Linear::new(dim, n_heads * head_dim)?; |  | ||||||
|         let k_proj = nn::Linear::new(dim, n_kv_heads * head_dim)?; |  | ||||||
|         let v_proj = nn::Linear::new(dim, n_kv_heads * head_dim)?; |  | ||||||
|         let o_proj = nn::Linear::new(n_heads * head_dim, dim)?; |  | ||||||
| 
 |  | ||||||
|         Ok(Self { |  | ||||||
|             q_proj, |  | ||||||
|             k_proj, |  | ||||||
|             v_proj, |  | ||||||
|             o_proj, |  | ||||||
|             n_heads, |  | ||||||
|             n_kv_heads, |  | ||||||
|             head_dim, |  | ||||||
|             scale, |  | ||||||
|             // rope: Rope::new(...) // Initialize Rope here when implemented
 |  | ||||||
|         }) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl Module<&Array> for Attention { |  | ||||||
|     type Output = Array; |  | ||||||
|     type Error = mlx_rs::error::Exception; |  | ||||||
| 
 |  | ||||||
|     fn forward(&mut self, input: &Array) -> Result<Self::Output, Self::Error> { |  | ||||||
|         // Placeholder for actual attention logic
 |  | ||||||
|         // Need to implement:
 |  | ||||||
|         // 1. Projections (q_proj, k_proj, v_proj)
 |  | ||||||
|         // 2. Reshape and transpose for multi-head
 |  | ||||||
|         // 3. RoPE application
 |  | ||||||
|         // 4. Scaled dot-product attention
 |  | ||||||
|         // 5. Output projection (o_proj)
 |  | ||||||
| 
 |  | ||||||
|         let q = self.q_proj.forward(input)?; |  | ||||||
|         let k = self.k_proj.forward(input)?; |  | ||||||
|         let v = self.v_proj.forward(input)?; |  | ||||||
| 
 |  | ||||||
|         // Placeholder - directly return v projection for now
 |  | ||||||
|         self.o_proj.forward(&v) |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     fn training_mode(&mut self, mode: bool) { |  | ||||||
|         self.q_proj.training_mode(mode); |  | ||||||
|         self.k_proj.training_mode(mode); |  | ||||||
|         self.v_proj.training_mode(mode); |  | ||||||
|         self.o_proj.training_mode(mode); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // Define MLP struct
 |  | ||||||
| #[derive(Debug, Clone, ModuleParameters)] |  | ||||||
| pub struct MLP { |  | ||||||
|     pub gate_proj: nn::Linear, |  | ||||||
|     pub down_proj: nn::Linear, |  | ||||||
|     pub up_proj: nn::Linear, |  | ||||||
|     mlp_bias: bool, // Store mlp_bias
 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl MLP { |  | ||||||
|     pub fn new(args: &ModelArgsRs) -> Result<Self, mlx_rs::error::Exception> { |  | ||||||
|         let dim = args.hidden_size; |  | ||||||
|         let hidden_dim = args.intermediate_size; |  | ||||||
|         let mlp_bias = args.mlp_bias; // Get mlp_bias from args
 |  | ||||||
|         let gate_proj = nn::Linear::new(dim, hidden_dim)?; |  | ||||||
|         let down_proj = nn::Linear::new(hidden_dim, dim)?; |  | ||||||
|         let up_proj = nn::Linear::new(dim, hidden_dim)?; |  | ||||||
| 
 |  | ||||||
|         Ok(Self { |  | ||||||
|             gate_proj, |  | ||||||
|             down_proj, |  | ||||||
|             up_proj, |  | ||||||
|             mlp_bias, // Store mlp_bias
 |  | ||||||
|         }) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl Module<&Array> for MLP { |  | ||||||
|     type Output = Array; |  | ||||||
|     type Error = mlx_rs::error::Exception; |  | ||||||
| 
 |  | ||||||
|     fn forward(&mut self, input: &Array) -> Result<Self::Output, Self::Error> { |  | ||||||
|         // Implement MLP forward pass using nn::silu
 |  | ||||||
|         let gate_output = self.gate_proj.forward(input)?; |  | ||||||
|         let silu_output = nn::silu(&gate_output)?; // Apply silu activation
 |  | ||||||
|         let up_output = self.up_proj.forward(input)?; |  | ||||||
|         let combined_output = silu_output * up_output; // Element-wise multiplication
 |  | ||||||
|         self.down_proj.forward(&combined_output) // Final projection
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     fn training_mode(&mut self, mode: bool) { |  | ||||||
|         self.gate_proj.training_mode(mode); |  | ||||||
|         self.down_proj.training_mode(mode); |  | ||||||
|         self.up_proj.training_mode(mode); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| // ... existing code ...
 |  | ||||||
| 
 |  | ||||||
| // ... existing code ...
 |  | ||||||
| // Placeholder for TransformerBlock - You'll need to implement this in Rust
 |  | ||||||
| #[derive(Debug, Clone, ModuleParameters)] |  | ||||||
| pub struct TransformerBlock { |  | ||||||
|     // Define the layers within TransformerBlock as needed, e.g., attention, norm, etc.
 |  | ||||||
|     // For now, using a linear layer as a placeholder
 |  | ||||||
|     pub linear: nn::Linear, |  | ||||||
|     self_attn: Attention, |  | ||||||
|     mlp: MLP, |  | ||||||
|     input_layernorm: nn::RmsNorm, |  | ||||||
|     post_attention_layernorm: nn::RmsNorm, |  | ||||||
|     args: ModelArgsRs, // Store args for potential use within TransformerBlock
 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl TransformerBlock { |  | ||||||
|     pub fn new(args: ModelArgsRs) -> Result<Self, mlx_rs::error::Exception> { |  | ||||||
|         let linear = nn::Linear::new(1, 1).unwrap(); // Dummy linear layer, will be removed
 |  | ||||||
|         let self_attn = Attention::new(&args)?; |  | ||||||
|         let mlp = MLP::new(&args)?; |  | ||||||
| 
 |  | ||||||
|         let input_layernorm = nn::RmsNormBuilder::new(args.hidden_size) |  | ||||||
|             .eps(args.rms_norm_eps) |  | ||||||
|             .build() |  | ||||||
|             .unwrap(); |  | ||||||
| 
 |  | ||||||
|         let post_attention_layernorm = nn::RmsNormBuilder::new(args.hidden_size) |  | ||||||
|             .eps(args.rms_norm_eps) |  | ||||||
|             .build() |  | ||||||
|             .unwrap(); |  | ||||||
| 
 |  | ||||||
|         Ok(Self { |  | ||||||
|             linear, // Dummy linear layer, will be removed in future
 |  | ||||||
|             self_attn, |  | ||||||
|             mlp, |  | ||||||
|             input_layernorm, |  | ||||||
|             post_attention_layernorm, |  | ||||||
|             args, // Store args
 |  | ||||||
|         }) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl Module<&Array> for TransformerBlock { |  | ||||||
|     type Output = Array; |  | ||||||
|     type Error = mlx_rs::error::Exception; |  | ||||||
| 
 |  | ||||||
|     fn forward(&mut self, input: &Array) -> Result<Self::Output, Self::Error> { |  | ||||||
|         // Implement the forward pass logic for TransformerBlock
 |  | ||||||
|         // For now, just passing through the linear layer
 |  | ||||||
|         // self.linear.forward(input) // Old placeholder
 |  | ||||||
| 
 |  | ||||||
|         let normed_input = self.input_layernorm.forward(input)?; |  | ||||||
|         let attention_output = self.self_attn.forward(&normed_input)?; |  | ||||||
|         let hidden_state = input + &attention_output; // Residual connection
 |  | ||||||
| 
 |  | ||||||
|         let normed_hidden_state = self.post_attention_layernorm.forward(&hidden_state)?; |  | ||||||
|         let mlp_output = self.mlp.forward(&normed_hidden_state)?; |  | ||||||
|         let output = hidden_state + &mlp_output; // Residual connection
 |  | ||||||
| 
 |  | ||||||
|         Ok(output) |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     fn training_mode(&mut self, mode: bool) { |  | ||||||
|         self.self_attn.training_mode(mode); |  | ||||||
|         self.mlp.training_mode(mode); |  | ||||||
|         self.input_layernorm.training_mode(mode); |  | ||||||
|         self.post_attention_layernorm.training_mode(mode); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| // ... existing code ...
 |  | ||||||
| 
 |  | ||||||
| // ... existing code ...
 |  | ||||||
| #[derive(Debug, Clone, ModuleParameters)] |  | ||||||
| pub struct LlamaModelRs { |  | ||||||
|     pub args: ModelArgsRs, |  | ||||||
|     pub vocab_size: i32, |  | ||||||
|     pub num_hidden_layers: i32, |  | ||||||
|     pub embed_tokens: Option<nn::Embedding>, // Embedding layer is optional based on sharding
 |  | ||||||
|     pub layers: Vec<TransformerBlock>,       // Using TransformerBlock
 |  | ||||||
|     pub norm: Option<nn::RmsNorm>,           // RMSNorm layer is optional based on sharding
 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl LlamaModelRs { |  | ||||||
|     pub fn new(args: ModelArgsRs) -> Result<Self, mlx_rs::error::Exception> { |  | ||||||
|         let vocab_size = args.vocab_size; |  | ||||||
|         let num_hidden_layers = args.num_hidden_layers; |  | ||||||
|         let mut embed_tokens = None; |  | ||||||
|         if args.shard.is_first_layer() || (args.shard.is_last_layer() && args.tie_word_embeddings) { |  | ||||||
|             embed_tokens = Some(nn::Embedding::new(args.vocab_size, args.hidden_size)?); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         let mut layers: Vec<TransformerBlock> = Vec::new(); // Specify type here
 |  | ||||||
|         for _ in 0..num_hidden_layers { |  | ||||||
|             // No sharding logic for now, apply to all layers - revisit sharding
 |  | ||||||
|             layers.push(TransformerBlock::new(args.clone())?); // Pass cloned args
 |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         let mut norm = None; |  | ||||||
|         if args.shard.is_last_layer() { |  | ||||||
|             norm = Some( |  | ||||||
|                 nn::RmsNormBuilder::new(args.hidden_size) |  | ||||||
|                     .eps(args.rms_norm_eps) |  | ||||||
|                     .build() |  | ||||||
|                     .unwrap(), |  | ||||||
|             ); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         Ok(Self { |  | ||||||
|             args, |  | ||||||
|             vocab_size, |  | ||||||
|             num_hidden_layers, |  | ||||||
|             embed_tokens, |  | ||||||
|             layers, |  | ||||||
|             norm, |  | ||||||
|         }) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl Module<&Array> for LlamaModelRs { |  | ||||||
|     type Output = Array; |  | ||||||
|     type Error = mlx_rs::error::Exception; |  | ||||||
| 
 |  | ||||||
|     fn forward(&mut self, inputs: &Array) -> Result<Self::Output, Self::Error> { |  | ||||||
|         let mut h; |  | ||||||
|         if self.args.shard.is_first_layer() && self.embed_tokens.is_some() { |  | ||||||
|             h = self.embed_tokens.as_ref().unwrap().forward(inputs)?; |  | ||||||
|         } else { |  | ||||||
|             h = inputs.clone(); // Assuming input is already embedded if not the first layer
 |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         // Mask creation logic would go here - needs to be implemented in Rust
 |  | ||||||
|         let mask: Option<Array> = None; // Placeholder mask
 |  | ||||||
|                                         // if h.ndim() > 1 && h.shape()[1] > 1 {
 |  | ||||||
|                                         //     mask = create_attention_mask(h, cache); // Need to port create_attention_mask to Rust
 |  | ||||||
|                                         // }
 |  | ||||||
| 
 |  | ||||||
|         // Cache handling - needs more detailed implementation for Rust
 |  | ||||||
|         let mut cache: Option<Vec<Option<Array>>> = None; // Placeholder cache
 |  | ||||||
|                                                           // let mut cache = cache.unwrap_or_else(|| vec![None; self.layers.len()]);
 |  | ||||||
| 
 |  | ||||||
|         for layer in &mut self.layers { |  | ||||||
|             h = layer.forward(&h)?; // Pass mask and cache when implemented
 |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         let normed_h = match &mut self.norm { |  | ||||||
|             Some(norm_layer) => norm_layer.forward(&h)?, |  | ||||||
|             None => h, // Skip norm if not the last layer
 |  | ||||||
|         }; |  | ||||||
|         Ok(normed_h) |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     fn training_mode(&mut self, mode: bool) { |  | ||||||
|         if let Some(embed_tokens) = &mut self.embed_tokens { |  | ||||||
|             embed_tokens.training_mode(mode); |  | ||||||
|         } |  | ||||||
|         for layer in &mut self.layers { |  | ||||||
|             layer.training_mode(mode); |  | ||||||
|         } |  | ||||||
|         if let Some(norm) = &mut self.norm { |  | ||||||
|             norm.training_mode(mode); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| // ... existing code ...
 |  | ||||||
| 
 |  | ||||||
| // ... existing code ...
 |  | ||||||
| #[derive(Debug, Clone, ModuleParameters)] |  | ||||||
| pub struct ModelRs { |  | ||||||
|     pub args: ModelArgsRs, |  | ||||||
|     pub model_type: String, |  | ||||||
|     pub model: LlamaModelRs, |  | ||||||
|     pub lm_head: Option<nn::Linear>, // Linear layer for language model head, optional based on tie_word_embeddings
 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl ModelRs { |  | ||||||
|     pub fn new(args: ModelArgsRs) -> Result<Self, mlx_rs::error::Exception> { |  | ||||||
|         let model = LlamaModelRs::new(args.clone())?; // Clone args for LlamaModel
 |  | ||||||
|         let model_type = args.model_type.clone(); |  | ||||||
|         let mut lm_head = None; |  | ||||||
|         if args.shard.is_last_layer() && !args.tie_word_embeddings { |  | ||||||
|             lm_head = Some(nn::Linear::new(args.hidden_size, args.vocab_size)?); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         Ok(Self { |  | ||||||
|             args, |  | ||||||
|             model_type, |  | ||||||
|             model, |  | ||||||
|             lm_head, |  | ||||||
|         }) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl Module<&Array> for ModelRs { |  | ||||||
|     type Output = Array; |  | ||||||
|     type Error = mlx_rs::error::Exception; |  | ||||||
| 
 |  | ||||||
|     fn forward(&mut self, inputs: &Array) -> Result<Self::Output, Self::Error> { |  | ||||||
|         let mut out = self.model.forward(inputs)?; |  | ||||||
| 
 |  | ||||||
|         if self.args.shard.is_last_layer() { |  | ||||||
|             if self.args.tie_word_embeddings && self.model.embed_tokens.is_some() { |  | ||||||
|                 // Need to implement as_linear() equivalent in Rust or directly use embedding weights for linear transformation
 |  | ||||||
|                 // Placeholder - direct linear transformation using embedding weights is not directly available in mlx-rs as in python
 |  | ||||||
|                 if let Some(embed_tokens) = &self.model.embed_tokens { |  | ||||||
|                     let params = embed_tokens.parameters(); |  | ||||||
|                     if let Some(weight_param) = params.entries.get("weight") { |  | ||||||
|                         // This is a very simplified placeholder - needs proper matrix multiplication with 'out' and 'weight_param'
 |  | ||||||
|                         // out = weight_param.clone(); // Incorrect - replace with actual linear transformation
 |  | ||||||
|                         // Placeholder: use linear layer with embedding weights (not directly supported in mlx-rs)
 |  | ||||||
|                         let embedding_weight = weight_param.array(); |  | ||||||
|                         let weight_array = embedding_weight.transpose()?; // Assuming weight needs transpose
 |  | ||||||
|                         let weight_arr_ref = &weight_array; |  | ||||||
|                         let out_matmul = out.matmul(weight_arr_ref)?; // Perform matrix multiplication
 |  | ||||||
|                         out = out_matmul; |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             } else if let Some(lm_head) = &mut self.lm_head { |  | ||||||
|                 out = lm_head.forward(&out)?; |  | ||||||
|             } |             } | ||||||
|         } |         }).collect::<Vec<_>>(); | ||||||
|         Ok(out) |  | ||||||
|     } |  | ||||||
| 
 | 
 | ||||||
|     fn training_mode(&mut self, mode: bool) { |         let norm = RmsNormBuilder::new(args.hidden_size) | ||||||
|         self.model.training_mode(mode); |             .eps(args.rms_norm_eps) | ||||||
|         if let Some(lm_head) = &mut self.lm_head { |             .build() | ||||||
|             lm_head.training_mode(mode); |             .unwrap(); | ||||||
|  | 
 | ||||||
|  |         Self { | ||||||
|  |             cache: vec![None; args.num_hidden_layers as usize], | ||||||
|  |             args, | ||||||
|  |             shard, | ||||||
|  |             layers, | ||||||
|  |             embed_tokens, | ||||||
|  |             norm | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
| // ... existing code ...
 | 
 | ||||||
|  | impl Module<Array> for LlamaModel { | ||||||
|  |     type Output = Array; | ||||||
|  |     type Error = mlx_rs::error::Exception; | ||||||
|  | 
 | ||||||
|  |     fn forward(&mut self, input: Array) -> Result<Self::Output, Self::Error> { | ||||||
|  |         let h = if self.shard.is_first_layer() { | ||||||
|  |             self.embed_tokens.forward(&input)? | ||||||
|  |         } else { | ||||||
|  |             input | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  |         let mut mask = if h.ndim() > 1 && h.shape()[1] > 1 { | ||||||
|  |             Some(create_attention_mask(&h, &self.cache)?) | ||||||
|  |         } else { | ||||||
|  |             None | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  |         let h = self.layers.iter_mut().zip(self.cache.iter_mut()) | ||||||
|  |         .fold(h, |h, (layer, c)| { | ||||||
|  |             layer.forward(&h, mask.as_ref(), c)? | ||||||
|  |         }); | ||||||
|  | 
 | ||||||
|  |         let h = if self.shard.is_last_layer() { | ||||||
|  |             self.norm.forward(&h)? | ||||||
|  |         } else { | ||||||
|  |             h | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  |         Ok(h)    
 | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn training_mode(&mut self, mode: bool) { | ||||||
|  |         todo!() | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | fn create_attention_mask(h: &Array, cache: &[Option<HashMap<String, i32>>]) -> Result<Array, mlx_rs::error::Exception> { | ||||||
|  |     let shape = h.shape(); | ||||||
|  |     let t = shape[1]; | ||||||
|  | 
 | ||||||
|  |     if t > 1 { | ||||||
|  |         let (window_size, offset) = match cache { | ||||||
|  |             &[Some(ref cache), ..] => { | ||||||
|  |                 let offset = *cache.get("offset").unwrap(); | ||||||
|  | 
 | ||||||
|  |                 if let Some(max_size) = cache.get("max_size") { | ||||||
|  |                     (Some(*max_size), i32::min(*max_size, offset)) | ||||||
|  |                 } else { | ||||||
|  |                     (None, offset) | ||||||
|  |                 } | ||||||
|  |             }, | ||||||
|  |             _ => (None, 0), | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  |         let mask = create_causal_mask(t, offset, window_size, None)?; | ||||||
|  |         mask.as_dtype(h.dtype()) | ||||||
|  |     } else { | ||||||
|  |         Ok(zeros(&[0])) // Return empty array when T <= 1
 | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | fn create_causal_mask( | ||||||
|  |     n: i32, | ||||||
|  |     offset: i32, | ||||||
|  |     window_size: Option<i32>, | ||||||
|  |     lengths: Option<&Array> | ||||||
|  | ) -> Result<Array, mlx_rs::error::Exception> { | ||||||
|  |     let rinds = Array::arange(0, offset + n, 1)?; | ||||||
|  |     let linds = if offset > 0 { | ||||||
|  |         Array::arange(0, offset + n, 1)? | ||||||
|  |     } else { | ||||||
|  |         rinds.clone() | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     let linds = linds.reshape(&[-1, 1])?; | ||||||
|  |     let rinds = rinds.reshape(&[1, -1])?; | ||||||
|  |     
 | ||||||
|  |     let mut mask = linds.lt(&rinds)?; | ||||||
|  | 
 | ||||||
|  |     if let Some(w) = window_size { | ||||||
|  |         let window_mask = linds.gt(&(rinds + w))?; | ||||||
|  |         mask = mask.logical_or(&window_mask)?; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if let Some(l) = lengths { | ||||||
|  |         let l = l.reshape(&[-1, 1, 1, 1])?; | ||||||
|  |         let length_mask = rinds.greater_equal(&l)?; | ||||||
|  |         mask = mask.logical_or(&length_mask)?; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     mask.multiply(-1e9) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | |||||||
| @ -1,7 +1,6 @@ | |||||||
| use std::collections::HashMap; | use std::collections::HashMap; | ||||||
| use std::path::Path; | use std::path::Path; | ||||||
| use serde_json::Value; | use serde_json::Value; | ||||||
| use crate::llama_module::{LlamaModelRs, ModelArgsRs}; |  | ||||||
| use crate::Shard; | use crate::Shard; | ||||||
| 
 | 
 | ||||||
| fn load_config( | fn load_config( | ||||||
|  | |||||||
		Loading…
	
		Reference in New Issue
	
	Block a user