diff --git a/entity/src/account.rs b/entity/src/account.rs new file mode 100644 index 0000000..0913664 --- /dev/null +++ b/entity/src/account.rs @@ -0,0 +1,26 @@ +//! `SeaORM` Entity, @generated by sea-orm-codegen 1.1.0 + +use sea_orm::entity::prelude::*; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Eq, Serialize, Deserialize)] +#[sea_orm(table_name = "account")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub name: String, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm(has_many = "super::transaction::Entity")] + Transaction, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Transaction.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/entity/src/expenditure.rs b/entity/src/expenditure.rs index c457147..6d6cad3 100644 --- a/entity/src/expenditure.rs +++ b/entity/src/expenditure.rs @@ -1,4 +1,4 @@ -//! `SeaORM` Entity. Generated by sea-orm-codegen 0.12.2 +//! `SeaORM` Entity, @generated by sea-orm-codegen 1.1.0 use sea_orm::entity::prelude::*; use serde::{Deserialize, Serialize}; @@ -14,6 +14,21 @@ pub struct Model { } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation {} +pub enum Relation { + #[sea_orm( + belongs_to = "super::transaction::Entity", + from = "Column::TransactionId", + to = "super::transaction::Column::Id", + on_update = "NoAction", + on_delete = "Cascade" + )] + Transaction, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Transaction.def() + } +} impl ActiveModelBehavior for ActiveModel {} diff --git a/entity/src/lib.rs b/entity/src/lib.rs index 76f7c0f..ecfd33c 100644 --- a/entity/src/lib.rs +++ b/entity/src/lib.rs @@ -1,6 +1,7 @@ -//! `SeaORM` Entity. Generated by sea-orm-codegen 0.12.2 +//! `SeaORM` Entity, @generated by sea-orm-codegen 1.1.0 pub mod prelude; +pub mod account; pub mod expenditure; pub mod transaction; diff --git a/entity/src/prelude.rs b/entity/src/prelude.rs index 6589503..464a526 100644 --- a/entity/src/prelude.rs +++ b/entity/src/prelude.rs @@ -1,4 +1,5 @@ -//! `SeaORM` Entity. Generated by sea-orm-codegen 0.12.2 +//! `SeaORM` Entity, @generated by sea-orm-codegen 1.1.0 +pub use super::account::Entity as Account; pub use super::expenditure::Entity as Expenditure; pub use super::transaction::Entity as Transaction; diff --git a/entity/src/transaction.rs b/entity/src/transaction.rs index b8c25e9..714d299 100644 --- a/entity/src/transaction.rs +++ b/entity/src/transaction.rs @@ -1,4 +1,4 @@ -//! `SeaORM` Entity. Generated by sea-orm-codegen 0.12.2 +//! `SeaORM` Entity, @generated by sea-orm-codegen 1.1.0 use sea_orm::entity::prelude::*; use serde::{Deserialize, Serialize}; @@ -18,9 +18,33 @@ pub struct Model { pub description: Option, #[sea_orm(unique)] pub identity_hash: Option, + pub account_id: Option, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation {} +pub enum Relation { + #[sea_orm( + belongs_to = "super::account::Entity", + from = "Column::AccountId", + to = "super::account::Column::Id", + on_update = "NoAction", + on_delete = "Cascade" + )] + Account, + #[sea_orm(has_many = "super::expenditure::Entity")] + Expenditure, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Account.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Expenditure.def() + } +} impl ActiveModelBehavior for ActiveModel {} diff --git a/src/ingestion/db.rs b/src/ingestion/db.rs index 4b7950e..1dc61a3 100644 --- a/src/ingestion/db.rs +++ b/src/ingestion/db.rs @@ -20,11 +20,12 @@ pub struct Insertion { pub async fn insert( db: &DatabaseConnection, monzo_rows: Vec, + account_id: i32, ) -> Result, AppError> { let mut new_transaction_ids = Vec::new(); let insertions = monzo_rows .into_iter() - .map(MonzoRow::into_insertion) + .map(|row| MonzoRow::into_insertion(row, account_id)) .collect::, _>>()?; for insertions in insertions.chunks(400) { @@ -263,7 +264,7 @@ mod tests { .collect::, anyhow::Error>>() .unwrap(); - insert(&dbi.db, data.clone()).await?; + insert(&dbi.db, data.clone(), ).await?; let notification = listener.recv().await?; let payload = notification.payload(); let mut payload = serde_json::from_str::>(&payload)?; @@ -277,7 +278,7 @@ mod tests { assert_eq!(payload, ids, "Inserted IDs do not match"); - insert(&dbi.db, data.clone()).await?; + insert(&dbi.db, data.clone(), ).await?; let notification = listener.recv().await?; let payload = notification.payload(); let payload = serde_json::from_str::>(&payload)?; @@ -288,7 +289,7 @@ mod tests { assert_ne!(altered_data[0].compute_hash(), data[0].compute_hash(), "Alterations have the same hash"); - insert(&dbi.db, altered_data.clone()).await?; + insert(&dbi.db, altered_data.clone(), ).await?; let notification = listener.recv().await?; let payload = notification.payload(); let payload = serde_json::from_str::>(&payload)?; @@ -297,3 +298,15 @@ mod tests { Ok(()) } } + +pub(crate) async fn get_account_id(p0: &DatabaseConnection, p1: Option) -> Result { + let p1 = p1.unwrap_or("Monzo".to_string()); + + entity::prelude::Account::find() + .filter(entity::account::Column::Name.eq(p1)) + .select_only() + .column(entity::account::Column::Id) + .into_tuple::() + .one(p0).await? + .ok_or(AppError::BadRequest(anyhow!("Account not found"))) +} \ No newline at end of file diff --git a/src/ingestion/flex.rs b/src/ingestion/flex.rs index f9c3715..7d18d1f 100644 --- a/src/ingestion/flex.rs +++ b/src/ingestion/flex.rs @@ -1,5 +1,5 @@ #[allow(dead_code)] -mod headings { +pub mod headings { #[allow(unused_imports)] pub use super::super::ingestion_logic::headings::*; diff --git a/src/ingestion/ingestion_logic.rs b/src/ingestion/ingestion_logic.rs index 1e67948..2594502 100644 --- a/src/ingestion/ingestion_logic.rs +++ b/src/ingestion/ingestion_logic.rs @@ -9,6 +9,7 @@ use sea_orm::prelude::Decimal; use sea_orm::IntoActiveModel; use serde_json::Value; use std::hash::Hash; +use crate::ingestion::flex; #[allow(dead_code)] pub(crate) mod headings { @@ -77,7 +78,7 @@ impl MonzoRow { hasher.finish() as i64 } - pub fn into_insertion(self) -> Result { + pub fn into_insertion(self, account_id: i32) -> Result { let identity_hash = self.compute_hash(); let expenditures: Vec<_> = match &self.category_split { @@ -106,6 +107,7 @@ impl MonzoRow { total_amount: self.total_amount, description: self.description, identity_hash: Some(identity_hash), + account_id: Some(account_id), } .into_active_model(), diff --git a/src/ingestion/routes.rs b/src/ingestion/routes.rs index 1b950b1..0ad7c0a 100644 --- a/src/ingestion/routes.rs +++ b/src/ingestion/routes.rs @@ -19,36 +19,52 @@ pub async fn monzo_batched_json( .map(from_json_row) .collect::>()?; - db::insert(&db, data).await?; + // We default to the main account for JSON ingestion for now. + let account_id = db::get_account_id(&db, None).await?; + db::insert(&db, data, account_id).await?; Ok("Ok") } -async fn extract_csv(mut multipart: Multipart) -> Result, MultipartError> { - let csv = loop { - match multipart.next_field().await? { - Some(field) if field.name() == Some("csv") => { - break Some(field.bytes().await?); +async fn extract_csv_and_account_name(mut multipart: Multipart) -> Result<(Option, Option), MultipartError> { + let mut csv = None; + let mut account_name = None; + + while let Some(field) = multipart.next_field().await? { + match field.name() { + Some("csv") => { + csv = Some(field.bytes().await?); } - Some(_) => {} - None => break None, + Some("account_id") => { + account_name = Some(field.text().await?); + } + + _ => {} } - }; - Ok(csv) + if csv.is_some() && account_name.is_some() { + break; + } + } + + Ok((csv, account_name)) } - pub async fn monzo_batched_csv( Extension(db): Extension, multipart: Multipart, ) -> Result<&'static str, AppError> { static CSV_MISSING_ERR_MSG: &str = "No CSV file provided. Expected a multipart request with a `csv` field containing the contents of the CSV."; - let csv = extract_csv(multipart) + let (csv, account_name) = extract_csv_and_account_name(multipart) .await - .map_err(|e| AppError::BadRequest(anyhow!(e))) - .and_then(|csv| csv.ok_or(AppError::BadRequest(anyhow!(CSV_MISSING_ERR_MSG))))?; + .map_err(|e| AppError::BadRequest(anyhow!(e)))?; + + let Some(csv) = csv else { + return Err(AppError::BadRequest(anyhow!(CSV_MISSING_ERR_MSG))); + }; + + let account_id = db::get_account_id(&db, account_name).await?; let csv = Cursor::new(csv); let mut csv = csv::Reader::from_reader(csv); @@ -58,7 +74,7 @@ pub async fn monzo_batched_csv( .map(from_csv_row) .collect::>()?; - db::insert(&db, data).await?; + db::insert(&db, data, account_id).await?; Ok("Ok") } diff --git a/src/main.rs b/src/main.rs index 8c3c14b..dc8090a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -44,6 +44,10 @@ enum Commands { Csv { /// The path of the CSV file to ingest. csv_file: PathBuf, + + /// The name of the account to ingest the CSV for. + #[clap(long, short)] + account: String, }, } @@ -94,7 +98,7 @@ async fn main() -> anyhow::Result<()> { serve_web(addr, connection).await?; } - Commands::Csv { csv_file } => { + Commands::Csv { csv_file, account: account_name } => { let mut csv = csv::Reader::from_reader(File::open(csv_file)?); let data = csv.records(); let data = data @@ -102,7 +106,8 @@ async fn main() -> anyhow::Result<()> { .map(from_csv_row) .collect::>()?; - db::insert(&connection, data).await?; + let account_id = db::get_account_id(&connection, Some(account_name)).await?; + db::insert(&connection, data, account_id).await?; } }