Add a little more CLI structure for local runs

This commit is contained in:
Joshua Coles 2024-06-03 19:59:22 +01:00
parent 7fd85550ea
commit b563bbd02c

View File

@ -2,29 +2,55 @@ mod error;
mod ingestion;
use crate::error::AppError;
use crate::ingestion::db;
use crate::ingestion::ingestion_logic::from_csv_row;
use crate::ingestion::routes::{monzo_batched_csv, monzo_batched_json};
use axum::routing::{get, post};
use axum::{Extension, Router};
use clap::Parser;
use clap::{Parser, Subcommand};
use migration::{Migrator, MigratorTrait};
use sea_orm::{ConnectionTrait, DatabaseConnection};
use std::fs::File;
use std::net::SocketAddr;
use std::path::PathBuf;
use tower_http::trace::TraceLayer;
use tracing::log::LevelFilter;
#[derive(Debug, Subcommand)]
enum Commands {
Migrate {
/// Number of migration steps to perform. If not provided, all migrations will be run.
#[arg(long)]
steps: Option<u32>,
/// If we should perform migration down.
#[arg(long)]
down: bool,
},
Run {
/// If we should perform migration on startup.
#[clap(short, long, env, default_value_t = true)]
migrate: bool,
/// The server address to bind to.
#[clap(short, long, env, default_value = "0.0.0.0:3000")]
addr: SocketAddr,
},
Csv {
/// The path to the CSV file to ingest.
csv_file: PathBuf,
},
}
#[derive(Debug, clap::Parser)]
struct Config {
/// If we should perform migration on startup.
#[clap(short, long, env, default_value_t = true)]
migrate: bool,
/// The server address to bind to.
#[clap(short, long, env, default_value = "0.0.0.0:3000")]
addr: SocketAddr,
struct Cli {
/// URL to PostgreSQL database.
#[clap(short, long = "db", env)]
database_url: String,
#[command(subcommand)]
command: Commands,
}
async fn health_check(
@ -37,18 +63,48 @@ async fn health_check(
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let config: Config = Config::parse();
let connection = sea_orm::ConnectOptions::new(&config.database_url)
tracing_subscriber::fmt::init();
let cli: Cli = Cli::parse();
let connection = sea_orm::ConnectOptions::new(&cli.database_url)
.sqlx_logging_level(LevelFilter::Debug)
.to_owned();
let connection = sea_orm::Database::connect(connection).await?;
if config.migrate {
Migrator::up(&connection, None).await?;
match cli.command {
Commands::Migrate { steps, down } => {
if down {
Migrator::down(&connection, steps).await?;
} else {
Migrator::up(&connection, steps).await?
}
}
Commands::Run { migrate, addr } => {
if migrate {
Migrator::up(&connection, None).await?;
}
serve_web(addr, connection).await?;
}
Commands::Csv { csv_file } => {
let mut csv = csv::Reader::from_reader(File::open(csv_file)?);
let data = csv.records();
let data = data
.filter_map(|f| f.ok())
.map(from_csv_row)
.collect::<Result<_, _>>()?;
db::insert(&connection, data).await?;
}
}
tracing_subscriber::fmt::init();
Ok(())
}
async fn serve_web(address: SocketAddr, connection: DatabaseConnection) -> anyhow::Result<()> {
let app = Router::new()
.route("/health", get(health_check))
.route("/monzo-batch-export", post(monzo_batched_json))
@ -56,9 +112,9 @@ async fn main() -> anyhow::Result<()> {
.layer(Extension(connection.clone()))
.layer(TraceLayer::new_for_http());
tracing::debug!("listening on {}", &config.addr);
let listener = tokio::net::TcpListener::bind(&config.addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
tracing::info!("listening on {}", &address);
let listener = tokio::net::TcpListener::bind(&address).await?;
axum::serve(listener, app).await?;
Ok(())
}