Stash
This commit is contained in:
parent
01e352bc2e
commit
e551efc0a4
@ -1,450 +1,160 @@
|
|||||||
|
use crate::Shard;
|
||||||
use mlx_rs::builder::Builder;
|
use mlx_rs::builder::Builder;
|
||||||
use mlx_rs::macros::ModuleParameters;
|
use mlx_rs::macros::ModuleParameters;
|
||||||
use mlx_rs::module::Module;
|
use mlx_rs::module::Module;
|
||||||
use mlx_rs::module::ModuleParameters;
|
use mlx_rs::nn::{Embedding, RmsNorm, RmsNormBuilder};
|
||||||
use mlx_rs::module::Param;
|
|
||||||
use mlx_rs::nested::NestedHashMap;
|
|
||||||
use mlx_rs::nn;
|
|
||||||
use mlx_rs::nn::RmsNormBuilder;
|
|
||||||
use mlx_rs::Array;
|
use mlx_rs::Array;
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::rc::Rc;
|
use std::collections::HashMap;
|
||||||
|
use std::env::args;
|
||||||
|
use mlx_rs::ops::zeros;
|
||||||
|
|
||||||
// Define Shard struct to mirror Python dataclass
|
#[derive(Debug, Deserialize, ModuleParameters)]
|
||||||
#[derive(Debug, Clone, Deserialize, Default)]
|
struct ModelArgs {
|
||||||
pub struct Shard {
|
vocab_size: i32,
|
||||||
pub name: String,
|
hidden_size: i32,
|
||||||
pub start_layer: usize,
|
num_hidden_layers: i32,
|
||||||
pub end_layer: usize,
|
rms_norm_eps: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Shard {
|
#[derive(Debug)]
|
||||||
pub fn new(name: String, start_layer: usize, end_layer: usize) -> Self {
|
enum ShardedLayer {
|
||||||
Shard {
|
TransformerBlock,
|
||||||
name,
|
IdentityBlock,
|
||||||
start_layer,
|
|
||||||
end_layer,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn is_first_layer(&self) -> bool {
|
|
||||||
self.start_layer == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn is_last_layer(&self) -> bool {
|
|
||||||
// Assuming end_layer is inclusive and represents the last layer index in the shard
|
|
||||||
// and num_hidden_layers is the total number of layers.
|
|
||||||
// We would need num_hidden_layers to accurately determine the last layer.
|
|
||||||
// For now, let's assume if end_layer is very large, it's the last layer in shard.
|
|
||||||
self.end_layer > 9999 // A large number as a placeholder, adjust as needed
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Define ModelArgs struct to mirror Python dataclass ModelArgs
|
#[derive(Debug, ModuleParameters)]
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
struct LlamaModel {
|
||||||
pub struct ModelArgsRs {
|
args: ModelArgs,
|
||||||
pub model_type: String,
|
shard: Shard,
|
||||||
pub hidden_size: i32,
|
layers: Vec<ShardedLayer>,
|
||||||
pub num_hidden_layers: i32,
|
|
||||||
pub intermediate_size: i32,
|
|
||||||
pub num_attention_heads: i32,
|
|
||||||
pub rms_norm_eps: f32,
|
|
||||||
pub vocab_size: i32,
|
|
||||||
pub head_dim: Option<i32>,
|
|
||||||
pub max_position_embeddings: Option<i32>, // Added max_position_embeddings
|
|
||||||
pub num_key_value_heads: Option<i32>, // Added num_key_value_heads
|
|
||||||
pub attention_bias: bool, // Added attention_bias
|
|
||||||
pub mlp_bias: bool, // Added mlp_bias
|
|
||||||
pub rope_theta: f32, // Added rope_theta
|
|
||||||
pub rope_traditional: bool, // Added rope_traditional
|
|
||||||
// pub rope_scaling: Option<Dict<str, Union<float, str>>>, // Complex type, needs handling if needed
|
|
||||||
pub tie_word_embeddings: bool,
|
|
||||||
|
|
||||||
#[serde(default)]
|
embed_tokens: Embedding,
|
||||||
pub shard: Shard, // Using the Shard struct defined above
|
norm: RmsNorm,
|
||||||
|
|
||||||
|
cache: Vec<Option<Array>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ModelArgsRs {
|
impl LlamaModel {
|
||||||
// Add a constructor or builder pattern here if needed for easier initialization
|
fn new(args: ModelArgs, shard: Shard) -> Self {
|
||||||
pub fn new(
|
let embed_tokens = Embedding::new(args.vocab_size, args.hidden_size).unwrap();
|
||||||
model_type: String,
|
|
||||||
hidden_size: i32,
|
|
||||||
num_hidden_layers: i32,
|
|
||||||
intermediate_size: i32,
|
|
||||||
num_attention_heads: i32,
|
|
||||||
rms_norm_eps: f32,
|
|
||||||
vocab_size: i32,
|
|
||||||
tie_word_embeddings: bool,
|
|
||||||
shard: Shard,
|
|
||||||
) -> Self {
|
|
||||||
ModelArgsRs {
|
|
||||||
model_type,
|
|
||||||
hidden_size,
|
|
||||||
num_hidden_layers,
|
|
||||||
intermediate_size,
|
|
||||||
num_attention_heads,
|
|
||||||
rms_norm_eps,
|
|
||||||
vocab_size,
|
|
||||||
head_dim: None, // Default value
|
|
||||||
max_position_embeddings: None, // Default value
|
|
||||||
num_key_value_heads: None, // Default value
|
|
||||||
attention_bias: false, // Default value
|
|
||||||
mlp_bias: false, // Default value
|
|
||||||
rope_theta: 10000.0, // Default value
|
|
||||||
rope_traditional: false, // Default value
|
|
||||||
// rope_scaling: None, // Default value
|
|
||||||
tie_word_embeddings,
|
|
||||||
shard,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Define Attention struct
|
let layers = (0..(args.num_hidden_layers as u32)).map(|i| {
|
||||||
#[derive(Debug, Clone, ModuleParameters)]
|
if shard.start_layer <= i && i <= shard.end_layer {
|
||||||
pub struct Attention {
|
ShardedLayer::TransformerBlock
|
||||||
pub q_proj: nn::Linear,
|
} else {
|
||||||
pub k_proj: nn::Linear,
|
ShardedLayer::IdentityBlock
|
||||||
pub v_proj: nn::Linear,
|
|
||||||
pub o_proj: nn::Linear,
|
|
||||||
pub n_heads: i32,
|
|
||||||
pub n_kv_heads: i32,
|
|
||||||
pub head_dim: i32,
|
|
||||||
pub scale: f32,
|
|
||||||
// pub rope: Rope, // Placeholder for Rope implementation
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Attention {
|
|
||||||
pub fn new(args: &ModelArgsRs) -> Result<Self, mlx_rs::error::Exception> {
|
|
||||||
let dim = args.hidden_size;
|
|
||||||
let n_heads = args.num_attention_heads;
|
|
||||||
let n_kv_heads = args.num_key_value_heads.unwrap_or(args.num_attention_heads);
|
|
||||||
let head_dim = args.head_dim.unwrap_or(args.hidden_size / n_heads); // Default head_dim calculation
|
|
||||||
let scale = (head_dim as f32).powf(-0.5);
|
|
||||||
let attention_bias = args.attention_bias; // Use bias from args
|
|
||||||
|
|
||||||
let q_proj = nn::Linear::new(dim, n_heads * head_dim)?;
|
|
||||||
let k_proj = nn::Linear::new(dim, n_kv_heads * head_dim)?;
|
|
||||||
let v_proj = nn::Linear::new(dim, n_kv_heads * head_dim)?;
|
|
||||||
let o_proj = nn::Linear::new(n_heads * head_dim, dim)?;
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
q_proj,
|
|
||||||
k_proj,
|
|
||||||
v_proj,
|
|
||||||
o_proj,
|
|
||||||
n_heads,
|
|
||||||
n_kv_heads,
|
|
||||||
head_dim,
|
|
||||||
scale,
|
|
||||||
// rope: Rope::new(...) // Initialize Rope here when implemented
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module<&Array> for Attention {
|
|
||||||
type Output = Array;
|
|
||||||
type Error = mlx_rs::error::Exception;
|
|
||||||
|
|
||||||
fn forward(&mut self, input: &Array) -> Result<Self::Output, Self::Error> {
|
|
||||||
// Placeholder for actual attention logic
|
|
||||||
// Need to implement:
|
|
||||||
// 1. Projections (q_proj, k_proj, v_proj)
|
|
||||||
// 2. Reshape and transpose for multi-head
|
|
||||||
// 3. RoPE application
|
|
||||||
// 4. Scaled dot-product attention
|
|
||||||
// 5. Output projection (o_proj)
|
|
||||||
|
|
||||||
let q = self.q_proj.forward(input)?;
|
|
||||||
let k = self.k_proj.forward(input)?;
|
|
||||||
let v = self.v_proj.forward(input)?;
|
|
||||||
|
|
||||||
// Placeholder - directly return v projection for now
|
|
||||||
self.o_proj.forward(&v)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn training_mode(&mut self, mode: bool) {
|
|
||||||
self.q_proj.training_mode(mode);
|
|
||||||
self.k_proj.training_mode(mode);
|
|
||||||
self.v_proj.training_mode(mode);
|
|
||||||
self.o_proj.training_mode(mode);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Define MLP struct
|
|
||||||
#[derive(Debug, Clone, ModuleParameters)]
|
|
||||||
pub struct MLP {
|
|
||||||
pub gate_proj: nn::Linear,
|
|
||||||
pub down_proj: nn::Linear,
|
|
||||||
pub up_proj: nn::Linear,
|
|
||||||
mlp_bias: bool, // Store mlp_bias
|
|
||||||
}
|
|
||||||
|
|
||||||
impl MLP {
|
|
||||||
pub fn new(args: &ModelArgsRs) -> Result<Self, mlx_rs::error::Exception> {
|
|
||||||
let dim = args.hidden_size;
|
|
||||||
let hidden_dim = args.intermediate_size;
|
|
||||||
let mlp_bias = args.mlp_bias; // Get mlp_bias from args
|
|
||||||
let gate_proj = nn::Linear::new(dim, hidden_dim)?;
|
|
||||||
let down_proj = nn::Linear::new(hidden_dim, dim)?;
|
|
||||||
let up_proj = nn::Linear::new(dim, hidden_dim)?;
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
gate_proj,
|
|
||||||
down_proj,
|
|
||||||
up_proj,
|
|
||||||
mlp_bias, // Store mlp_bias
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module<&Array> for MLP {
|
|
||||||
type Output = Array;
|
|
||||||
type Error = mlx_rs::error::Exception;
|
|
||||||
|
|
||||||
fn forward(&mut self, input: &Array) -> Result<Self::Output, Self::Error> {
|
|
||||||
// Implement MLP forward pass using nn::silu
|
|
||||||
let gate_output = self.gate_proj.forward(input)?;
|
|
||||||
let silu_output = nn::silu(&gate_output)?; // Apply silu activation
|
|
||||||
let up_output = self.up_proj.forward(input)?;
|
|
||||||
let combined_output = silu_output * up_output; // Element-wise multiplication
|
|
||||||
self.down_proj.forward(&combined_output) // Final projection
|
|
||||||
}
|
|
||||||
|
|
||||||
fn training_mode(&mut self, mode: bool) {
|
|
||||||
self.gate_proj.training_mode(mode);
|
|
||||||
self.down_proj.training_mode(mode);
|
|
||||||
self.up_proj.training_mode(mode);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// ... existing code ...
|
|
||||||
|
|
||||||
// ... existing code ...
|
|
||||||
// Placeholder for TransformerBlock - You'll need to implement this in Rust
|
|
||||||
#[derive(Debug, Clone, ModuleParameters)]
|
|
||||||
pub struct TransformerBlock {
|
|
||||||
// Define the layers within TransformerBlock as needed, e.g., attention, norm, etc.
|
|
||||||
// For now, using a linear layer as a placeholder
|
|
||||||
pub linear: nn::Linear,
|
|
||||||
self_attn: Attention,
|
|
||||||
mlp: MLP,
|
|
||||||
input_layernorm: nn::RmsNorm,
|
|
||||||
post_attention_layernorm: nn::RmsNorm,
|
|
||||||
args: ModelArgsRs, // Store args for potential use within TransformerBlock
|
|
||||||
}
|
|
||||||
|
|
||||||
impl TransformerBlock {
|
|
||||||
pub fn new(args: ModelArgsRs) -> Result<Self, mlx_rs::error::Exception> {
|
|
||||||
let linear = nn::Linear::new(1, 1).unwrap(); // Dummy linear layer, will be removed
|
|
||||||
let self_attn = Attention::new(&args)?;
|
|
||||||
let mlp = MLP::new(&args)?;
|
|
||||||
|
|
||||||
let input_layernorm = nn::RmsNormBuilder::new(args.hidden_size)
|
|
||||||
.eps(args.rms_norm_eps)
|
|
||||||
.build()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let post_attention_layernorm = nn::RmsNormBuilder::new(args.hidden_size)
|
|
||||||
.eps(args.rms_norm_eps)
|
|
||||||
.build()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
linear, // Dummy linear layer, will be removed in future
|
|
||||||
self_attn,
|
|
||||||
mlp,
|
|
||||||
input_layernorm,
|
|
||||||
post_attention_layernorm,
|
|
||||||
args, // Store args
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module<&Array> for TransformerBlock {
|
|
||||||
type Output = Array;
|
|
||||||
type Error = mlx_rs::error::Exception;
|
|
||||||
|
|
||||||
fn forward(&mut self, input: &Array) -> Result<Self::Output, Self::Error> {
|
|
||||||
// Implement the forward pass logic for TransformerBlock
|
|
||||||
// For now, just passing through the linear layer
|
|
||||||
// self.linear.forward(input) // Old placeholder
|
|
||||||
|
|
||||||
let normed_input = self.input_layernorm.forward(input)?;
|
|
||||||
let attention_output = self.self_attn.forward(&normed_input)?;
|
|
||||||
let hidden_state = input + &attention_output; // Residual connection
|
|
||||||
|
|
||||||
let normed_hidden_state = self.post_attention_layernorm.forward(&hidden_state)?;
|
|
||||||
let mlp_output = self.mlp.forward(&normed_hidden_state)?;
|
|
||||||
let output = hidden_state + &mlp_output; // Residual connection
|
|
||||||
|
|
||||||
Ok(output)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn training_mode(&mut self, mode: bool) {
|
|
||||||
self.self_attn.training_mode(mode);
|
|
||||||
self.mlp.training_mode(mode);
|
|
||||||
self.input_layernorm.training_mode(mode);
|
|
||||||
self.post_attention_layernorm.training_mode(mode);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// ... existing code ...
|
|
||||||
|
|
||||||
// ... existing code ...
|
|
||||||
#[derive(Debug, Clone, ModuleParameters)]
|
|
||||||
pub struct LlamaModelRs {
|
|
||||||
pub args: ModelArgsRs,
|
|
||||||
pub vocab_size: i32,
|
|
||||||
pub num_hidden_layers: i32,
|
|
||||||
pub embed_tokens: Option<nn::Embedding>, // Embedding layer is optional based on sharding
|
|
||||||
pub layers: Vec<TransformerBlock>, // Using TransformerBlock
|
|
||||||
pub norm: Option<nn::RmsNorm>, // RMSNorm layer is optional based on sharding
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LlamaModelRs {
|
|
||||||
pub fn new(args: ModelArgsRs) -> Result<Self, mlx_rs::error::Exception> {
|
|
||||||
let vocab_size = args.vocab_size;
|
|
||||||
let num_hidden_layers = args.num_hidden_layers;
|
|
||||||
let mut embed_tokens = None;
|
|
||||||
if args.shard.is_first_layer() || (args.shard.is_last_layer() && args.tie_word_embeddings) {
|
|
||||||
embed_tokens = Some(nn::Embedding::new(args.vocab_size, args.hidden_size)?);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut layers: Vec<TransformerBlock> = Vec::new(); // Specify type here
|
|
||||||
for _ in 0..num_hidden_layers {
|
|
||||||
// No sharding logic for now, apply to all layers - revisit sharding
|
|
||||||
layers.push(TransformerBlock::new(args.clone())?); // Pass cloned args
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut norm = None;
|
|
||||||
if args.shard.is_last_layer() {
|
|
||||||
norm = Some(
|
|
||||||
nn::RmsNormBuilder::new(args.hidden_size)
|
|
||||||
.eps(args.rms_norm_eps)
|
|
||||||
.build()
|
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
args,
|
|
||||||
vocab_size,
|
|
||||||
num_hidden_layers,
|
|
||||||
embed_tokens,
|
|
||||||
layers,
|
|
||||||
norm,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module<&Array> for LlamaModelRs {
|
|
||||||
type Output = Array;
|
|
||||||
type Error = mlx_rs::error::Exception;
|
|
||||||
|
|
||||||
fn forward(&mut self, inputs: &Array) -> Result<Self::Output, Self::Error> {
|
|
||||||
let mut h;
|
|
||||||
if self.args.shard.is_first_layer() && self.embed_tokens.is_some() {
|
|
||||||
h = self.embed_tokens.as_ref().unwrap().forward(inputs)?;
|
|
||||||
} else {
|
|
||||||
h = inputs.clone(); // Assuming input is already embedded if not the first layer
|
|
||||||
}
|
|
||||||
|
|
||||||
// Mask creation logic would go here - needs to be implemented in Rust
|
|
||||||
let mask: Option<Array> = None; // Placeholder mask
|
|
||||||
// if h.ndim() > 1 && h.shape()[1] > 1 {
|
|
||||||
// mask = create_attention_mask(h, cache); // Need to port create_attention_mask to Rust
|
|
||||||
// }
|
|
||||||
|
|
||||||
// Cache handling - needs more detailed implementation for Rust
|
|
||||||
let mut cache: Option<Vec<Option<Array>>> = None; // Placeholder cache
|
|
||||||
// let mut cache = cache.unwrap_or_else(|| vec![None; self.layers.len()]);
|
|
||||||
|
|
||||||
for layer in &mut self.layers {
|
|
||||||
h = layer.forward(&h)?; // Pass mask and cache when implemented
|
|
||||||
}
|
|
||||||
|
|
||||||
let normed_h = match &mut self.norm {
|
|
||||||
Some(norm_layer) => norm_layer.forward(&h)?,
|
|
||||||
None => h, // Skip norm if not the last layer
|
|
||||||
};
|
|
||||||
Ok(normed_h)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn training_mode(&mut self, mode: bool) {
|
|
||||||
if let Some(embed_tokens) = &mut self.embed_tokens {
|
|
||||||
embed_tokens.training_mode(mode);
|
|
||||||
}
|
|
||||||
for layer in &mut self.layers {
|
|
||||||
layer.training_mode(mode);
|
|
||||||
}
|
|
||||||
if let Some(norm) = &mut self.norm {
|
|
||||||
norm.training_mode(mode);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// ... existing code ...
|
|
||||||
|
|
||||||
// ... existing code ...
|
|
||||||
#[derive(Debug, Clone, ModuleParameters)]
|
|
||||||
pub struct ModelRs {
|
|
||||||
pub args: ModelArgsRs,
|
|
||||||
pub model_type: String,
|
|
||||||
pub model: LlamaModelRs,
|
|
||||||
pub lm_head: Option<nn::Linear>, // Linear layer for language model head, optional based on tie_word_embeddings
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ModelRs {
|
|
||||||
pub fn new(args: ModelArgsRs) -> Result<Self, mlx_rs::error::Exception> {
|
|
||||||
let model = LlamaModelRs::new(args.clone())?; // Clone args for LlamaModel
|
|
||||||
let model_type = args.model_type.clone();
|
|
||||||
let mut lm_head = None;
|
|
||||||
if args.shard.is_last_layer() && !args.tie_word_embeddings {
|
|
||||||
lm_head = Some(nn::Linear::new(args.hidden_size, args.vocab_size)?);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
args,
|
|
||||||
model_type,
|
|
||||||
model,
|
|
||||||
lm_head,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module<&Array> for ModelRs {
|
|
||||||
type Output = Array;
|
|
||||||
type Error = mlx_rs::error::Exception;
|
|
||||||
|
|
||||||
fn forward(&mut self, inputs: &Array) -> Result<Self::Output, Self::Error> {
|
|
||||||
let mut out = self.model.forward(inputs)?;
|
|
||||||
|
|
||||||
if self.args.shard.is_last_layer() {
|
|
||||||
if self.args.tie_word_embeddings && self.model.embed_tokens.is_some() {
|
|
||||||
// Need to implement as_linear() equivalent in Rust or directly use embedding weights for linear transformation
|
|
||||||
// Placeholder - direct linear transformation using embedding weights is not directly available in mlx-rs as in python
|
|
||||||
if let Some(embed_tokens) = &self.model.embed_tokens {
|
|
||||||
let params = embed_tokens.parameters();
|
|
||||||
if let Some(weight_param) = params.entries.get("weight") {
|
|
||||||
// This is a very simplified placeholder - needs proper matrix multiplication with 'out' and 'weight_param'
|
|
||||||
// out = weight_param.clone(); // Incorrect - replace with actual linear transformation
|
|
||||||
// Placeholder: use linear layer with embedding weights (not directly supported in mlx-rs)
|
|
||||||
let embedding_weight = weight_param.array();
|
|
||||||
let weight_array = embedding_weight.transpose()?; // Assuming weight needs transpose
|
|
||||||
let weight_arr_ref = &weight_array;
|
|
||||||
let out_matmul = out.matmul(weight_arr_ref)?; // Perform matrix multiplication
|
|
||||||
out = out_matmul;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if let Some(lm_head) = &mut self.lm_head {
|
|
||||||
out = lm_head.forward(&out)?;
|
|
||||||
}
|
}
|
||||||
}
|
}).collect::<Vec<_>>();
|
||||||
Ok(out)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn training_mode(&mut self, mode: bool) {
|
let norm = RmsNormBuilder::new(args.hidden_size)
|
||||||
self.model.training_mode(mode);
|
.eps(args.rms_norm_eps)
|
||||||
if let Some(lm_head) = &mut self.lm_head {
|
.build()
|
||||||
lm_head.training_mode(mode);
|
.unwrap();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
cache: vec![None; args.num_hidden_layers as usize],
|
||||||
|
args,
|
||||||
|
shard,
|
||||||
|
layers,
|
||||||
|
embed_tokens,
|
||||||
|
norm
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// ... existing code ...
|
|
||||||
|
impl Module<Array> for LlamaModel {
|
||||||
|
type Output = Array;
|
||||||
|
type Error = mlx_rs::error::Exception;
|
||||||
|
|
||||||
|
fn forward(&mut self, input: Array) -> Result<Self::Output, Self::Error> {
|
||||||
|
let h = if self.shard.is_first_layer() {
|
||||||
|
self.embed_tokens.forward(&input)?
|
||||||
|
} else {
|
||||||
|
input
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut mask = if h.ndim() > 1 && h.shape()[1] > 1 {
|
||||||
|
Some(create_attention_mask(&h, &self.cache)?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let h = self.layers.iter_mut().zip(self.cache.iter_mut())
|
||||||
|
.fold(h, |h, (layer, c)| {
|
||||||
|
layer.forward(&h, mask.as_ref(), c)?
|
||||||
|
});
|
||||||
|
|
||||||
|
let h = if self.shard.is_last_layer() {
|
||||||
|
self.norm.forward(&h)?
|
||||||
|
} else {
|
||||||
|
h
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn training_mode(&mut self, mode: bool) {
|
||||||
|
todo!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
fn create_attention_mask(h: &Array, cache: &[Option<HashMap<String, i32>>]) -> Result<Array, mlx_rs::error::Exception> {
|
||||||
|
let shape = h.shape();
|
||||||
|
let t = shape[1];
|
||||||
|
|
||||||
|
if t > 1 {
|
||||||
|
let (window_size, offset) = match cache {
|
||||||
|
&[Some(ref cache), ..] => {
|
||||||
|
let offset = *cache.get("offset").unwrap();
|
||||||
|
|
||||||
|
if let Some(max_size) = cache.get("max_size") {
|
||||||
|
(Some(*max_size), i32::min(*max_size, offset))
|
||||||
|
} else {
|
||||||
|
(None, offset)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => (None, 0),
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
let mask = create_causal_mask(t, offset, window_size, None)?;
|
||||||
|
mask.as_dtype(h.dtype())
|
||||||
|
} else {
|
||||||
|
Ok(zeros(&[0])) // Return empty array when T <= 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_causal_mask(
|
||||||
|
n: i32,
|
||||||
|
offset: i32,
|
||||||
|
window_size: Option<i32>,
|
||||||
|
lengths: Option<&Array>
|
||||||
|
) -> Result<Array, mlx_rs::error::Exception> {
|
||||||
|
let rinds = Array::arange(0, offset + n, 1)?;
|
||||||
|
let linds = if offset > 0 {
|
||||||
|
Array::arange(0, offset + n, 1)?
|
||||||
|
} else {
|
||||||
|
rinds.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
let linds = linds.reshape(&[-1, 1])?;
|
||||||
|
let rinds = rinds.reshape(&[1, -1])?;
|
||||||
|
|
||||||
|
let mut mask = linds.lt(&rinds)?;
|
||||||
|
|
||||||
|
if let Some(w) = window_size {
|
||||||
|
let window_mask = linds.gt(&(rinds + w))?;
|
||||||
|
mask = mask.logical_or(&window_mask)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(l) = lengths {
|
||||||
|
let l = l.reshape(&[-1, 1, 1, 1])?;
|
||||||
|
let length_mask = rinds.greater_equal(&l)?;
|
||||||
|
mask = mask.logical_or(&length_mask)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
mask.multiply(-1e9)
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use crate::llama_module::{LlamaModelRs, ModelArgsRs};
|
|
||||||
use crate::Shard;
|
use crate::Shard;
|
||||||
|
|
||||||
fn load_config(
|
fn load_config(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user