use crate::error::AppError; use crate::ingestion::ingestion_logic::MonzoRow; use anyhow::anyhow; use entity::{expenditure, transaction}; use sea_orm::sea_query::OnConflict; use sea_orm::{ ColumnTrait, DatabaseConnection, DbErr, EntityTrait, Iterable, QuerySelect, TransactionTrait, }; use sea_orm::{ConnectionTrait, DatabaseTransaction, QueryFilter}; pub struct Insertion { pub transaction: transaction::ActiveModel, pub contained_expenditures: Vec, pub identity_hash: i64, } // Note while this is more efficient in db calls, it does bind together the entire group. // We employ a batching process for now to try balance speed and failure rate, but it is worth // trying to move failures earlier and improve reporting. pub async fn insert( db: &DatabaseConnection, monzo_rows: Vec, account_id: i32, ) -> Result, AppError> { let mut new_transaction_ids = Vec::new(); let insertions = monzo_rows .into_iter() .map(|row| MonzoRow::into_insertion(row, account_id)) .collect::, _>>()?; for insertions in insertions.chunks(200) { let (new_or_updated_insertions, inserted_transaction_ids) = whittle_insertions(insertions, db).await?; if new_or_updated_insertions.is_empty() { continue; } let tx = db.begin().await?; update_transactions(&tx, &new_or_updated_insertions).await?; update_expenditures(&tx, &new_or_updated_insertions, &inserted_transaction_ids).await?; tx.commit().await?; // We wait until the transaction is committed before adding the new transaction ids to the // list to avoid issues with the transaction being rolled back. new_transaction_ids.extend(inserted_transaction_ids); } // Notify the new transactions once everything is committed. notify_new_transactions(db, &new_transaction_ids).await?; Ok(new_transaction_ids) } async fn update_expenditures( tx: &DatabaseTransaction, new_or_updated_insertions: &[&Insertion], inserted_transaction_ids: &[String], ) -> Result<(), AppError> { if new_or_updated_insertions.is_empty() { return Ok(()); } // Expenditures can change as we re-categorise them, so we delete all the old ones and // insert an entirely new set to ensure we don't end up leaving old ones around. expenditure::Entity::delete_many() .filter(expenditure::Column::TransactionId.is_in(inserted_transaction_ids)) .exec(tx) .await?; expenditure::Entity::insert_many( new_or_updated_insertions .iter() .flat_map(|i| &i.contained_expenditures) .cloned(), ) .on_conflict( OnConflict::columns(vec![ expenditure::Column::TransactionId, expenditure::Column::Category, ]) .update_columns(expenditure::Column::iter()) .to_owned(), ) .exec(tx) .await?; Ok(()) } async fn update_transactions( tx: &DatabaseTransaction, new_or_updated_insertions: &[&Insertion], ) -> Result<(), DbErr> { if new_or_updated_insertions.is_empty() { return Ok(()); } let transactions = new_or_updated_insertions .iter() .map(|i| &i.transaction) .cloned(); transaction::Entity::insert_many(transactions) .on_conflict( OnConflict::column(transaction::Column::Id) .update_columns(transaction::Column::iter()) .to_owned(), ) .exec(tx) .await?; Ok(()) } async fn whittle_insertions<'a>( insertions: &'a [Insertion], tx: &DatabaseConnection, ) -> Result<(Vec<&'a Insertion>, Vec), AppError> { let existing_hashes = transaction::Entity::find() .select_only() .columns([transaction::Column::IdentityHash]) .filter(transaction::Column::IdentityHash.is_not_null()) .into_tuple::<(i64,)>() .all(tx) .await?; tracing::debug!("Found existing entries: {existing_hashes:?}"); // We will only update those where the hash is different to avoid unnecessary updates and // notifications. let new_or_updated_insertions = insertions .iter() .filter(|i| { let hash = i.identity_hash; !existing_hashes .iter() .any(|(existing_hash,)| *existing_hash == hash) }) .collect::>(); let inserted_transaction_ids = new_or_updated_insertions .iter() .map(|i| i.transaction.id.clone().unwrap()) .collect::>(); Ok((new_or_updated_insertions, inserted_transaction_ids)) } async fn notify_new_transactions( db: &DatabaseConnection, new_transaction_ids: &[String], ) -> Result<(), AppError> { let payload = serde_json::to_string(&new_transaction_ids).map_err(|e| anyhow!(e))?; db.execute_unprepared(&format!(r#"NOTIFY monzo_new_transactions, '{payload}'"#)) .await?; Ok(()) } mod tests { use super::{insert, notify_new_transactions, update_expenditures, update_transactions}; use crate::ingestion::ingestion_logic::from_json_row; use anyhow::Error; use entity::account; use migration::MigratorTrait; use sea_orm::{ActiveModelTrait, DatabaseConnection, TransactionTrait}; use serde_json::Value; use sqlx::postgres::PgListener; use sqlx::{Executor, PgPool}; use testcontainers::runners::AsyncRunner; use testcontainers::ContainerAsync; use tokio::sync::OnceCell; #[derive(Debug)] struct DatabaseInstance { container: ContainerAsync, db: DatabaseConnection, pool: PgPool, } static INSTANCE: OnceCell = OnceCell::const_new(); async fn initialise_db() -> Result { let container = testcontainers_modules::postgres::Postgres::default() .start() .await?; // prepare connection string let connection_string = &format!( "postgres://postgres:postgres@127.0.0.1:{}/postgres", container.get_host_port_ipv4(5432).await? ); let db: DatabaseConnection = sea_orm::Database::connect(connection_string).await?; migration::Migrator::up(&db, None).await?; let pool = PgPool::connect(connection_string).await?; let instance = DatabaseInstance { container, db, pool, }; Ok(instance) } async fn get_or_initialize_db_instance() -> Result<&'static DatabaseInstance, Error> { Ok(INSTANCE .get_or_init(|| async { initialise_db().await.unwrap() }) .await) } async fn create_test_account(db: &DatabaseConnection) -> Result { let new_account = account::ActiveModel { id: sea_orm::ActiveValue::NotSet, name: sea_orm::ActiveValue::Set("Test Account".to_string()), }; let inserted = new_account.insert(db).await?; Ok(inserted.id) } #[tokio::test] async fn test_empty_insertion_list() -> Result<(), Error> { let db = get_or_initialize_db_instance().await?; let insertions = vec![]; let tx = db.db.begin().await?; update_transactions(&tx, &insertions).await?; update_expenditures(&tx, &insertions, &vec![]).await?; tx.commit().await?; Ok(()) } #[tokio::test] async fn test_notify() -> Result<(), Error> { let dbi = get_or_initialize_db_instance().await?; let mut listener = PgListener::connect_with(&dbi.pool).await?; listener.listen("monzo_new_transactions").await?; let ids = vec![ "test1".to_string(), "test2".to_string(), "test3".to_string(), ]; notify_new_transactions(&dbi.db, &ids).await?; let notification = listener.recv().await?; let payload = notification.payload(); println!("Payload: {}", payload); assert_eq!( serde_json::from_str::>(&payload)?, ids, "Payloads do not match" ); Ok(()) } #[tokio::test] async fn test_notify_on_insert() -> Result<(), Error> { let dbi = get_or_initialize_db_instance().await?; let account_id = create_test_account(&dbi.db).await?; let mut listener = PgListener::connect_with(&dbi.pool).await?; listener.listen("monzo_new_transactions").await?; let json = include_str!("../../fixtures/transactions.json"); let json: Vec> = serde_json::from_str(json).unwrap(); let data = json .iter() .map(|row| from_json_row(row)) .collect::, anyhow::Error>>()?; insert(&dbi.db, data.clone(), account_id).await?; let notification = listener.recv().await?; let payload = notification.payload(); let mut payload = serde_json::from_str::>(&payload)?; payload.sort(); let mut ids = data .iter() .map(|row| row.transaction_id.clone()) .collect::>(); ids.sort(); assert_eq!(payload, ids, "Inserted IDs do not match"); insert(&dbi.db, data.clone(), account_id).await?; let notification = listener.recv().await?; let payload = notification.payload(); let payload = serde_json::from_str::>(&payload)?; assert_eq!( payload, Vec::::new(), "Re-inserting identical rows triggered double notification" ); let mut altered_data = data.clone(); altered_data[0].description = Some("New description".to_string()); assert_ne!( altered_data[0].compute_hash(), data[0].compute_hash(), "Alterations have the same hash" ); insert(&dbi.db, altered_data.clone(), account_id).await?; let notification = listener.recv().await?; let payload = notification.payload(); let payload = serde_json::from_str::>(&payload)?; assert_eq!( payload, vec![altered_data[0].transaction_id.clone()], "Re-inserting altered row failed to re-trigger notification" ); Ok(()) } } pub(crate) async fn get_account_id( p0: &DatabaseConnection, p1: Option, ) -> Result { let p1 = p1.unwrap_or("Monzo".to_string()); entity::prelude::Account::find() .filter(entity::account::Column::Name.eq(p1)) .select_only() .column(entity::account::Column::Id) .into_tuple::() .one(p0) .await? .ok_or(AppError::BadRequest(anyhow!("Account not found"))) }