Stash so much code
This commit is contained in:
parent
a129325ade
commit
01e352bc2e
@ -1,9 +1,10 @@
|
||||
use std::collections::HashMap;
|
||||
use crate::node_service::{InferenceState, Tensor};
|
||||
use crate::Shard;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct InferenceEngine {
|
||||
state_cache: HashMap<String, _>,
|
||||
// state_cache: HashMap<String, _>,
|
||||
}
|
||||
|
||||
impl InferenceEngine {
|
||||
|
||||
@ -1,17 +1,17 @@
|
||||
use mlx_rs::builder::Builder;
|
||||
use mlx_rs::macros::ModuleParameters;
|
||||
use mlx_rs::module::Module;
|
||||
use mlx_rs::module::ModuleParameters;
|
||||
use mlx_rs::module::Param;
|
||||
use mlx_rs::nested::NestedHashMap;
|
||||
use mlx_rs::nn;
|
||||
use mlx_rs::Array;
|
||||
use std::rc::Rc;
|
||||
use mlx_rs::builder::Builder;
|
||||
use mlx_rs::nn::RmsNormBuilder;
|
||||
use mlx_rs::Array;
|
||||
use serde::Deserialize;
|
||||
use std::rc::Rc;
|
||||
|
||||
// Define Shard struct to mirror Python dataclass
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct Shard {
|
||||
pub name: String,
|
||||
pub start_layer: usize,
|
||||
@ -43,31 +43,219 @@ impl Shard {
|
||||
// Define ModelArgs struct to mirror Python dataclass ModelArgs
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ModelArgsRs {
|
||||
pub vocab_size: i32,
|
||||
pub model_type: String,
|
||||
pub hidden_size: i32,
|
||||
pub num_hidden_layers: i32,
|
||||
pub intermediate_size: i32,
|
||||
pub num_attention_heads: i32,
|
||||
pub num_key_value_heads: i32,
|
||||
pub rms_norm_eps: f32,
|
||||
pub vocab_size: i32,
|
||||
pub head_dim: Option<i32>,
|
||||
pub max_position_embeddings: Option<i32>, // Added max_position_embeddings
|
||||
pub num_key_value_heads: Option<i32>, // Added num_key_value_heads
|
||||
pub attention_bias: bool, // Added attention_bias
|
||||
pub mlp_bias: bool, // Added mlp_bias
|
||||
pub rope_theta: f32, // Added rope_theta
|
||||
pub rope_traditional: bool, // Added rope_traditional
|
||||
// pub rope_scaling: Option<Dict<str, Union<float, str>>>, // Complex type, needs handling if needed
|
||||
pub tie_word_embeddings: bool,
|
||||
pub model_type: String, // Assuming model_type is a String
|
||||
pub head_dim: Option<i32>, // Using Option to represent optional field
|
||||
pub shard: Shard, // Using the Shard struct defined above
|
||||
|
||||
#[serde(default)]
|
||||
pub shard: Shard, // Using the Shard struct defined above
|
||||
}
|
||||
|
||||
impl ModelArgsRs {
|
||||
// Add a constructor or builder pattern here if needed for easier initialization
|
||||
pub fn new(
|
||||
model_type: String,
|
||||
hidden_size: i32,
|
||||
num_hidden_layers: i32,
|
||||
intermediate_size: i32,
|
||||
num_attention_heads: i32,
|
||||
rms_norm_eps: f32,
|
||||
vocab_size: i32,
|
||||
tie_word_embeddings: bool,
|
||||
shard: Shard,
|
||||
) -> Self {
|
||||
ModelArgsRs {
|
||||
model_type,
|
||||
hidden_size,
|
||||
num_hidden_layers,
|
||||
intermediate_size,
|
||||
num_attention_heads,
|
||||
rms_norm_eps,
|
||||
vocab_size,
|
||||
head_dim: None, // Default value
|
||||
max_position_embeddings: None, // Default value
|
||||
num_key_value_heads: None, // Default value
|
||||
attention_bias: false, // Default value
|
||||
mlp_bias: false, // Default value
|
||||
rope_theta: 10000.0, // Default value
|
||||
rope_traditional: false, // Default value
|
||||
// rope_scaling: None, // Default value
|
||||
tie_word_embeddings,
|
||||
shard,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Define Attention struct
|
||||
#[derive(Debug, Clone, ModuleParameters)]
|
||||
pub struct Attention {
|
||||
pub q_proj: nn::Linear,
|
||||
pub k_proj: nn::Linear,
|
||||
pub v_proj: nn::Linear,
|
||||
pub o_proj: nn::Linear,
|
||||
pub n_heads: i32,
|
||||
pub n_kv_heads: i32,
|
||||
pub head_dim: i32,
|
||||
pub scale: f32,
|
||||
// pub rope: Rope, // Placeholder for Rope implementation
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
pub fn new(args: &ModelArgsRs) -> Result<Self, mlx_rs::error::Exception> {
|
||||
let dim = args.hidden_size;
|
||||
let n_heads = args.num_attention_heads;
|
||||
let n_kv_heads = args.num_key_value_heads.unwrap_or(args.num_attention_heads);
|
||||
let head_dim = args.head_dim.unwrap_or(args.hidden_size / n_heads); // Default head_dim calculation
|
||||
let scale = (head_dim as f32).powf(-0.5);
|
||||
let attention_bias = args.attention_bias; // Use bias from args
|
||||
|
||||
let q_proj = nn::Linear::new(dim, n_heads * head_dim)?;
|
||||
let k_proj = nn::Linear::new(dim, n_kv_heads * head_dim)?;
|
||||
let v_proj = nn::Linear::new(dim, n_kv_heads * head_dim)?;
|
||||
let o_proj = nn::Linear::new(n_heads * head_dim, dim)?;
|
||||
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
k_proj,
|
||||
v_proj,
|
||||
o_proj,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
head_dim,
|
||||
scale,
|
||||
// rope: Rope::new(...) // Initialize Rope here when implemented
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module<&Array> for Attention {
|
||||
type Output = Array;
|
||||
type Error = mlx_rs::error::Exception;
|
||||
|
||||
fn forward(&mut self, input: &Array) -> Result<Self::Output, Self::Error> {
|
||||
// Placeholder for actual attention logic
|
||||
// Need to implement:
|
||||
// 1. Projections (q_proj, k_proj, v_proj)
|
||||
// 2. Reshape and transpose for multi-head
|
||||
// 3. RoPE application
|
||||
// 4. Scaled dot-product attention
|
||||
// 5. Output projection (o_proj)
|
||||
|
||||
let q = self.q_proj.forward(input)?;
|
||||
let k = self.k_proj.forward(input)?;
|
||||
let v = self.v_proj.forward(input)?;
|
||||
|
||||
// Placeholder - directly return v projection for now
|
||||
self.o_proj.forward(&v)
|
||||
}
|
||||
|
||||
fn training_mode(&mut self, mode: bool) {
|
||||
self.q_proj.training_mode(mode);
|
||||
self.k_proj.training_mode(mode);
|
||||
self.v_proj.training_mode(mode);
|
||||
self.o_proj.training_mode(mode);
|
||||
}
|
||||
}
|
||||
|
||||
// Define MLP struct
|
||||
#[derive(Debug, Clone, ModuleParameters)]
|
||||
pub struct MLP {
|
||||
pub gate_proj: nn::Linear,
|
||||
pub down_proj: nn::Linear,
|
||||
pub up_proj: nn::Linear,
|
||||
mlp_bias: bool, // Store mlp_bias
|
||||
}
|
||||
|
||||
impl MLP {
|
||||
pub fn new(args: &ModelArgsRs) -> Result<Self, mlx_rs::error::Exception> {
|
||||
let dim = args.hidden_size;
|
||||
let hidden_dim = args.intermediate_size;
|
||||
let mlp_bias = args.mlp_bias; // Get mlp_bias from args
|
||||
let gate_proj = nn::Linear::new(dim, hidden_dim)?;
|
||||
let down_proj = nn::Linear::new(hidden_dim, dim)?;
|
||||
let up_proj = nn::Linear::new(dim, hidden_dim)?;
|
||||
|
||||
Ok(Self {
|
||||
gate_proj,
|
||||
down_proj,
|
||||
up_proj,
|
||||
mlp_bias, // Store mlp_bias
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module<&Array> for MLP {
|
||||
type Output = Array;
|
||||
type Error = mlx_rs::error::Exception;
|
||||
|
||||
fn forward(&mut self, input: &Array) -> Result<Self::Output, Self::Error> {
|
||||
// Implement MLP forward pass using nn::silu
|
||||
let gate_output = self.gate_proj.forward(input)?;
|
||||
let silu_output = nn::silu(&gate_output)?; // Apply silu activation
|
||||
let up_output = self.up_proj.forward(input)?;
|
||||
let combined_output = silu_output * up_output; // Element-wise multiplication
|
||||
self.down_proj.forward(&combined_output) // Final projection
|
||||
}
|
||||
|
||||
fn training_mode(&mut self, mode: bool) {
|
||||
self.gate_proj.training_mode(mode);
|
||||
self.down_proj.training_mode(mode);
|
||||
self.up_proj.training_mode(mode);
|
||||
}
|
||||
}
|
||||
// ... existing code ...
|
||||
|
||||
// ... existing code ...
|
||||
// Placeholder for TransformerBlock - You'll need to implement this in Rust
|
||||
#[derive(Debug, Clone, ModuleParameters)]
|
||||
pub struct TransformerBlock {
|
||||
// Define the layers within TransformerBlock as needed, e.g., attention, norm, etc.
|
||||
// For now, using a linear layer as a placeholder
|
||||
pub linear: nn::Linear,
|
||||
self_attn: Attention,
|
||||
mlp: MLP,
|
||||
input_layernorm: nn::RmsNorm,
|
||||
post_attention_layernorm: nn::RmsNorm,
|
||||
args: ModelArgsRs, // Store args for potential use within TransformerBlock
|
||||
}
|
||||
|
||||
impl TransformerBlock {
|
||||
pub fn new(dims: i32, mlp_dims: i32) -> Self {
|
||||
Self {
|
||||
linear: nn::Linear::new(dims, mlp_dims).unwrap(), // Example, adjust params
|
||||
}
|
||||
pub fn new(args: ModelArgsRs) -> Result<Self, mlx_rs::error::Exception> {
|
||||
let linear = nn::Linear::new(1, 1).unwrap(); // Dummy linear layer, will be removed
|
||||
let self_attn = Attention::new(&args)?;
|
||||
let mlp = MLP::new(&args)?;
|
||||
|
||||
let input_layernorm = nn::RmsNormBuilder::new(args.hidden_size)
|
||||
.eps(args.rms_norm_eps)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let post_attention_layernorm = nn::RmsNormBuilder::new(args.hidden_size)
|
||||
.eps(args.rms_norm_eps)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
Ok(Self {
|
||||
linear, // Dummy linear layer, will be removed in future
|
||||
self_attn,
|
||||
mlp,
|
||||
input_layernorm,
|
||||
post_attention_layernorm,
|
||||
args, // Store args
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -78,20 +266,36 @@ impl Module<&Array> for TransformerBlock {
|
||||
fn forward(&mut self, input: &Array) -> Result<Self::Output, Self::Error> {
|
||||
// Implement the forward pass logic for TransformerBlock
|
||||
// For now, just passing through the linear layer
|
||||
self.linear.forward(input)
|
||||
// self.linear.forward(input) // Old placeholder
|
||||
|
||||
let normed_input = self.input_layernorm.forward(input)?;
|
||||
let attention_output = self.self_attn.forward(&normed_input)?;
|
||||
let hidden_state = input + &attention_output; // Residual connection
|
||||
|
||||
let normed_hidden_state = self.post_attention_layernorm.forward(&hidden_state)?;
|
||||
let mlp_output = self.mlp.forward(&normed_hidden_state)?;
|
||||
let output = hidden_state + &mlp_output; // Residual connection
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn training_mode(&mut self, mode: bool) {}
|
||||
fn training_mode(&mut self, mode: bool) {
|
||||
self.self_attn.training_mode(mode);
|
||||
self.mlp.training_mode(mode);
|
||||
self.input_layernorm.training_mode(mode);
|
||||
self.post_attention_layernorm.training_mode(mode);
|
||||
}
|
||||
}
|
||||
// ... existing code ...
|
||||
|
||||
// Define LlamaModel struct
|
||||
// ... existing code ...
|
||||
#[derive(Debug, Clone, ModuleParameters)]
|
||||
pub struct LlamaModelRs {
|
||||
pub args: ModelArgsRs,
|
||||
pub vocab_size: i32,
|
||||
pub num_hidden_layers: i32,
|
||||
pub embed_tokens: Option<nn::Embedding>, // Embedding layer is optional based on sharding
|
||||
pub layers: Vec<TransformerBlock>, // Using placeholder TransformerBlock
|
||||
pub layers: Vec<TransformerBlock>, // Using TransformerBlock
|
||||
pub norm: Option<nn::RmsNorm>, // RMSNorm layer is optional based on sharding
|
||||
}
|
||||
|
||||
@ -104,30 +308,20 @@ impl LlamaModelRs {
|
||||
embed_tokens = Some(nn::Embedding::new(args.vocab_size, args.hidden_size)?);
|
||||
}
|
||||
|
||||
let mut layers = Vec::new();
|
||||
for i in 0..(num_hidden_layers as usize) {
|
||||
if args.shard.start_layer <= i && i <= args.shard.end_layer {
|
||||
// Using placeholder dimensions for TransformerBlock, adjust as needed
|
||||
layers.push(TransformerBlock::new(
|
||||
args.hidden_size,
|
||||
args.hidden_size * 4,
|
||||
));
|
||||
} else {
|
||||
// Placeholder for IdentityBlock - you might need to create a Rust version if needed
|
||||
// For now, just pushing a default TransformerBlock or handle differently
|
||||
layers.push(TransformerBlock::new(
|
||||
args.hidden_size,
|
||||
args.hidden_size * 4,
|
||||
)); // IdentityBlock() in Python seems to be a no-op
|
||||
}
|
||||
let mut layers: Vec<TransformerBlock> = Vec::new(); // Specify type here
|
||||
for _ in 0..num_hidden_layers {
|
||||
// No sharding logic for now, apply to all layers - revisit sharding
|
||||
layers.push(TransformerBlock::new(args.clone())?); // Pass cloned args
|
||||
}
|
||||
|
||||
let mut norm = None;
|
||||
if args.shard.is_last_layer() {
|
||||
norm = Some(RmsNormBuilder::new(args.hidden_size)
|
||||
.eps(args.rms_norm_eps)
|
||||
.build()
|
||||
.unwrap());
|
||||
norm = Some(
|
||||
nn::RmsNormBuilder::new(args.hidden_size)
|
||||
.eps(args.rms_norm_eps)
|
||||
.build()
|
||||
.unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
@ -154,28 +348,41 @@ impl Module<&Array> for LlamaModelRs {
|
||||
}
|
||||
|
||||
// Mask creation logic would go here - needs to be implemented in Rust
|
||||
// let mask = None;
|
||||
// if h.ndim() > 1 && h.shape()[1] > 1 {
|
||||
// mask = create_attention_mask(h, cache); // Need to port create_attention_mask to Rust
|
||||
// }
|
||||
let mask: Option<Array> = None; // Placeholder mask
|
||||
// if h.ndim() > 1 && h.shape()[1] > 1 {
|
||||
// mask = create_attention_mask(h, cache); // Need to port create_attention_mask to Rust
|
||||
// }
|
||||
|
||||
// Cache handling - needs more detailed implementation for Rust
|
||||
// let mut cache = cache.unwrap_or_else(|| vec![None; self.layers.len()]);
|
||||
let mut cache: Option<Vec<Option<Array>>> = None; // Placeholder cache
|
||||
// let mut cache = cache.unwrap_or_else(|| vec![None; self.layers.len()]);
|
||||
|
||||
for layer in &mut self.layers {
|
||||
h = layer.forward(&h)?; // Pass mask and cache when implemented
|
||||
}
|
||||
|
||||
if self.args.shard.is_last_layer() && self.norm.is_some() {
|
||||
h = self.norm.as_ref().unwrap().forward(&h)?;
|
||||
}
|
||||
Ok(h)
|
||||
let normed_h = match &mut self.norm {
|
||||
Some(norm_layer) => norm_layer.forward(&h)?,
|
||||
None => h, // Skip norm if not the last layer
|
||||
};
|
||||
Ok(normed_h)
|
||||
}
|
||||
|
||||
fn training_mode(&mut self, mode: bool) {}
|
||||
fn training_mode(&mut self, mode: bool) {
|
||||
if let Some(embed_tokens) = &mut self.embed_tokens {
|
||||
embed_tokens.training_mode(mode);
|
||||
}
|
||||
for layer in &mut self.layers {
|
||||
layer.training_mode(mode);
|
||||
}
|
||||
if let Some(norm) = &mut self.norm {
|
||||
norm.training_mode(mode);
|
||||
}
|
||||
}
|
||||
}
|
||||
// ... existing code ...
|
||||
|
||||
// Define Model struct
|
||||
// ... existing code ...
|
||||
#[derive(Debug, Clone, ModuleParameters)]
|
||||
pub struct ModelRs {
|
||||
pub args: ModelArgsRs,
|
||||
@ -214,19 +421,30 @@ impl Module<&Array> for ModelRs {
|
||||
// Need to implement as_linear() equivalent in Rust or directly use embedding weights for linear transformation
|
||||
// Placeholder - direct linear transformation using embedding weights is not directly available in mlx-rs as in python
|
||||
if let Some(embed_tokens) = &self.model.embed_tokens {
|
||||
if let Ok(params) = embed_tokens.parameters() {
|
||||
if let Some(weight_param) = params.get("weight") {
|
||||
// This is a very simplified placeholder - needs proper matrix multiplication with 'out' and 'weight_param'
|
||||
out = weight_param.clone(); // Incorrect - replace with actual linear transformation
|
||||
}
|
||||
let params = embed_tokens.parameters();
|
||||
if let Some(weight_param) = params.entries.get("weight") {
|
||||
// This is a very simplified placeholder - needs proper matrix multiplication with 'out' and 'weight_param'
|
||||
// out = weight_param.clone(); // Incorrect - replace with actual linear transformation
|
||||
// Placeholder: use linear layer with embedding weights (not directly supported in mlx-rs)
|
||||
let embedding_weight = weight_param.array();
|
||||
let weight_array = embedding_weight.transpose()?; // Assuming weight needs transpose
|
||||
let weight_arr_ref = &weight_array;
|
||||
let out_matmul = out.matmul(weight_arr_ref)?; // Perform matrix multiplication
|
||||
out = out_matmul;
|
||||
}
|
||||
}
|
||||
} else if self.lm_head.is_some() {
|
||||
out = self.lm_head.as_ref().unwrap().forward(&out)?;
|
||||
} else if let Some(lm_head) = &mut self.lm_head {
|
||||
out = lm_head.forward(&out)?;
|
||||
}
|
||||
}
|
||||
Ok(out)
|
||||
}
|
||||
|
||||
fn training_mode(&mut self, mode: bool) {}
|
||||
fn training_mode(&mut self, mode: bool) {
|
||||
self.model.training_mode(mode);
|
||||
if let Some(lm_head) = &mut self.lm_head {
|
||||
lm_head.training_mode(mode);
|
||||
}
|
||||
}
|
||||
}
|
||||
// ... existing code ...
|
||||
|
||||
@ -39,7 +39,7 @@ fn test() {
|
||||
let mut n_cur = batch.n_tokens();
|
||||
|
||||
// The `Decoder`
|
||||
let mut decoder = encoding_rs::UTF_8.new_decoder();
|
||||
// let mut decoder = encoding_rs::UTF_8.new_decoder();
|
||||
let mut sampler = LlamaSampler::greedy();
|
||||
|
||||
while n_cur <= n_len {
|
||||
@ -58,9 +58,9 @@ fn test() {
|
||||
let output_bytes = model.token_to_bytes(token, Special::Tokenize).unwrap();
|
||||
// use `Decoder.decode_to_string()` to avoid the intermediate buffer
|
||||
let mut output_string = String::with_capacity(32);
|
||||
let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false);
|
||||
print!("{output_string}");
|
||||
std::io::stdout().flush().unwrap();
|
||||
// let _decode_result = decoder.decode_to_string(&output_bytes, &mut output_string, false);
|
||||
// print!("{output_string}");
|
||||
// std::io::stdout().flush().unwrap();
|
||||
|
||||
batch.clear();
|
||||
batch.add(token, n_cur, &[0], true).unwrap();
|
||||
|
||||
@ -45,7 +45,7 @@ async fn load_model_shard(
|
||||
let model_args = serde_json::from_value::<ModelArgsRs>(Value::Object(config)).unwrap();
|
||||
let model = LlamaModelRs::new(model_args).unwrap();
|
||||
|
||||
weight_files.sort_by_key(|x| x.file_name().unwrap().to_str().unwrap());
|
||||
weight_files.sort_by_key(|x| x.file_name().unwrap().to_str().unwrap().to_string());
|
||||
|
||||
let mut weights = HashMap::new();
|
||||
|
||||
@ -53,8 +53,6 @@ async fn load_model_shard(
|
||||
weights.extend(mlx_rs::Array::load_safetensors(&weight_file).unwrap());
|
||||
}
|
||||
|
||||
model
|
||||
|
||||
todo!();
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user