use crate::error::AppError; use anyhow::anyhow; use entity::{expenditure, transaction}; use sea_orm::sea_query::{OnConflict, PostgresQueryBuilder}; use sea_orm::{ColumnTrait, DatabaseConnection, EntityTrait, Iterable, TransactionTrait}; use sea_orm::{ ConnectionTrait, DatabaseBackend, DatabaseTransaction, DbErr, QueryFilter, QueryTrait, Statement, }; use crate::ingestion::ingestion_logic::MonzoRow; pub struct Insertion { pub transaction: transaction::ActiveModel, pub contained_expenditures: Vec, } async fn update_expenditures( tx: &DatabaseTransaction, insertions: &[Insertion], ) -> Result<(), DbErr> { // Expenditures can change as we re-categorise them, so we delete all the old ones and // insert an entirely new set to ensure we don't end up leaving old ones around. expenditure::Entity::delete_many() .filter( expenditure::Column::TransactionId .is_in(insertions.iter().map(|i| i.transaction.id.as_ref())), ) .exec(tx) .await?; expenditure::Entity::insert_many( insertions .iter() .flat_map(|i| &i.contained_expenditures) .cloned(), ) .on_conflict( OnConflict::columns(vec![ expenditure::Column::TransactionId, expenditure::Column::Category, ]) .update_columns(expenditure::Column::iter()) .to_owned(), ) .exec(tx) .await?; Ok(()) } // Note while this is more efficient in db calls, it does bind together the entire group. // We employ a batching process for now to try balance speed and failure rate, but it is worth // trying to move failures earlier and improve reporting. pub async fn insert( db: &DatabaseConnection, monzo_rows: Vec, ) -> Result, AppError> { let mut new_transaction_ids = Vec::new(); let insertions = monzo_rows .into_iter() .map(MonzoRow::into_insertion) .collect::, _>>()?; for insertions in insertions.chunks(400) { let tx = db.begin().await?; let inserted_transaction_ids = update_transactions(insertions, &tx).await?; update_expenditures(&tx, &insertions).await?; tx.commit().await?; // We wait until the transaction is committed before adding the new transaction ids to the // list to avoid issues with the transaction being rolled back. new_transaction_ids.extend(inserted_transaction_ids); } // Notify the new transactions once everything is committed. notify_new_transactions(db, &new_transaction_ids).await?; Ok(new_transaction_ids) } async fn update_transactions( insertions: &[Insertion], tx: &DatabaseTransaction, ) -> Result, AppError> { let insert = transaction::Entity::insert_many(insertions.iter().map(|i| &i.transaction).cloned()) .on_conflict( OnConflict::column(transaction::Column::Id) .update_columns(transaction::Column::iter()) .to_owned(), ) .into_query() .returning_col(transaction::Column::Id) .build(PostgresQueryBuilder); let inserted_transaction_ids = tx .query_all(Statement::from_sql_and_values( DatabaseBackend::Postgres, insert.0, insert.1, )) .await? .iter() .map(|r| r.try_get_by("id")) .collect::, _>>()?; Ok(inserted_transaction_ids) } async fn notify_new_transactions( db: &DatabaseConnection, new_transaction_ids: &[String], ) -> Result<(), AppError> { let payload = serde_json::to_string(&new_transaction_ids).map_err(|e| anyhow!(e))?; db.execute_unprepared(&format!(r#"NOTIFY monzo_new_transactions, '{payload}'"#)) .await?; Ok(()) } mod tests { use anyhow::Error; use sea_orm::{DatabaseConnection}; use sqlx::{PgPool}; use sqlx::postgres::PgListener; use testcontainers::ContainerAsync; use migration::MigratorTrait; use testcontainers::runners::AsyncRunner; use super::notify_new_transactions; async fn initialise() -> Result<(ContainerAsync, DatabaseConnection, PgPool), Error> { let container = testcontainers_modules::postgres::Postgres::default() .start() .await?; // prepare connection string let connection_string = &format!( "postgres://postgres:postgres@127.0.0.1:{}/postgres", container.get_host_port_ipv4(5432).await? ); let db: DatabaseConnection = sea_orm::Database::connect(connection_string).await?; migration::Migrator::up(&db, None).await?; let pool = PgPool::connect(connection_string) .await?; Ok((container, db, pool)) } #[tokio::test] async fn test_notify() -> Result<(), Error> { let (_container, db, pool) = initialise().await?; let mut listener = PgListener::connect_with(&pool).await?; listener.listen("monzo_new_transactions").await?; let ids = vec!["test1".to_string(), "test2".to_string(), "test3".to_string()]; notify_new_transactions( &db, &ids, ).await?; let notification = listener.recv().await?; let payload = notification.payload(); println!("Payload: {}", payload); assert_eq!(serde_json::from_str::>(&payload)?, ids, "Payloads do not match"); Ok(()) } }