Compare commits

...

3 Commits

Author SHA1 Message Date
Joshua Coles
d11e4fd0c4 Allow specifying the account ID and fix bug when delivering batched rows
All checks were successful
Build and Publish / Build and Test (push) Successful in 9m12s
2025-12-28 12:04:53 +00:00
Joshua Coles
a2ba83e6f8 Correct tests 2025-12-28 11:57:29 +00:00
Joshua Coles
0d564ff299 Update deps 2025-12-28 11:57:22 +00:00
5 changed files with 1582 additions and 790 deletions

2215
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -7,8 +7,8 @@ edition = "2021"
entity = { path = "entity" } entity = { path = "entity" }
migration = { path = "migration" } migration = { path = "migration" }
axum = { version = "0.7.5", features = ["multipart"] } axum = { version = "0.8.8", features = ["multipart"] }
tokio = { version = "1.37.0", features = ["full"] } tokio = { version = "1.48.0", features = ["full"] }
sea-orm = { version = "1.1.0", features = [ sea-orm = { version = "1.1.0", features = [
"sqlx-postgres", "sqlx-postgres",
"runtime-tokio-rustls", "runtime-tokio-rustls",
@ -20,14 +20,14 @@ serde_json = "1.0"
tracing-subscriber = "0.3.18" tracing-subscriber = "0.3.18"
tracing = "0.1.40" tracing = "0.1.40"
anyhow = { version = "1.0", features = ["backtrace"] } anyhow = { version = "1.0", features = ["backtrace"] }
thiserror = "1.0" thiserror = "2.0"
http = "1.1" http = "1.1"
chrono = { version = "0.4", features = ["serde"] } chrono = { version = "0.4", features = ["serde"] }
num-traits = "0.2" num-traits = "0.2"
csv = "1.3.0" csv = "1.3.0"
clap = "4.5" clap = "4.5"
testcontainers = "0.21" testcontainers = "0.26"
testcontainers-modules = { version = "0.9", features = ["postgres"] } testcontainers-modules = { version = "0.14", features = ["postgres"] }
sqlx = { version = "0.8", features = ["postgres"] } sqlx = { version = "0.8", features = ["postgres"] }
tower-http = { version = "0.6", features = ["trace"] } tower-http = { version = "0.6", features = ["trace"] }
bytes = "1.7" bytes = "1.7"

View File

@ -159,16 +159,17 @@ async fn notify_new_transactions(
mod tests { mod tests {
use super::{insert, notify_new_transactions, update_expenditures, update_transactions}; use super::{insert, notify_new_transactions, update_expenditures, update_transactions};
use crate::ingestion::ingestion_logic::from_json_row;
use anyhow::Error; use anyhow::Error;
use tokio::sync::OnceCell; use entity::account;
use migration::MigratorTrait; use migration::MigratorTrait;
use sea_orm::{DatabaseConnection, TransactionTrait}; use sea_orm::{ActiveModelTrait, DatabaseConnection, TransactionTrait};
use serde_json::Value; use serde_json::Value;
use sqlx::postgres::PgListener; use sqlx::postgres::PgListener;
use sqlx::PgPool; use sqlx::{Executor, PgPool};
use testcontainers::runners::AsyncRunner; use testcontainers::runners::AsyncRunner;
use testcontainers::ContainerAsync; use testcontainers::ContainerAsync;
use crate::ingestion::ingestion_logic::from_json_row; use tokio::sync::OnceCell;
#[derive(Debug)] #[derive(Debug)]
struct DatabaseInstance { struct DatabaseInstance {
@ -203,13 +204,19 @@ mod tests {
Ok(instance) Ok(instance)
} }
async fn get_or_initialize_db_instance() -> Result< async fn get_or_initialize_db_instance() -> Result<&'static DatabaseInstance, Error> {
&'static DatabaseInstance, Ok(INSTANCE
Error, .get_or_init(|| async { initialise_db().await.unwrap() })
> { .await)
Ok(INSTANCE.get_or_init(|| async { }
initialise_db().await.unwrap()
}).await) async fn create_test_account(db: &DatabaseConnection) -> Result<i32, Error> {
let new_account = account::ActiveModel {
id: sea_orm::ActiveValue::NotSet,
name: sea_orm::ActiveValue::Set("Test Account".to_string()),
};
let inserted = new_account.insert(db).await?;
Ok(inserted.id)
} }
#[tokio::test] #[tokio::test]
@ -253,6 +260,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_notify_on_insert() -> Result<(), Error> { async fn test_notify_on_insert() -> Result<(), Error> {
let dbi = get_or_initialize_db_instance().await?; let dbi = get_or_initialize_db_instance().await?;
let account_id = create_test_account(&dbi.db).await?;
let mut listener = PgListener::connect_with(&dbi.pool).await?; let mut listener = PgListener::connect_with(&dbi.pool).await?;
listener.listen("monzo_new_transactions").await?; listener.listen("monzo_new_transactions").await?;
@ -260,17 +268,17 @@ mod tests {
let json: Vec<Vec<Value>> = serde_json::from_str(json).unwrap(); let json: Vec<Vec<Value>> = serde_json::from_str(json).unwrap();
let data = json let data = json
.iter() .iter()
.map(|row| from_json_row(row.clone())) .map(|row| from_json_row(row))
.collect::<Result<Vec<_>, anyhow::Error>>() .collect::<Result<Vec<_>, anyhow::Error>>()?;
.unwrap();
insert(&dbi.db, data.clone(), ).await?; insert(&dbi.db, data.clone(), account_id).await?;
let notification = listener.recv().await?; let notification = listener.recv().await?;
let payload = notification.payload(); let payload = notification.payload();
let mut payload = serde_json::from_str::<Vec<String>>(&payload)?; let mut payload = serde_json::from_str::<Vec<String>>(&payload)?;
payload.sort(); payload.sort();
let mut ids = data.iter() let mut ids = data
.iter()
.map(|row| row.transaction_id.clone()) .map(|row| row.transaction_id.clone())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
@ -278,28 +286,43 @@ mod tests {
assert_eq!(payload, ids, "Inserted IDs do not match"); assert_eq!(payload, ids, "Inserted IDs do not match");
insert(&dbi.db, data.clone(), ).await?; insert(&dbi.db, data.clone(), account_id).await?;
let notification = listener.recv().await?; let notification = listener.recv().await?;
let payload = notification.payload(); let payload = notification.payload();
let payload = serde_json::from_str::<Vec<String>>(&payload)?; let payload = serde_json::from_str::<Vec<String>>(&payload)?;
assert_eq!(payload, Vec::<String>::new(), "Re-inserting identical rows triggered double notification"); assert_eq!(
payload,
Vec::<String>::new(),
"Re-inserting identical rows triggered double notification"
);
let mut altered_data = data.clone(); let mut altered_data = data.clone();
altered_data[0].description = Some("New description".to_string()); altered_data[0].description = Some("New description".to_string());
assert_ne!(altered_data[0].compute_hash(), data[0].compute_hash(), "Alterations have the same hash"); assert_ne!(
altered_data[0].compute_hash(),
data[0].compute_hash(),
"Alterations have the same hash"
);
insert(&dbi.db, altered_data.clone(), ).await?; insert(&dbi.db, altered_data.clone(), account_id).await?;
let notification = listener.recv().await?; let notification = listener.recv().await?;
let payload = notification.payload(); let payload = notification.payload();
let payload = serde_json::from_str::<Vec<String>>(&payload)?; let payload = serde_json::from_str::<Vec<String>>(&payload)?;
assert_eq!(payload, vec![altered_data[0].transaction_id.clone()], "Re-inserting altered row failed to re-trigger notification"); assert_eq!(
payload,
vec![altered_data[0].transaction_id.clone()],
"Re-inserting altered row failed to re-trigger notification"
);
Ok(()) Ok(())
} }
} }
pub(crate) async fn get_account_id(p0: &DatabaseConnection, p1: Option<String>) -> Result<i32, AppError> { pub(crate) async fn get_account_id(
p0: &DatabaseConnection,
p1: Option<String>,
) -> Result<i32, AppError> {
let p1 = p1.unwrap_or("Monzo".to_string()); let p1 = p1.unwrap_or("Monzo".to_string());
entity::prelude::Account::find() entity::prelude::Account::find()
@ -307,6 +330,7 @@ pub(crate) async fn get_account_id(p0: &DatabaseConnection, p1: Option<String>)
.select_only() .select_only()
.column(entity::account::Column::Id) .column(entity::account::Column::Id)
.into_tuple::<i32>() .into_tuple::<i32>()
.one(p0).await? .one(p0)
.await?
.ok_or(AppError::BadRequest(anyhow!("Account not found"))) .ok_or(AppError::BadRequest(anyhow!("Account not found")))
} }

View File

@ -140,7 +140,7 @@ fn parse_timestamp(date: &str, time: &str) -> anyhow::Result<NaiveDateTime> {
Ok(date.and_time(time)) Ok(date.and_time(time))
} }
pub fn from_json_row(row: Vec<Value>) -> anyhow::Result<MonzoRow> { pub fn from_json_row(row: &[Value]) -> anyhow::Result<MonzoRow> {
let date = DateTime::parse_from_rfc3339(row[headings::DATE].as_str().context("No date")?) let date = DateTime::parse_from_rfc3339(row[headings::DATE].as_str().context("No date")?)
.context("Failed to parse date")?; .context("Failed to parse date")?;
@ -178,7 +178,7 @@ fn test_json() {
let json_rows = json let json_rows = json
.iter() .iter()
.map(|row| from_json_row(row.clone())) .map(|row| from_json_row(&row))
.collect::<Result<Vec<_>, anyhow::Error>>() .collect::<Result<Vec<_>, anyhow::Error>>()
.unwrap(); .unwrap();

View File

@ -9,24 +9,53 @@ use sea_orm::DatabaseConnection;
use serde_json::Value; use serde_json::Value;
use std::io::Cursor; use std::io::Cursor;
#[derive(serde::Deserialize, Debug)]
#[serde(untagged)]
pub enum MonzoBatchedJsonInput {
Legacy(Vec<Vec<Value>>),
New {
account_id: Option<u8>,
rows: Vec<Vec<Value>>,
},
}
impl MonzoBatchedJsonInput {
fn account_id(&self) -> Option<u8> {
match self {
MonzoBatchedJsonInput::Legacy(_) => None,
MonzoBatchedJsonInput::New { account_id, .. } => *account_id,
}
}
fn rows(&self) -> &[Vec<Value>] {
match self {
MonzoBatchedJsonInput::Legacy(rows) => rows,
MonzoBatchedJsonInput::New { rows, .. } => rows,
}
}
}
pub async fn monzo_batched_json( pub async fn monzo_batched_json(
Extension(db): Extension<DatabaseConnection>, Extension(db): Extension<DatabaseConnection>,
Json(data): Json<Vec<Vec<Value>>>, Json(data): Json<MonzoBatchedJsonInput>,
) -> Result<&'static str, AppError> { ) -> Result<&'static str, AppError> {
let data = data let rows = data
.into_iter() .rows()
.skip(1) // Skip the header row. .iter()
.map(from_json_row) .skip_while(|row| row[0] == Value::String("Transaction ID".to_string()))
.map(|row| from_json_row(row.as_ref()))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
// We default to the main account for JSON ingestion for now. // We default to the main account for JSON ingestion for now.
let account_id = db::get_account_id(&db, None).await?; let account_id = db::get_account_id(&db, data.account_id().map(|id| id.to_string())).await?;
db::insert(&db, data, account_id).await?; db::insert(&db, rows, account_id).await?;
Ok("Ok") Ok("Ok")
} }
async fn extract_csv_and_account_name(mut multipart: Multipart) -> Result<(Option<Bytes>, Option<String>), MultipartError> { async fn extract_csv_and_account_name(
mut multipart: Multipart,
) -> Result<(Option<Bytes>, Option<String>), MultipartError> {
let mut csv = None; let mut csv = None;
let mut account_name = None; let mut account_name = None;
@ -59,7 +88,7 @@ pub struct ShortcutBody {
pub async fn shortcuts_csv( pub async fn shortcuts_csv(
Extension(db): Extension<DatabaseConnection>, Extension(db): Extension<DatabaseConnection>,
Json(shortcut_body): Json<ShortcutBody> Json(shortcut_body): Json<ShortcutBody>,
) -> Result<&'static str, AppError> { ) -> Result<&'static str, AppError> {
let account_id = db::get_account_id(&db, Some(shortcut_body.account_name)).await?; let account_id = db::get_account_id(&db, Some(shortcut_body.account_name)).await?;