From b563bbd02c9bbdd6bf454a62a40be8dac2a9a97c Mon Sep 17 00:00:00 2001 From: Joshua Coles Date: Mon, 3 Jun 2024 19:59:22 +0100 Subject: [PATCH] Add a little more CLI structure for local runs --- src/main.rs | 92 ++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 74 insertions(+), 18 deletions(-) diff --git a/src/main.rs b/src/main.rs index 408f981..5201e64 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,29 +2,55 @@ mod error; mod ingestion; use crate::error::AppError; +use crate::ingestion::db; +use crate::ingestion::ingestion_logic::from_csv_row; use crate::ingestion::routes::{monzo_batched_csv, monzo_batched_json}; use axum::routing::{get, post}; use axum::{Extension, Router}; -use clap::Parser; +use clap::{Parser, Subcommand}; use migration::{Migrator, MigratorTrait}; use sea_orm::{ConnectionTrait, DatabaseConnection}; +use std::fs::File; use std::net::SocketAddr; +use std::path::PathBuf; use tower_http::trace::TraceLayer; use tracing::log::LevelFilter; +#[derive(Debug, Subcommand)] +enum Commands { + Migrate { + /// Number of migration steps to perform. If not provided, all migrations will be run. + #[arg(long)] + steps: Option, + + /// If we should perform migration down. + #[arg(long)] + down: bool, + }, + Run { + /// If we should perform migration on startup. + #[clap(short, long, env, default_value_t = true)] + migrate: bool, + + /// The server address to bind to. + #[clap(short, long, env, default_value = "0.0.0.0:3000")] + addr: SocketAddr, + }, + + Csv { + /// The path to the CSV file to ingest. + csv_file: PathBuf, + }, +} + #[derive(Debug, clap::Parser)] -struct Config { - /// If we should perform migration on startup. - #[clap(short, long, env, default_value_t = true)] - migrate: bool, - - /// The server address to bind to. - #[clap(short, long, env, default_value = "0.0.0.0:3000")] - addr: SocketAddr, - +struct Cli { /// URL to PostgreSQL database. #[clap(short, long = "db", env)] database_url: String, + + #[command(subcommand)] + command: Commands, } async fn health_check( @@ -37,18 +63,48 @@ async fn health_check( #[tokio::main] async fn main() -> anyhow::Result<()> { - let config: Config = Config::parse(); - let connection = sea_orm::ConnectOptions::new(&config.database_url) + tracing_subscriber::fmt::init(); + + let cli: Cli = Cli::parse(); + let connection = sea_orm::ConnectOptions::new(&cli.database_url) .sqlx_logging_level(LevelFilter::Debug) .to_owned(); let connection = sea_orm::Database::connect(connection).await?; - if config.migrate { - Migrator::up(&connection, None).await?; + match cli.command { + Commands::Migrate { steps, down } => { + if down { + Migrator::down(&connection, steps).await?; + } else { + Migrator::up(&connection, steps).await? + } + } + + Commands::Run { migrate, addr } => { + if migrate { + Migrator::up(&connection, None).await?; + } + + serve_web(addr, connection).await?; + } + + Commands::Csv { csv_file } => { + let mut csv = csv::Reader::from_reader(File::open(csv_file)?); + let data = csv.records(); + let data = data + .filter_map(|f| f.ok()) + .map(from_csv_row) + .collect::>()?; + + db::insert(&connection, data).await?; + } } - tracing_subscriber::fmt::init(); + Ok(()) +} + +async fn serve_web(address: SocketAddr, connection: DatabaseConnection) -> anyhow::Result<()> { let app = Router::new() .route("/health", get(health_check)) .route("/monzo-batch-export", post(monzo_batched_json)) @@ -56,9 +112,9 @@ async fn main() -> anyhow::Result<()> { .layer(Extension(connection.clone())) .layer(TraceLayer::new_for_http()); - tracing::debug!("listening on {}", &config.addr); - let listener = tokio::net::TcpListener::bind(&config.addr).await.unwrap(); - axum::serve(listener, app).await.unwrap(); + tracing::info!("listening on {}", &address); + let listener = tokio::net::TcpListener::bind(&address).await?; + axum::serve(listener, app).await?; Ok(()) }