diff --git a/migration/src/m20240529_195030_add_transaction_identity_hash.rs b/migration/src/m20240529_195030_add_transaction_identity_hash.rs index 0577295..842de80 100644 --- a/migration/src/m20240529_195030_add_transaction_identity_hash.rs +++ b/migration/src/m20240529_195030_add_transaction_identity_hash.rs @@ -6,24 +6,29 @@ pub struct Migration; #[async_trait::async_trait] impl MigrationTrait for Migration { async fn up(&self, manager: &SchemaManager) -> Result<(), DbErr> { - manager.alter_table( - TableAlterStatement::new() - .table(Transaction::Table) - .add_column( - ColumnDef::new(Transaction::IdentityHash) - .big_integer() - .unique_key(), - ).to_owned() - ).await + manager + .alter_table( + TableAlterStatement::new() + .table(Transaction::Table) + .add_column( + ColumnDef::new(Transaction::IdentityHash) + .big_integer() + .unique_key(), + ) + .to_owned(), + ) + .await } async fn down(&self, manager: &SchemaManager) -> Result<(), DbErr> { - manager.alter_table( - TableAlterStatement::new() - .table(Transaction::Table) - .drop_column(Transaction::IdentityHash) - .to_owned() - ).await + manager + .alter_table( + TableAlterStatement::new() + .table(Transaction::Table) + .drop_column(Transaction::IdentityHash) + .to_owned(), + ) + .await } } diff --git a/src/ingestion/db.rs b/src/ingestion/db.rs index 2695f67..e08008d 100644 --- a/src/ingestion/db.rs +++ b/src/ingestion/db.rs @@ -1,51 +1,17 @@ use crate::error::AppError; +use crate::ingestion::ingestion_logic::MonzoRow; use anyhow::anyhow; use entity::{expenditure, transaction}; -use sea_orm::sea_query::{OnConflict, PostgresQueryBuilder}; -use sea_orm::{ColumnTrait, DatabaseConnection, EntityTrait, Iterable, TransactionTrait}; +use sea_orm::sea_query::OnConflict; use sea_orm::{ - ConnectionTrait, DatabaseBackend, DatabaseTransaction, DbErr, QueryFilter, QueryTrait, - Statement, + ColumnTrait, DatabaseConnection, EntityTrait, Iterable, QuerySelect, TransactionTrait, }; -use crate::ingestion::ingestion_logic::MonzoRow; +use sea_orm::{ConnectionTrait, DatabaseTransaction, QueryFilter}; pub struct Insertion { pub transaction: transaction::ActiveModel, pub contained_expenditures: Vec, -} - -async fn update_expenditures( - tx: &DatabaseTransaction, - insertions: &[Insertion], -) -> Result<(), DbErr> { - // Expenditures can change as we re-categorise them, so we delete all the old ones and - // insert an entirely new set to ensure we don't end up leaving old ones around. - expenditure::Entity::delete_many() - .filter( - expenditure::Column::TransactionId - .is_in(insertions.iter().map(|i| i.transaction.id.as_ref())), - ) - .exec(tx) - .await?; - - expenditure::Entity::insert_many( - insertions - .iter() - .flat_map(|i| &i.contained_expenditures) - .cloned(), - ) - .on_conflict( - OnConflict::columns(vec![ - expenditure::Column::TransactionId, - expenditure::Column::Category, - ]) - .update_columns(expenditure::Column::iter()) - .to_owned(), - ) - .exec(tx) - .await?; - - Ok(()) + pub identity_hash: i64, } // Note while this is more efficient in db calls, it does bind together the entire group. @@ -63,8 +29,11 @@ pub async fn insert( for insertions in insertions.chunks(400) { let tx = db.begin().await?; - let inserted_transaction_ids = update_transactions(insertions, &tx).await?; - update_expenditures(&tx, &insertions).await?; + let (new_or_updated_insertions, inserted_transaction_ids) = + whittle_insertions(insertions, &tx).await?; + + update_transactions(&tx, &new_or_updated_insertions).await?; + update_expenditures(&tx, &new_or_updated_insertions, &inserted_transaction_ids).await?; tx.commit().await?; // We wait until the transaction is committed before adding the new transaction ids to the @@ -78,34 +47,87 @@ pub async fn insert( Ok(new_transaction_ids) } -async fn update_transactions( - insertions: &[Insertion], +async fn update_expenditures( tx: &DatabaseTransaction, -) -> Result, AppError> { + new_or_updated_insertions: &[&Insertion], + inserted_transaction_ids: &[String], +) -> Result<(), AppError> { + // Expenditures can change as we re-categorise them, so we delete all the old ones and + // insert an entirely new set to ensure we don't end up leaving old ones around. + expenditure::Entity::delete_many() + .filter(expenditure::Column::TransactionId.is_in(inserted_transaction_ids)) + .exec(tx) + .await?; + expenditure::Entity::insert_many( + new_or_updated_insertions + .into_iter() + .flat_map(|i| &i.contained_expenditures) + .cloned(), + ) + .on_conflict( + OnConflict::columns(vec![ + expenditure::Column::TransactionId, + expenditure::Column::Category, + ]) + .update_columns(expenditure::Column::iter()) + .to_owned(), + ) + .exec(tx) + .await?; + Ok(()) +} - let insert = - transaction::Entity::insert_many(insertions.iter().map(|i| &i.transaction).cloned()) - .on_conflict( - OnConflict::columns([transaction::Column::Id, transaction::Column::IdentityHash]) - .update_columns(transaction::Column::iter()) - .to_owned(), - ) - .into_query() - .returning_col(transaction::Column::Id) - .build(PostgresQueryBuilder); - - let inserted_transaction_ids = tx - .query_all(Statement::from_sql_and_values( - DatabaseBackend::Postgres, - insert.0, - insert.1, - )) - .await? +async fn update_transactions( + tx: &DatabaseTransaction, + new_or_updated_insertions: &[&Insertion], +) -> Result<(), AppError> { + let transactions = new_or_updated_insertions .iter() - .map(|r| r.try_get_by("id")) - .collect::, _>>()?; - Ok(inserted_transaction_ids) + .map(|i| &i.transaction) + .cloned(); + + transaction::Entity::insert_many(transactions) + .on_conflict( + OnConflict::column(transaction::Column::Id) + .update_columns(transaction::Column::iter()) + .to_owned(), + ) + .exec(tx) + .await?; + Ok(()) +} + +async fn whittle_insertions<'a>( + insertions: &'a [Insertion], + tx: &DatabaseTransaction, +) -> Result<(Vec<&'a Insertion>, Vec), AppError> { + let existing_hashes = transaction::Entity::find() + .select_only() + .columns([transaction::Column::IdentityHash]) + .filter(transaction::Column::IdentityHash.is_not_null()) + .into_tuple::<(i64,)>() + .all(tx) + .await?; + + // We will only update those where the hash is different to avoid unnecessary updates and + // notifications. + let new_or_updated_insertions = insertions + .into_iter() + .filter(|i| { + let hash = i.identity_hash; + !existing_hashes + .iter() + .any(|(existing_hash,)| *existing_hash == hash) + }) + .collect::>(); + + let inserted_transaction_ids = new_or_updated_insertions + .iter() + .map(|i| i.transaction.id.clone().unwrap()) + .collect::>(); + + Ok((new_or_updated_insertions, inserted_transaction_ids)) } async fn notify_new_transactions( @@ -121,16 +143,23 @@ async fn notify_new_transactions( } mod tests { - use anyhow::Error; - use sea_orm::{DatabaseConnection}; - use sqlx::{PgPool}; - use sqlx::postgres::PgListener; - use testcontainers::ContainerAsync; - use migration::MigratorTrait; - use testcontainers::runners::AsyncRunner; use super::notify_new_transactions; + use anyhow::Error; + use migration::MigratorTrait; + use sea_orm::DatabaseConnection; + use sqlx::postgres::PgListener; + use sqlx::PgPool; + use testcontainers::runners::AsyncRunner; + use testcontainers::ContainerAsync; - async fn initialise() -> Result<(ContainerAsync, DatabaseConnection, PgPool), Error> { + async fn initialise() -> Result< + ( + ContainerAsync, + DatabaseConnection, + PgPool, + ), + Error, + > { let container = testcontainers_modules::postgres::Postgres::default() .start() .await?; @@ -144,8 +173,7 @@ mod tests { let db: DatabaseConnection = sea_orm::Database::connect(connection_string).await?; migration::Migrator::up(&db, None).await?; - let pool = PgPool::connect(connection_string) - .await?; + let pool = PgPool::connect(connection_string).await?; Ok((container, db, pool)) } @@ -156,18 +184,23 @@ mod tests { let mut listener = PgListener::connect_with(&pool).await?; listener.listen("monzo_new_transactions").await?; - let ids = vec!["test1".to_string(), "test2".to_string(), "test3".to_string()]; + let ids = vec![ + "test1".to_string(), + "test2".to_string(), + "test3".to_string(), + ]; - notify_new_transactions( - &db, - &ids, - ).await?; + notify_new_transactions(&db, &ids).await?; let notification = listener.recv().await?; let payload = notification.payload(); println!("Payload: {}", payload); - assert_eq!(serde_json::from_str::>(&payload)?, ids, "Payloads do not match"); + assert_eq!( + serde_json::from_str::>(&payload)?, + ids, + "Payloads do not match" + ); Ok(()) } -} \ No newline at end of file +} diff --git a/src/ingestion/ingestion_logic.rs b/src/ingestion/ingestion_logic.rs index 5e2a60d..cc316c8 100644 --- a/src/ingestion/ingestion_logic.rs +++ b/src/ingestion/ingestion_logic.rs @@ -1,4 +1,3 @@ -use std::hash::Hash; use crate::ingestion::db::Insertion; use anyhow::{anyhow, Context}; use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime}; @@ -9,6 +8,7 @@ use num_traits::FromPrimitive; use sea_orm::prelude::Decimal; use sea_orm::IntoActiveModel; use serde_json::Value; +use std::hash::Hash; #[allow(dead_code)] mod headings { @@ -62,7 +62,8 @@ impl MonzoRow { transaction_id: monzo_transaction_id.to_string(), category, amount, - }.into_active_model()) + } + .into_active_model()) } /// Compute a hash of this row, returning the number as an i64 to be used as a unique constraint @@ -77,7 +78,9 @@ impl MonzoRow { } pub fn into_insertion(self) -> Result { - let expenditures: Vec<_> = match self.category_split { + let identity_hash = self.compute_hash(); + + let expenditures: Vec<_> = match &self.category_split { Some(split) if !split.is_empty() => split .split(',') .map(|section| Self::parse_section(&self.transaction_id, section)) @@ -88,7 +91,7 @@ impl MonzoRow { amount: self.total_amount, transaction_id: self.transaction_id.clone(), } - .into_active_model()], + .into_active_model()], }; Ok(Insertion { @@ -102,10 +105,12 @@ impl MonzoRow { receipt: self.receipt, total_amount: self.total_amount, description: self.description, - identity_hash: Some(self.compute_hash()), - }.into_active_model(), + identity_hash: Some(identity_hash), + } + .into_active_model(), contained_expenditures: expenditures, + identity_hash, }) } }