Merge pull request 'Overhaul database driver' (#9) from joel/db-driver into main

Reviewed-on: xacrimon/no-no#9
This commit is contained in:
Joel Wejdenstål 2024-07-01 04:14:04 +00:00
commit 08ddc83d39
14 changed files with 493 additions and 537 deletions

View file

@ -0,0 +1,144 @@
use rusqlite::types::FromSql;
/// row_extract is a helper function that extracts a tuple of values from a rusqlite::Row.
pub(crate) fn row_extract<T: FromRow>(row: &rusqlite::Row) -> rusqlite::Result<T> {
T::from_row(row)
}
/// FromRow is a trait that allows for the conversion of a rusqlite::Row into a tuple of values.
/// This is useful for extracting values from a row in a query and is used internally by `row_extract`.
pub(crate) trait FromRow {
fn from_row(row: &rusqlite::Row) -> rusqlite::Result<Self>
where
Self: Sized;
}
impl<A> FromRow for (A,)
where
A: FromSql,
{
fn from_row(row: &rusqlite::Row) -> rusqlite::Result<Self> {
Ok((row.get(0)?,))
}
}
impl<A, B> FromRow for (A, B)
where
A: FromSql,
B: FromSql,
{
fn from_row(row: &rusqlite::Row) -> rusqlite::Result<Self> {
Ok((row.get(0)?, row.get(1)?))
}
}
impl<A, B, C> FromRow for (A, B, C)
where
A: FromSql,
B: FromSql,
C: FromSql,
{
fn from_row(row: &rusqlite::Row) -> rusqlite::Result<Self> {
Ok((row.get(0)?, row.get(1)?, row.get(2)?))
}
}
impl<A, B, C, D> FromRow for (A, B, C, D)
where
A: FromSql,
B: FromSql,
C: FromSql,
D: FromSql,
{
fn from_row(row: &rusqlite::Row) -> rusqlite::Result<Self> {
Ok((row.get(0)?, row.get(1)?, row.get(2)?, row.get(3)?))
}
}
impl<A, B, C, D, E> FromRow for (A, B, C, D, E)
where
A: FromSql,
B: FromSql,
C: FromSql,
D: FromSql,
E: FromSql,
{
fn from_row(row: &rusqlite::Row) -> rusqlite::Result<Self> {
Ok((
row.get(0)?,
row.get(1)?,
row.get(2)?,
row.get(3)?,
row.get(4)?,
))
}
}
impl<A, B, C, D, E, F> FromRow for (A, B, C, D, E, F)
where
A: FromSql,
B: FromSql,
C: FromSql,
D: FromSql,
E: FromSql,
F: FromSql,
{
fn from_row(row: &rusqlite::Row) -> rusqlite::Result<Self> {
Ok((
row.get(0)?,
row.get(1)?,
row.get(2)?,
row.get(3)?,
row.get(4)?,
row.get(5)?,
))
}
}
impl<A, B, C, D, E, F, G> FromRow for (A, B, C, D, E, F, G)
where
A: FromSql,
B: FromSql,
C: FromSql,
D: FromSql,
E: FromSql,
F: FromSql,
G: FromSql,
{
fn from_row(row: &rusqlite::Row) -> rusqlite::Result<Self> {
Ok((
row.get(0)?,
row.get(1)?,
row.get(2)?,
row.get(3)?,
row.get(4)?,
row.get(5)?,
row.get(6)?,
))
}
}
impl<A, B, C, D, E, F, G, H> FromRow for (A, B, C, D, E, F, G, H)
where
A: FromSql,
B: FromSql,
C: FromSql,
D: FromSql,
E: FromSql,
F: FromSql,
G: FromSql,
H: FromSql,
{
fn from_row(row: &rusqlite::Row) -> rusqlite::Result<Self> {
Ok((
row.get(0)?,
row.get(1)?,
row.get(2)?,
row.get(3)?,
row.get(4)?,
row.get(5)?,
row.get(6)?,
row.get(7)?,
))
}
}

View file

@ -0,0 +1,73 @@
use std::{fs, path::Path};
use anyhow::{anyhow, Result};
use rusqlite::Connection;
use super::params;
/// `apply_pragmas` applies the database settings defined in `params::DB_SETTINGS`
/// to the given connection using `PRAGMA` statements.
pub(super) fn apply_pragmas(conn: &Connection) -> Result<()> {
for (pragma, value) in params::DB_SETTINGS {
conn.pragma_update(None, pragma, value)?;
}
Ok(())
}
/// `apply_migrations` applies migrations found in the `migrations` directory
/// if the current database version is outdated.
pub(super) fn apply_migrations(conn: &mut Connection) -> Result<()> {
let tx = conn.transaction()?;
let current_version: i32 =
tx.query_row("SELECT user_version FROM pragma_user_version", [], |row| {
row.get(0)
})?;
let migrations = load_migrations(current_version)?;
for (_name, version, migration) in migrations {
if version > current_version {
tx.execute_batch(&migration)?;
tx.pragma_update(None, "user_version", version)?;
}
}
tx.commit()?;
Ok(())
}
fn load_migrations(above: i32) -> Result<Vec<(String, i32, String)>> {
let mut migrations = Vec::new();
let path = if Path::exists(Path::new("./backend")) {
"./backend/migrations"
} else {
"./migrations"
};
for entry in fs::read_dir(path)? {
let entry = entry?;
let (name, version) = extract_key_from_entry(&entry)?;
let migration = fs::read_to_string(entry.path())?;
if version > above {
migrations.push((name, version, migration));
}
}
migrations.sort_by_key(|(_, num, _)| *num);
Ok(migrations)
}
fn extract_key_from_entry(entry: &fs::DirEntry) -> Result<(String, i32)> {
let raw_name = entry.file_name();
let name = raw_name
.to_str()
.ok_or_else(|| anyhow!("invalid file name"))?;
let num = name
.split('-')
.next()
.ok_or_else(|| anyhow!("missing number in migration name"))?
.parse()?;
Ok((name.to_owned(), num))
}

View file

@ -0,0 +1,99 @@
mod helpers;
mod initialize;
mod params;
mod pool;
use std::{
mem::MaybeUninit,
sync::{Arc, Condvar, Mutex},
};
use anyhow::Result;
pub(crate) use helpers::row_extract;
use pool::Pool;
use rusqlite::Connection;
use tokio::task;
use crate::Config;
/// `Database` is a wrapper around a connection pool for SQLite databases that can be shared across threads.
pub(crate) struct Database {
pool: Mutex<Pool>,
pool_notify: Condvar,
}
impl Database {
/// `new` creates a new `Database` instance with the given configuration.
pub(crate) fn new(config: &Config) -> Result<Self> {
let path = config.instance.data_directory.join(params::DB_FILE_NAME);
let pool = Pool::new(path);
Ok(Self {
pool: Mutex::new(pool),
pool_notify: Condvar::new(),
})
}
fn acquire(&self) -> Result<Connection> {
self.pool.lock().unwrap().acquire(&self.pool_notify)
}
fn release(&self, conn: Connection) {
self.pool.lock().unwrap().release(conn, &self.pool_notify)
}
}
struct ConnGuard<'a> {
db: &'a Database,
conn: MaybeUninit<Connection>,
}
impl ConnGuard<'_> {
fn conn_mut(&mut self) -> &mut Connection {
unsafe { self.conn.assume_init_mut() }
}
}
impl Drop for ConnGuard<'_> {
fn drop(&mut self) {
let conn = unsafe { self.conn.assume_init_read() };
self.db.release(conn);
}
}
/// `query` executes a closure providing a connection to the database and returns the result.
pub(crate) async fn query<F, T>(db: &Arc<Database>, query: F) -> Result<T>
where
T: Send + 'static,
F: FnOnce(&rusqlite::Connection) -> Result<T> + Send + 'static,
{
query_inner(db, move |conn| query(conn)).await
}
/// `query_tx` executes a closure providing a transaction to the database and returns the result.
pub(crate) async fn query_tx<F, T>(db: &Arc<Database>, query: F) -> Result<T>
where
T: Send + 'static,
F: FnOnce(rusqlite::Transaction) -> Result<T> + Send + 'static,
{
query_inner(db, move |conn| query(conn.transaction()?)).await
}
async fn query_inner<F, T>(db: &Arc<Database>, query: F) -> Result<T>
where
T: Send + 'static,
F: FnOnce(&mut rusqlite::Connection) -> Result<T> + Send + 'static,
{
let db = Arc::clone(db);
task::spawn_blocking(move || {
let conn = db.acquire()?;
let mut guard = ConnGuard {
db: &db,
conn: MaybeUninit::new(conn),
};
query(guard.conn_mut())
})
.await
.unwrap()
}

