diff --git a/src/llama_module.rs b/src/llama_module.rs index 070dd20..591484d 100644 --- a/src/llama_module.rs +++ b/src/llama_module.rs @@ -1,450 +1,160 @@ +use crate::Shard; use mlx_rs::builder::Builder; use mlx_rs::macros::ModuleParameters; use mlx_rs::module::Module; -use mlx_rs::module::ModuleParameters; -use mlx_rs::module::Param; -use mlx_rs::nested::NestedHashMap; -use mlx_rs::nn; -use mlx_rs::nn::RmsNormBuilder; +use mlx_rs::nn::{Embedding, RmsNorm, RmsNormBuilder}; use mlx_rs::Array; use serde::Deserialize; -use std::rc::Rc; +use std::collections::HashMap; +use std::env::args; +use mlx_rs::ops::zeros; -// Define Shard struct to mirror Python dataclass -#[derive(Debug, Clone, Deserialize, Default)] -pub struct Shard { - pub name: String, - pub start_layer: usize, - pub end_layer: usize, +#[derive(Debug, Deserialize, ModuleParameters)] +struct ModelArgs { + vocab_size: i32, + hidden_size: i32, + num_hidden_layers: i32, + rms_norm_eps: f32, } -impl Shard { - pub fn new(name: String, start_layer: usize, end_layer: usize) -> Self { - Shard { - name, - start_layer, - end_layer, - } - } - - pub fn is_first_layer(&self) -> bool { - self.start_layer == 0 - } - - pub fn is_last_layer(&self) -> bool { - // Assuming end_layer is inclusive and represents the last layer index in the shard - // and num_hidden_layers is the total number of layers. - // We would need num_hidden_layers to accurately determine the last layer. - // For now, let's assume if end_layer is very large, it's the last layer in shard. - self.end_layer > 9999 // A large number as a placeholder, adjust as needed - } +#[derive(Debug)] +enum ShardedLayer { + TransformerBlock, + IdentityBlock, } -// Define ModelArgs struct to mirror Python dataclass ModelArgs -#[derive(Debug, Clone, Deserialize)] -pub struct ModelArgsRs { - pub model_type: String, - pub hidden_size: i32, - pub num_hidden_layers: i32, - pub intermediate_size: i32, - pub num_attention_heads: i32, - pub rms_norm_eps: f32, - pub vocab_size: i32, - pub head_dim: Option, - pub max_position_embeddings: Option, // Added max_position_embeddings - pub num_key_value_heads: Option, // Added num_key_value_heads - pub attention_bias: bool, // Added attention_bias - pub mlp_bias: bool, // Added mlp_bias - pub rope_theta: f32, // Added rope_theta - pub rope_traditional: bool, // Added rope_traditional - // pub rope_scaling: Option>>, // Complex type, needs handling if needed - pub tie_word_embeddings: bool, +#[derive(Debug, ModuleParameters)] +struct LlamaModel { + args: ModelArgs, + shard: Shard, + layers: Vec, - #[serde(default)] - pub shard: Shard, // Using the Shard struct defined above + embed_tokens: Embedding, + norm: RmsNorm, + + cache: Vec>, } -impl ModelArgsRs { - // Add a constructor or builder pattern here if needed for easier initialization - pub fn new( - model_type: String, - hidden_size: i32, - num_hidden_layers: i32, - intermediate_size: i32, - num_attention_heads: i32, - rms_norm_eps: f32, - vocab_size: i32, - tie_word_embeddings: bool, - shard: Shard, - ) -> Self { - ModelArgsRs { - model_type, - hidden_size, - num_hidden_layers, - intermediate_size, - num_attention_heads, - rms_norm_eps, - vocab_size, - head_dim: None, // Default value - max_position_embeddings: None, // Default value - num_key_value_heads: None, // Default value - attention_bias: false, // Default value - mlp_bias: false, // Default value - rope_theta: 10000.0, // Default value - rope_traditional: false, // Default value - // rope_scaling: None, // Default value - tie_word_embeddings, - shard, - } - } -} +impl LlamaModel { + fn new(args: ModelArgs, shard: Shard) -> Self { + let embed_tokens = Embedding::new(args.vocab_size, args.hidden_size).unwrap(); -// Define Attention struct -#[derive(Debug, Clone, ModuleParameters)] -pub struct Attention { - pub q_proj: nn::Linear, - pub k_proj: nn::Linear, - pub v_proj: nn::Linear, - pub o_proj: nn::Linear, - pub n_heads: i32, - pub n_kv_heads: i32, - pub head_dim: i32, - pub scale: f32, - // pub rope: Rope, // Placeholder for Rope implementation -} - -impl Attention { - pub fn new(args: &ModelArgsRs) -> Result { - let dim = args.hidden_size; - let n_heads = args.num_attention_heads; - let n_kv_heads = args.num_key_value_heads.unwrap_or(args.num_attention_heads); - let head_dim = args.head_dim.unwrap_or(args.hidden_size / n_heads); // Default head_dim calculation - let scale = (head_dim as f32).powf(-0.5); - let attention_bias = args.attention_bias; // Use bias from args - - let q_proj = nn::Linear::new(dim, n_heads * head_dim)?; - let k_proj = nn::Linear::new(dim, n_kv_heads * head_dim)?; - let v_proj = nn::Linear::new(dim, n_kv_heads * head_dim)?; - let o_proj = nn::Linear::new(n_heads * head_dim, dim)?; - - Ok(Self { - q_proj, - k_proj, - v_proj, - o_proj, - n_heads, - n_kv_heads, - head_dim, - scale, - // rope: Rope::new(...) // Initialize Rope here when implemented - }) - } -} - -impl Module<&Array> for Attention { - type Output = Array; - type Error = mlx_rs::error::Exception; - - fn forward(&mut self, input: &Array) -> Result { - // Placeholder for actual attention logic - // Need to implement: - // 1. Projections (q_proj, k_proj, v_proj) - // 2. Reshape and transpose for multi-head - // 3. RoPE application - // 4. Scaled dot-product attention - // 5. Output projection (o_proj) - - let q = self.q_proj.forward(input)?; - let k = self.k_proj.forward(input)?; - let v = self.v_proj.forward(input)?; - - // Placeholder - directly return v projection for now - self.o_proj.forward(&v) - } - - fn training_mode(&mut self, mode: bool) { - self.q_proj.training_mode(mode); - self.k_proj.training_mode(mode); - self.v_proj.training_mode(mode); - self.o_proj.training_mode(mode); - } -} - -// Define MLP struct -#[derive(Debug, Clone, ModuleParameters)] -pub struct MLP { - pub gate_proj: nn::Linear, - pub down_proj: nn::Linear, - pub up_proj: nn::Linear, - mlp_bias: bool, // Store mlp_bias -} - -impl MLP { - pub fn new(args: &ModelArgsRs) -> Result { - let dim = args.hidden_size; - let hidden_dim = args.intermediate_size; - let mlp_bias = args.mlp_bias; // Get mlp_bias from args - let gate_proj = nn::Linear::new(dim, hidden_dim)?; - let down_proj = nn::Linear::new(hidden_dim, dim)?; - let up_proj = nn::Linear::new(dim, hidden_dim)?; - - Ok(Self { - gate_proj, - down_proj, - up_proj, - mlp_bias, // Store mlp_bias - }) - } -} - -impl Module<&Array> for MLP { - type Output = Array; - type Error = mlx_rs::error::Exception; - - fn forward(&mut self, input: &Array) -> Result { - // Implement MLP forward pass using nn::silu - let gate_output = self.gate_proj.forward(input)?; - let silu_output = nn::silu(&gate_output)?; // Apply silu activation - let up_output = self.up_proj.forward(input)?; - let combined_output = silu_output * up_output; // Element-wise multiplication - self.down_proj.forward(&combined_output) // Final projection - } - - fn training_mode(&mut self, mode: bool) { - self.gate_proj.training_mode(mode); - self.down_proj.training_mode(mode); - self.up_proj.training_mode(mode); - } -} -// ... existing code ... - -// ... existing code ... -// Placeholder for TransformerBlock - You'll need to implement this in Rust -#[derive(Debug, Clone, ModuleParameters)] -pub struct TransformerBlock { - // Define the layers within TransformerBlock as needed, e.g., attention, norm, etc. - // For now, using a linear layer as a placeholder - pub linear: nn::Linear, - self_attn: Attention, - mlp: MLP, - input_layernorm: nn::RmsNorm, - post_attention_layernorm: nn::RmsNorm, - args: ModelArgsRs, // Store args for potential use within TransformerBlock -} - -impl TransformerBlock { - pub fn new(args: ModelArgsRs) -> Result { - let linear = nn::Linear::new(1, 1).unwrap(); // Dummy linear layer, will be removed - let self_attn = Attention::new(&args)?; - let mlp = MLP::new(&args)?; - - let input_layernorm = nn::RmsNormBuilder::new(args.hidden_size) - .eps(args.rms_norm_eps) - .build() - .unwrap(); - - let post_attention_layernorm = nn::RmsNormBuilder::new(args.hidden_size) - .eps(args.rms_norm_eps) - .build() - .unwrap(); - - Ok(Self { - linear, // Dummy linear layer, will be removed in future - self_attn, - mlp, - input_layernorm, - post_attention_layernorm, - args, // Store args - }) - } -} - -impl Module<&Array> for TransformerBlock { - type Output = Array; - type Error = mlx_rs::error::Exception; - - fn forward(&mut self, input: &Array) -> Result { - // Implement the forward pass logic for TransformerBlock - // For now, just passing through the linear layer - // self.linear.forward(input) // Old placeholder - - let normed_input = self.input_layernorm.forward(input)?; - let attention_output = self.self_attn.forward(&normed_input)?; - let hidden_state = input + &attention_output; // Residual connection - - let normed_hidden_state = self.post_attention_layernorm.forward(&hidden_state)?; - let mlp_output = self.mlp.forward(&normed_hidden_state)?; - let output = hidden_state + &mlp_output; // Residual connection - - Ok(output) - } - - fn training_mode(&mut self, mode: bool) { - self.self_attn.training_mode(mode); - self.mlp.training_mode(mode); - self.input_layernorm.training_mode(mode); - self.post_attention_layernorm.training_mode(mode); - } -} -// ... existing code ... - -// ... existing code ... -#[derive(Debug, Clone, ModuleParameters)] -pub struct LlamaModelRs { - pub args: ModelArgsRs, - pub vocab_size: i32, - pub num_hidden_layers: i32, - pub embed_tokens: Option, // Embedding layer is optional based on sharding - pub layers: Vec, // Using TransformerBlock - pub norm: Option, // RMSNorm layer is optional based on sharding -} - -impl LlamaModelRs { - pub fn new(args: ModelArgsRs) -> Result { - let vocab_size = args.vocab_size; - let num_hidden_layers = args.num_hidden_layers; - let mut embed_tokens = None; - if args.shard.is_first_layer() || (args.shard.is_last_layer() && args.tie_word_embeddings) { - embed_tokens = Some(nn::Embedding::new(args.vocab_size, args.hidden_size)?); - } - - let mut layers: Vec = Vec::new(); // Specify type here - for _ in 0..num_hidden_layers { - // No sharding logic for now, apply to all layers - revisit sharding - layers.push(TransformerBlock::new(args.clone())?); // Pass cloned args - } - - let mut norm = None; - if args.shard.is_last_layer() { - norm = Some( - nn::RmsNormBuilder::new(args.hidden_size) - .eps(args.rms_norm_eps) - .build() - .unwrap(), - ); - } - - Ok(Self { - args, - vocab_size, - num_hidden_layers, - embed_tokens, - layers, - norm, - }) - } -} - -impl Module<&Array> for LlamaModelRs { - type Output = Array; - type Error = mlx_rs::error::Exception; - - fn forward(&mut self, inputs: &Array) -> Result { - let mut h; - if self.args.shard.is_first_layer() && self.embed_tokens.is_some() { - h = self.embed_tokens.as_ref().unwrap().forward(inputs)?; - } else { - h = inputs.clone(); // Assuming input is already embedded if not the first layer - } - - // Mask creation logic would go here - needs to be implemented in Rust - let mask: Option = None; // Placeholder mask - // if h.ndim() > 1 && h.shape()[1] > 1 { - // mask = create_attention_mask(h, cache); // Need to port create_attention_mask to Rust - // } - - // Cache handling - needs more detailed implementation for Rust - let mut cache: Option>> = None; // Placeholder cache - // let mut cache = cache.unwrap_or_else(|| vec![None; self.layers.len()]); - - for layer in &mut self.layers { - h = layer.forward(&h)?; // Pass mask and cache when implemented - } - - let normed_h = match &mut self.norm { - Some(norm_layer) => norm_layer.forward(&h)?, - None => h, // Skip norm if not the last layer - }; - Ok(normed_h) - } - - fn training_mode(&mut self, mode: bool) { - if let Some(embed_tokens) = &mut self.embed_tokens { - embed_tokens.training_mode(mode); - } - for layer in &mut self.layers { - layer.training_mode(mode); - } - if let Some(norm) = &mut self.norm { - norm.training_mode(mode); - } - } -} -// ... existing code ... - -// ... existing code ... -#[derive(Debug, Clone, ModuleParameters)] -pub struct ModelRs { - pub args: ModelArgsRs, - pub model_type: String, - pub model: LlamaModelRs, - pub lm_head: Option, // Linear layer for language model head, optional based on tie_word_embeddings -} - -impl ModelRs { - pub fn new(args: ModelArgsRs) -> Result { - let model = LlamaModelRs::new(args.clone())?; // Clone args for LlamaModel - let model_type = args.model_type.clone(); - let mut lm_head = None; - if args.shard.is_last_layer() && !args.tie_word_embeddings { - lm_head = Some(nn::Linear::new(args.hidden_size, args.vocab_size)?); - } - - Ok(Self { - args, - model_type, - model, - lm_head, - }) - } -} - -impl Module<&Array> for ModelRs { - type Output = Array; - type Error = mlx_rs::error::Exception; - - fn forward(&mut self, inputs: &Array) -> Result { - let mut out = self.model.forward(inputs)?; - - if self.args.shard.is_last_layer() { - if self.args.tie_word_embeddings && self.model.embed_tokens.is_some() { - // Need to implement as_linear() equivalent in Rust or directly use embedding weights for linear transformation - // Placeholder - direct linear transformation using embedding weights is not directly available in mlx-rs as in python - if let Some(embed_tokens) = &self.model.embed_tokens { - let params = embed_tokens.parameters(); - if let Some(weight_param) = params.entries.get("weight") { - // This is a very simplified placeholder - needs proper matrix multiplication with 'out' and 'weight_param' - // out = weight_param.clone(); // Incorrect - replace with actual linear transformation - // Placeholder: use linear layer with embedding weights (not directly supported in mlx-rs) - let embedding_weight = weight_param.array(); - let weight_array = embedding_weight.transpose()?; // Assuming weight needs transpose - let weight_arr_ref = &weight_array; - let out_matmul = out.matmul(weight_arr_ref)?; // Perform matrix multiplication - out = out_matmul; - } - } - } else if let Some(lm_head) = &mut self.lm_head { - out = lm_head.forward(&out)?; + let layers = (0..(args.num_hidden_layers as u32)).map(|i| { + if shard.start_layer <= i && i <= shard.end_layer { + ShardedLayer::TransformerBlock + } else { + ShardedLayer::IdentityBlock } - } - Ok(out) - } + }).collect::>(); - fn training_mode(&mut self, mode: bool) { - self.model.training_mode(mode); - if let Some(lm_head) = &mut self.lm_head { - lm_head.training_mode(mode); + let norm = RmsNormBuilder::new(args.hidden_size) + .eps(args.rms_norm_eps) + .build() + .unwrap(); + + Self { + cache: vec![None; args.num_hidden_layers as usize], + args, + shard, + layers, + embed_tokens, + norm } } } -// ... existing code ... + +impl Module for LlamaModel { + type Output = Array; + type Error = mlx_rs::error::Exception; + + fn forward(&mut self, input: Array) -> Result { + let h = if self.shard.is_first_layer() { + self.embed_tokens.forward(&input)? + } else { + input + }; + + let mut mask = if h.ndim() > 1 && h.shape()[1] > 1 { + Some(create_attention_mask(&h, &self.cache)?) + } else { + None + }; + + let h = self.layers.iter_mut().zip(self.cache.iter_mut()) + .fold(h, |h, (layer, c)| { + layer.forward(&h, mask.as_ref(), c)? + }); + + let h = if self.shard.is_last_layer() { + self.norm.forward(&h)? + } else { + h + }; + + Ok(h) + } + + fn training_mode(&mut self, mode: bool) { + todo!() + } +} + + +fn create_attention_mask(h: &Array, cache: &[Option>]) -> Result { + let shape = h.shape(); + let t = shape[1]; + + if t > 1 { + let (window_size, offset) = match cache { + &[Some(ref cache), ..] => { + let offset = *cache.get("offset").unwrap(); + + if let Some(max_size) = cache.get("max_size") { + (Some(*max_size), i32::min(*max_size, offset)) + } else { + (None, offset) + } + }, + _ => (None, 0), + }; + + + let mask = create_causal_mask(t, offset, window_size, None)?; + mask.as_dtype(h.dtype()) + } else { + Ok(zeros(&[0])) // Return empty array when T <= 1 + } +} + +fn create_causal_mask( + n: i32, + offset: i32, + window_size: Option, + lengths: Option<&Array> +) -> Result { + let rinds = Array::arange(0, offset + n, 1)?; + let linds = if offset > 0 { + Array::arange(0, offset + n, 1)? + } else { + rinds.clone() + }; + + let linds = linds.reshape(&[-1, 1])?; + let rinds = rinds.reshape(&[1, -1])?; + + let mut mask = linds.lt(&rinds)?; + + if let Some(w) = window_size { + let window_mask = linds.gt(&(rinds + w))?; + mask = mask.logical_or(&window_mask)?; + } + + if let Some(l) = lengths { + let l = l.reshape(&[-1, 1, 1, 1])?; + let length_mask = rinds.greater_equal(&l)?; + mask = mask.logical_or(&length_mask)?; + } + + mask.multiply(-1e9) +} + diff --git a/src/module_loading.rs b/src/module_loading.rs index 707880f..edaef13 100644 --- a/src/module_loading.rs +++ b/src/module_loading.rs @@ -1,7 +1,6 @@ use std::collections::HashMap; use std::path::Path; use serde_json::Value; -use crate::llama_module::{LlamaModelRs, ModelArgsRs}; use crate::Shard; fn load_config(