From 3f439c8a31ff9ac0a32773d64bd897f6ae3f8a7a Mon Sep 17 00:00:00 2001 From: Joshua Coles Date: Sat, 1 Feb 2025 19:45:43 +0000 Subject: [PATCH] Add additional validation --- src/main.rs | 4 +-- src/server.rs | 86 +++++++++++++++++++++++++++++++++++++++++++++------ src/worker.rs | 32 ++++++++++++++++--- 3 files changed, 107 insertions(+), 15 deletions(-) diff --git a/src/main.rs b/src/main.rs index 3c62bb5..8a86c85 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,5 @@ use axum::response::IntoResponse; -use sqlx::{Connection, PgPool}; +use sqlx::{PgPool}; use toggl::TogglApi; use worker::Worker; @@ -97,7 +97,7 @@ async fn main() { let cli = Cli::parse(); let toggl_api = TogglApi::new(&cli.api_token, cli.default_workspace_id); - let mut db = PgPool::connect(&cli.database_url).await.unwrap(); + let db = PgPool::connect(&cli.database_url).await.unwrap(); sqlx::migrate!("./migrations") .run(&db) diff --git a/src/server.rs b/src/server.rs index d9ed639..b961257 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,14 +1,73 @@ -use axum::response::IntoResponse; -use axum::{ - http::StatusCode, - routing::{get, post}, - Extension, Json, Router, -}; -use chrono::TimeDelta; -use std::net::IpAddr; - use crate::worker::Worker; use crate::AppError; +use axum::response::IntoResponse; +use axum::{ + routing::{get, post}, + Extension, Router, +}; +use chrono::TimeDelta; +use sqlx::postgres::PgListener; +use std::net::IpAddr; + +#[derive(Debug, serde::Deserialize)] +struct Payload(Vec); + +#[derive(Debug, serde::Deserialize)] +enum PayloadItem { + Num(i32), + Str(String), +} + +impl PayloadItem { + fn to_i32(self) -> Option { + match self { + PayloadItem::Num(n) => Some(n), + PayloadItem::Str(s) => s.parse().ok(), + } + } +} + +async fn listen_for_changes(pool: sqlx::PgPool, worker: Worker) { + let mut listener = PgListener::connect_with(&pool) + .await + .expect("Failed to connect to database"); + + // Enable LISTEN on the channel + listener + .listen("toggl_external_changes") + .await + .expect("Failed to LISTEN on channel"); + + while let Ok(notification) = listener.recv().await { + let payload = notification.payload(); + + if payload == "" { + worker + .update(TimeDelta::minutes(30)) + .await + .expect("Failed to sync"); + continue; + } + + match serde_json::from_str::(payload) { + Ok(payload) => { + let ids: Vec = payload + .0 + .into_iter() + .filter_map(|v| v.to_i32()) + .collect(); + + if !ids.is_empty() { + if let Err(e) = worker.fetch_time_entries_by_ids(&ids).await { + eprintln!("Error fetching time entries: {}", e); + } + } + } + + Err(e) => eprintln!("Error parsing notification payload: {}", e), + } + } +} async fn sync(Extension(worker): Extension) -> Result { worker.update(TimeDelta::days(30)).await?; @@ -17,6 +76,15 @@ async fn sync(Extension(worker): Extension) -> Result Result<(), AppError> { + // Clone the pool and worker for the notification listener + let notification_pool = worker.db.clone(); + let notification_worker = worker.clone(); + + // Spawn the notification listener + tokio::spawn(async move { + listen_for_changes(notification_pool, notification_worker).await; + }); + // build our application with a route let app = Router::new() .route("/health", get(|| async { "Ok" })) diff --git a/src/worker.rs b/src/worker.rs index e607365..d10a2c6 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -28,6 +28,30 @@ pub struct Worker { } impl Worker { + pub async fn fetch_time_entries_by_ids(&self, ids: &[i32]) -> Result<(), AppError> { + // Convert i32 IDs to u64 as that's what the Toggl API expects + let ids: Vec = ids.iter().map(|&id| id as u64).collect(); + + // Use the search API with time_entry_ids filter + let time_entries = self + .toggl_api + .search( + self.toggl_api.workspace_id, + TogglReportFilters { + time_entry_ids: Some(ids), + ..Default::default() + }, + ) + .await?; + + let time_entries = time_entries + .into_iter() + .map(|entry| entry.into_time_entry(self.toggl_api.workspace_id)) + .collect::>(); + + self.update_database(time_entries).await + } + async fn get_ids(&self) -> Result { let client_ids = sqlx::query!("select id from tracking_clients") .fetch_all(&self.db) @@ -75,10 +99,6 @@ impl Worker { } pub async fn fetch_changed_since(&self, look_back: TimeDelta) -> Result<(), AppError> { - if look_back > TimeDelta::days(90) { - return Err(AppError::LookBackTooLarge); - } - self.update_time_entries(Utc::now() - look_back).await } @@ -96,6 +116,10 @@ impl Worker { } async fn update_time_entries(&self, fetch_since: DateTime) -> Result<(), AppError> { + if Utc::now() - fetch_since > TimeDelta::days(90) { + return Err(AppError::LookBackTooLarge); + } + let time_entries = self .toggl_api .get_time_entries_for_user_modified_since(fetch_since)