View file

@ -0,0 +1,28 @@
use std::time::Duration;
/// SQLite PRAGMA settings to apply to database connections.
pub(super) static DB_SETTINGS: &[(&str, &str)] = &[
("journal_mode", "delete"),
("synchronous", "full"),
("cache_size", "-8192"),
("busy_timeout", "5000"),
("temp_store", "memory"),
("foreign_key", "on"),
];
/// Name of the SQLite database file on disk within the data directory.
pub(super) const DB_FILE_NAME: &str = "persistent";
/// Minimum number of connections to keep open in the pool.
/// Ensures some amount of connections are always instantly available as old ones are recycled.
pub(super) const MIN_CONNECTIONS: usize = 1;
/// Maximum number of connections to keep open in the pool.
/// This is thus also the maximum amount queries that can be executed concurrently.
pub(super) const MAX_CONNECTIONS: usize = 4;
/// Maximum amount of time a connection will be kept alive after it was last used.
pub(super) const CONNECTION_KEEPALIVE: Duration = Duration::from_secs(60);
/// Maximum amount of time to wait for a connection to become available in the pool.
pub(super) const CONNECTION_ACQUIRE_TIMEOUT: Duration = Duration::from_secs(15);

View file

@ -0,0 +1,84 @@
use std::{
path::PathBuf,
sync::{Condvar, MutexGuard},
time::Instant,
};
use anyhow::{anyhow, Result};
use rusqlite::Connection;
use super::{initialize, params};
/// `Pool` is a connection pool for SQLite databases with dynamic resizing.
pub struct Pool {
connections: Vec<Entry>,
open_connections: usize,
db_path: PathBuf,
}
impl Pool {
/// `new` creates a new connection pool with the given database path.
pub fn new(db_path: PathBuf) -> Self {
Self {
connections: Vec::new(),
open_connections: 0,
db_path,
}
}
/// `acquire` returns a connection from the pool or creates a new one if possible.
/// If no connection is available and the pool is at capacity, it will wait for one to be released.
pub fn acquire(mut self: MutexGuard<'_, Self>, condvar: &Condvar) -> Result<Connection> {
loop {
if let Some(entry) = self.connections.pop() {
break Ok(entry.conn);
} else if self.open_connections < params::MAX_CONNECTIONS {
break self.open_connection();
}
let (guard, wait) = condvar
.wait_timeout(self, params::CONNECTION_ACQUIRE_TIMEOUT)
.unwrap();
if wait.timed_out() {
break Err(anyhow!("timed out waiting for a connection"));
}
self = guard;
}
}
/// `release` returns a connection to the pool for reuse.
pub fn release(mut self: MutexGuard<'_, Self>, conn: Connection, condvar: &Condvar) {
self.recycle_unused_connections();
self.connections.push(Entry {
conn,
last_used: Instant::now(),
});
drop(self);
condvar.notify_one();
}
fn open_connection(&mut self) -> Result<Connection> {
let mut conn = Connection::open(&self.db_path)?;
initialize::apply_pragmas(&conn)?;
initialize::apply_migrations(&mut conn)?;
self.open_connections += 1;
Ok(conn)
}
fn recycle_unused_connections(&mut self) {
while self.connections.len() + 1 > params::MIN_CONNECTIONS
&& self.connections[0].last_used.elapsed() > params::CONNECTION_KEEPALIVE
{
self.connections.remove(0);
self.open_connections -= 1;
}
}
}
struct Entry {
conn: Connection,
last_used: Instant,
}

View file

@ -1,6 +1,8 @@
#![feature(async_closure)]
#![feature(arbitrary_self_types)]
pub mod config;
mod database;
mod error;
mod routes;
mod services;
@ -19,11 +21,9 @@ use axum::{
};
pub use config::sample_prod_config;
use config::Config;
use database::Database;
use routes::PublicCache;
use services::{
database::{Database, DatabaseAnalyzeJob},
jobs::JobsService,
};
use services::jobs::JobsService;
use session::SessionService;
use specials::Closer;
pub use tokio;
@ -32,10 +32,9 @@ use tracing::info;
pub async fn build_app(config: &Arc<Config>) -> Result<(Router, Closer, Vec<(Duration, &str)>)> {
let mut closer = Closer::new();
let db = Database::new(config).await?;
let db = Arc::new(Database::new(config)?);
let session_service = SessionService::new(Arc::clone(&db))?;
let jobs_service = JobsService::new(Arc::clone(&db), Arc::clone(&session_service), &mut closer);
jobs_service.schedule::<DatabaseAnalyzeJob>().await;
let upcoming = jobs_service.upcoming().await;
let origin = config
.web

View file

@ -11,8 +11,8 @@ use serde_json::json;
use validator::Validate;
use crate::{
database::{query_tx, Database},
error::AppError,
services::database::{query_tx, Database},
utils,
utils::{validate_password, validate_username, ValidatedJson},
};
@ -37,7 +37,7 @@ pub async fn create_handler(
ValidatedJson(payload): ValidatedJson<Payload>,
) -> Result<Json<serde_json::Value>, AppError> {
let hash = utils::hash_password(&payload.password);
query_tx(&db, |tx| {
query_tx(&db, move |tx| {
let account_id: i32 = tx
.prepare_cached(
"INSERT INTO accounts (created_at, username, display_name, email) VALUES \

View file

@ -11,8 +11,8 @@ use serde_json::json;
use validator::Validate;
use crate::{
database::{query, row_extract, Database},
error::AppError,
services::database::{query, row_extract, Database},
session::SessionService,
utils::{self, validate_login_identifier, validate_password, ValidatedJson},
};
@ -43,7 +43,7 @@ pub async fn create_handler(
Extension(session_service): Extension<Arc<SessionService>>,
ValidatedJson(payload): ValidatedJson<Payload>,
) -> Result<Json<serde_json::Value>, AppError> {
let (id, _username, password): (i32, String, String) = query(&db, |conn| {
let (id, _username, password): (i32, String, String) = query(&db, move |conn| {
let sql = format!(
"SELECT id, username, password FROM accounts INNER JOIN identities ON accounts.id = \
identities.account_id WHERE kind = 'local' AND {} = ?",

View file

@ -12,8 +12,8 @@ use serde_json::json;
use validator::Validate;
use crate::{
database::{query_tx, row_extract, Database},
error::AppError,
services::database::{query_tx, row_extract, Database},
session::{OptionalSession, Session},
utils::{validate_snippet, Language, SnippetPath, ValidatedJson, Visibility},
};
@ -53,7 +53,7 @@ pub async fn get_handler(
Query(parameters): Query<Parameters>,
OptionalSession(session): OptionalSession,
) -> Result<Json<Snippet>, AppError> {
let snippet = query_tx(&db, |tx| {
let snippet = query_tx(&db, move |tx| {
let (id, author, visibility, lang): (i32, i32, Visibility, Option<Language>) = tx
.prepare_cached(
"SELECT id, author, visibility, lang FROM snippets WHERE name = ? AND author = \
@ -121,7 +121,7 @@ pub async fn create_handler(
session: Session,
ValidatedJson(payload): ValidatedJson<Payload>,
) -> Result<Json<serde_json::Value>, AppError> {
query_tx(&db, |tx| {
query_tx(&db, move |tx| {
let (snippet,): (i32,) = tx
.prepare_cached(
"INSERT INTO snippets (author, name, visibility, lang) VALUES (?, ?, ?, ?) \

View file

@ -4,8 +4,8 @@ use anyhow::Result;
use axum::extract::Extension;
use crate::{
database::{query_tx, row_extract, Database},
error::AppError,
services::database::{query_tx, row_extract, Database},
};
pub async fn handler(Extension(db): Extension<Arc<Database>>) -> Result<String, AppError> {

View file

@ -1,483 +0,0 @@
use std::{
fs,
mem::MaybeUninit,
ops::{Deref, DerefMut},
path::{Path, PathBuf},
sync::{Arc, Mutex},
thread,
time::{Duration, Instant},
};
use anyhow::{anyhow, bail, Result};
use axum::async_trait;
use rusqlite::{types::FromSql, Connection, TransactionBehavior};
use serde::{Deserialize, Serialize};
use serde_json::json;
use tokio::{
sync::{Semaphore, SemaphorePermit},
task,
};
use tracing::{debug, info, instrument, warn};
use crate::{
config::Config,
services::jobs::{Job, Schedulable, State},
};
static DB_SETTINGS: &[(&str, &str)] = &[
("journal_mode", "delete"),
("synchronous", "full"),
("cache_size", "-8192"),
("busy_timeout", "100"),
("temp_store", "memory"),
("foreign_key", "on"),
];
const DB_TIMEOUT: Duration = Duration::from_secs(10);
const DB_FILE_NAME: &str = "persistent";
const MIN_CONNECTIONS: usize = 1;
const MAX_CONNECTIONS: usize = 4;
const KEEPALIVE_CONNECTIONS: Duration = Duration::from_secs(60);
pub struct Database {
tracker: Semaphore,
connections: Mutex<ConnectionList>,
path: PathBuf,
}
impl Database {
pub async fn new(config: &Config) -> Result<Arc<Self>> {
debug!("detected sqlite library version: {}", rusqlite::version());
info!("connecting to database...");
let path = config.instance.data_directory.join(DB_FILE_NAME);
let db = Self {
tracker: Semaphore::new(0),
connections: Mutex::new(ConnectionList::new()),
path,
};
db.acquire_conn().await?;
Ok(Arc::new(db))
}
pub async fn acquire_conn(&self) -> Result<ConnectionGuard<'_>> {
let permit = match self.tracker.try_acquire() {
Ok(permit) => permit,
Err(_) => {
{
let mut connections = self.connections.lock().unwrap();
if connections.open_count() < MAX_CONNECTIONS {
let mut conn = Connection::open(&self.path)?;
apply_pragmas(&conn)?;
apply_migrations(&mut conn)?;
connections.add_connection(conn);
self.tracker.add_permits(1);
}
}
self.tracker.acquire().await?
},
};
let conn = self.connections.lock().unwrap().acquire();
Ok(ConnectionGuard {
conn: MaybeUninit::new(conn),
database: self,
_permit: permit,
})
}
}
struct ConnectionList {
open_connections: usize,
connections: Vec<(Connection, Instant)>,
}
impl ConnectionList {
fn new() -> Self {
Self {
open_connections: 0,
connections: Vec::new(),
}
}
fn open_count(&self) -> usize {
self.connections.len()
}
fn add_connection(&mut self, conn: Connection) {
self.connections.push((conn, Instant::now()));
self.open_connections += 1;
}
fn acquire(&mut self) -> Connection {
let (conn, _) = self.connections.pop().unwrap();
conn
}
fn release(&mut self, conn: Connection) {
self.connections.push((conn, Instant::now()));
if self.connections.len() > MIN_CONNECTIONS
&& self.connections[0].1.elapsed() > KEEPALIVE_CONNECTIONS
{
self.connections.remove(0);
}
}
}
pub struct ConnectionGuard<'a> {
conn: MaybeUninit<Connection>,
database: &'a Database,
_permit: SemaphorePermit<'a>,
}
impl ConnectionGuard<'_> {
pub fn conn(&self) -> &Connection {
unsafe { self.conn.assume_init_ref() }
}
pub fn conn_mut(&mut self) -> &mut Connection {
unsafe { self.conn.assume_init_mut() }
}
}
impl Deref for ConnectionGuard<'_> {
type Target = Connection;
fn deref(&self) -> &Self::Target {
self.conn()
}
}
impl DerefMut for ConnectionGuard<'_> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.conn_mut()
}
}
impl Drop for ConnectionGuard<'_> {
fn drop(&mut self) {
let conn = unsafe { self.conn.assume_init_read() };
let mut connections = self.database.connections.lock().unwrap();
connections.release(conn);
}
}
fn apply_pragmas(conn: &Connection) -> Result<()> {
for (pragma, value) in DB_SETTINGS {
conn.pragma_update(None, pragma, value)?;
}
Ok(())
}
fn apply_migrations(conn: &mut Connection) -> Result<()> {
let tx = conn.transaction_with_behavior(TransactionBehavior::Exclusive)?;
let current_version: i32 =
tx.query_row("SELECT user_version FROM pragma_user_version", [], |row| {
row.get(0)
})?;
let migrations = load_migrations(current_version)?;
for (name, version, migration) in migrations {
if version > current_version {
info!("applying migration: {}...", name);
tx.execute_batch(&migration)?;
tx.pragma_update(None, "user_version", version)?;
}
}
tx.commit()?;
Ok(())
}
fn load_migrations(above: i32) -> Result<Vec<(String, i32, String)>> {
let mut migrations = Vec::new();
let path = if Path::exists(Path::new("./backend")) {
"./backend/migrations"
} else {
"./migrations"
};
for entry in fs::read_dir(path)? {
let entry = entry?;
let (name, version) = extract_key_from_entry(&entry)?;
let migration = fs::read_to_string(entry.path())?;
if version > above {
migrations.push((name, version, migration));
}
}
migrations.sort_by_key(|(_, num, _)| *num);
Ok(migrations)
}
fn extract_key_from_entry(entry: &fs::DirEntry) -> Result<(String, i32)> {
let raw_name = entry.file_name();
let name = raw_name
.to_str()
.ok_or_else(|| anyhow!("invalid file name"))?;
let num = name
.split('-')
.next()
.ok_or_else(|| anyhow!("missing number in migration name"))?
.parse()?;
Ok((name.to_owned(), num))
}
trait QueryMode {
type Handle<'conn>;
fn handle(connection: &mut rusqlite::Connection) -> rusqlite::Result<Self::Handle<'_>>;
}
struct NoTx;
impl QueryMode for NoTx {
type Handle<'conn> = &'conn rusqlite::Connection;
fn handle(connection: &mut rusqlite::Connection) -> rusqlite::Result<Self::Handle<'_>> {
Ok(connection)
}
}
struct Tx;
impl QueryMode for Tx {
type Handle<'conn> = rusqlite::Transaction<'conn>;
fn handle(connection: &mut rusqlite::Connection) -> rusqlite::Result<Self::Handle<'_>> {
connection.transaction()
}
}
#[instrument(skip(db, query), err)]
async fn query_inner<M, F, T>(db: &Database, mut query: F) -> Result<T>
where
T: Send,
M: QueryMode,
F: FnMut(M::Handle<'_>) -> Result<T> + Send,
{
let conn = &mut *db.acquire_conn().await?;
task::block_in_place(move || {
let start = Instant::now();
let end = start + DB_TIMEOUT;
let ret = loop {
let handle = M::handle(conn)?;
match query(handle) {
Ok(item) => break Ok(item),
Err(err) => {
if let Some(sq_err) = err.downcast_ref::<rusqlite::Error>() {
if let Some(code) = sq_err.sqlite_error_code() {
if code == rusqlite::ErrorCode::DatabaseBusy {
warn!("database is busy, retrying");
if Instant::now() > end {
bail!("database busy, timed out");
} else {
thread::yield_now();
continue;
}
}
}
}
break Err(err);
},
}
};
debug!("transaction took {:?}", start.elapsed());
ret
})
}
pub async fn query<F, T>(db: &Database, query: F) -> Result<T>
where
T: Send,
F: FnMut(&rusqlite::Connection) -> Result<T> + Send,
{
query_inner::<NoTx, _, _>(db, query).await
}
pub async fn query_tx<F, T>(db: &Database, query: F) -> Result<T>
where
T: Send,
F: FnMut(rusqlite::Transaction) -> Result<T> + Send,
{
query_inner::<Tx, _, _>(db, query).await
}
#[derive(Default, Serialize, Deserialize)]
pub struct DatabaseAnalyzeJob;
#[async_trait]
#[typetag::serde]
impl Job for DatabaseAnalyzeJob {
async fn run(&self, state: &State) -> Result<serde_json::Value> {
query(&state.db, |conn| {
conn.execute_batch(
r#"
PRAGMA analysis_limit=400;
ANALYZE;
"#,
)?;
Ok(())
})
.await?;
Ok(json!({}))
}
}
impl Schedulable for DatabaseAnalyzeJob {
const INTERVAL: Duration = Duration::from_secs(60 * 60 * 24);
}
pub fn row_extract<T: FromRow>(row: &rusqlite::Row) -> rusqlite::Result<T> {
T::from_row(row)
}
pub trait FromRow {
fn from_row(row: &rusqlite::Row) -> rusqlite::Result<Self>
where
Self: Sized;
}
impl<A> FromRow for (A,)
where
A: FromSql,
{
fn from_row(row: &rusqlite::Row) -> rusqlite::Result<Self> {
Ok((row.get(0)?,))
}
}
impl<A, B> FromRow for (A, B)
where
A: FromSql,
B: FromSql,
{
fn from_row(row: &rusqlite::Row) -> rusqlite::Result<Self> {
Ok((row.get(0)?, row.get(1)?))
}
}
impl<A, B, C> FromRow for (A, B, C)
where
A: FromSql,
B: FromSql,
C: FromSql,
{
fn from_row(row: &rusqlite::Row) -> rusqlite::Result<Self> {
Ok((row.get(0)?, row.get(1)?, row.get(2)?))
}
}
impl<A, B, C, D> FromRow for (A, B, C, D)
where
A: FromSql,
B: FromSql,
C: FromSql,
D: FromSql,
{
fn from_row(row: &rusqlite::Row) -> rusqlite::Result<Self> {
Ok((row.get(0)?, row.get(1)?, row.get(2)?, row.get(3)?))
}
}
impl<A, B, C, D, E> FromRow for (A, B, C, D, E)
where
A: FromSql,
B: FromSql,
C: FromSql,
D: FromSql,
E: FromSql,
{
fn from_row(row: &rusqlite::Row) -> rusqlite::Result<Self> {
Ok((
row.get(0)?,
row.get(1)?,
row.get(2)?,
row.get(3)?,
row.get(4)?,
))
}
}
impl<A, B, C, D, E, F> FromRow for (A, B, C, D, E, F)
where
A: FromSql,
B: FromSql,
C: FromSql,
D: FromSql,
E: FromSql,
F: FromSql,
{
fn from_row(row: &rusqlite::Row) -> rusqlite::Result<Self> {
Ok((
row.get(0)?,
row.get(1)?,
row.get(2)?,
row.get(3)?,
row.get(4)?,
row.get(5)?,
))
}
}
impl<A, B, C, D, E, F, G> FromRow for (A, B, C, D, E, F, G)
where
A: FromSql,
B: FromSql,
C: FromSql,
D: FromSql,
E: FromSql,
F: FromSql,
G: FromSql,
{
fn from_row(row: &rusqlite::Row) -> rusqlite::Result<Self> {
Ok((
row.get(0)?,
row.get(1)?,
row.get(2)?,
row.get(3)?,
row.get(4)?,
row.get(5)?,
row.get(6)?,
))
}
}
impl<A, B, C, D, E, F, G, H> FromRow for (A, B, C, D, E, F, G, H)
where
A: FromSql,
B: FromSql,
C: FromSql,
D: FromSql,
E: FromSql,
F: FromSql,
G: FromSql,
H: FromSql,
{
fn from_row(row: &rusqlite::Row) -> rusqlite::Result<Self> {
Ok((
row.get(0)?,
row.get(1)?,
row.get(2)?,
row.get(3)?,
row.get(4)?,
row.get(5)?,
row.get(6)?,
row.get(7)?,
))
}
}

View file

@ -14,8 +14,11 @@ use tokio::{
};
use tracing::{debug, error, info, instrument, trace};
use super::database::{query, row_extract, Database};
use crate::{session::SessionService, specials::Closer};
use crate::{
database::{query, row_extract, Database},
session::SessionService,
specials::Closer,
};
const JOB_POLL_INTERVAL: Duration = Duration::from_secs(10);
const JOB_SCHEDULER_POLL_INTERVAL: Duration = Duration::from_secs(60);
@ -130,7 +133,7 @@ impl JobsService {
#[instrument(skip(db, scheduler, close_rx))]
#[allow(clippy::type_complexity)]
async fn overwatch(
db: &Database,
db: &Arc<Database>,
scheduler: &Mutex<
Vec<(
Instant,
@ -141,15 +144,16 @@ async fn overwatch(
>,
mut close_rx: oneshot::Receiver<()>,
) {
let schedule_pending = || async {
let scheduler = &mut *scheduler.lock().await;
for (next, interval, job, _) in scheduler {
let schedule_pending = async || {
let mut scheduler = scheduler.lock().await;
for (next, interval, job, _) in &mut *scheduler {
if *next > Instant::now() {
continue;
}
*next = Instant::now() + *interval;
query(db, |conn| {
let job = *job;
query(db, move |conn| {
job(conn)?;
Ok(())
})
@ -191,7 +195,7 @@ async fn worker_jobs_loop(state: &State, mut close_rx: oneshot::Receiver<()>) {
}
}
async fn load_job(db: &Database) -> Result<(i32, Box<dyn Job>)> {
async fn load_job(db: &Arc<Database>) -> Result<(i32, Box<dyn Job>)> {
let (id, args) = loop {
let maybe_data = query(db, |conn| {
let maybe_row: rusqlite::Result<(i32, String)> = conn
@ -248,7 +252,7 @@ async fn run_job(state: &State) -> Result<(i32, Result<serde_json::Value>)> {
}
_ = lock_update.tick() => {
debug!("updating lock for background job with id: {}", id);
query(&state.db, |conn| {
query(&state.db, move |conn| {
conn.prepare_cached("UPDATE job_queue SET active = unixepoch() WHERE id = ?")?.execute((id,))?;
Ok(())
}).await?;
@ -266,8 +270,8 @@ async fn run_job(state: &State) -> Result<(i32, Result<serde_json::Value>)> {
}
async fn save_job_result(state: &State, id: i32, res: &Result<serde_json::Value>) -> Result<()> {
let update = async |job_state: &str, output: &str| {
query(&state.db, |conn| {
let update = async |job_state: &'static str, output: String| {
query(&state.db, move |conn| {
conn.prepare_cached(
"UPDATE job_queue SET state = ?, completed = unixepoch(), output = ?, active = \
NULL WHERE id = ?",
@ -282,14 +286,14 @@ async fn save_job_result(state: &State, id: i32, res: &Result<serde_json::Value>
match res {
Ok(output) => {
let output = serde_json::to_string(&output)?;
update("completed", &output).await?;
update("completed", output).await?;
info!("successfully completed background job with id: {}", id);
Ok(())
},
Err(err) => {
let output = json!({ "error": err.to_string() });
let output = serde_json::to_string(&output)?;
update("failed", &output).await?;
update("failed", output).await?;
error!("error executing background job with id {}: {}", id, err);
Ok(())
},

View file

@ -1,2 +1 @@
pub mod database;
pub mod jobs;

View file

@ -9,8 +9,10 @@ use axum::{
use base64::prelude::{Engine, BASE64_URL_SAFE_NO_PAD};
use serde::{Deserialize, Serialize};
use super::{services::database::query, utils};
use crate::services::database::Database;
use crate::{
database::{query, Database},
utils,
};
const SESSION_LIFESPAN: Duration = Duration::from_secs(60 * 60 * 24 * 30);
@ -23,17 +25,20 @@ impl SessionService {
Ok(Arc::new(Self { db }))
}
pub async fn try_get(&self, token: &str) -> Result<Session> {
let (account_id, sudo_until, expires): (i32, Option<i32>, i32) = query(&self.db, |conn| {
let data = conn
.prepare_cached(
"SELECT account_id, sudo_until, expires FROM sessions WHERE token = ?",
)?
.query_row((token,), |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)))?;
pub async fn try_get(&self, token: String) -> Result<Session> {
let (account_id, sudo_until, expires): (i32, Option<i32>, i32) = {
let token = token.clone();
query(&self.db, |conn| {
let data = conn
.prepare_cached(
"SELECT account_id, sudo_until, expires FROM sessions WHERE token = ?",
)?
.query_row((token,), |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)))?;
Ok(data)
})
.await?;
Ok(data)
})
.await?
};
let now = utils::timestamp();
if now > expires {
@ -41,7 +46,7 @@ impl SessionService {
}
Ok(Session {
token: token.to_string(),
token,
uid: account_id,
is_sudo: sudo_until.map_or(false, |until| until > now),
})
@ -52,22 +57,26 @@ impl SessionService {
let token = BASE64_URL_SAFE_NO_PAD.encode(secret.to_le_bytes());
let expires = utils::timestamp_advance(utils::timestamp(), SESSION_LIFESPAN);
query(&self.db, |conn| {
conn.execute(
"INSERT INTO sessions (token, account_id, sudo_until, expires) VALUES (?, ?, ?, ?)",
(&token, uid, None::<bool>, expires),
)?;
{
let token = token.clone();
query(&self.db, move |conn| {
conn.execute(
"INSERT INTO sessions (token, account_id, sudo_until, expires) VALUES (?, ?, \
?, ?)",
(token, uid, None::<bool>, expires),
)?;
Ok(())
})
.await?;
Ok(())
})
.await?;
}
Ok(token)
}
pub async fn invalidate(&self, session: &Session) -> Result<()> {
query(&self.db, |conn| {
conn.execute("DELETE FROM sessions WHERE token = ?", (&session.token,))?;
pub async fn invalidate(&self, token: String) -> Result<()> {
query(&self.db, move |conn| {
conn.execute("DELETE FROM sessions WHERE token = ?", (token,))?;
Ok(())
})
.await
@ -100,7 +109,7 @@ where
let service = parts.extensions.get::<Arc<SessionService>>().unwrap();
let session = service
.try_get(token)
.try_get(token.to_owned())
.await
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "failed to fetch session"))?;
@ -128,7 +137,7 @@ where
let service = parts.extensions.get::<Arc<SessionService>>().unwrap();
let session = service
.try_get(token)
.try_get(token.to_owned())
.await
.map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "failed to fetch session"))?;