1
0
Fork 0
mirror of https://github.com/dani-garcia/vaultwarden.git synced 2025-07-08 21:44:57 +00:00

Update to rocket 0.5 and made code async, missing updating all db calls, that are currently blocking

This commit is contained in:
Daniel García 2021-11-07 18:53:39 +01:00
parent 89fe05b6cc
commit 0b7d6bf6df
No known key found for this signature in database
GPG key ID: FC8A7D14C3CD543A
30 changed files with 1314 additions and 1028 deletions

View file

@ -1,8 +1,16 @@
use std::{sync::Arc, time::Duration};
use diesel::r2d2::{ConnectionManager, Pool, PooledConnection};
use rocket::{
http::Status,
outcome::IntoOutcome,
request::{FromRequest, Outcome},
Request, State,
Request,
};
use tokio::{
sync::{Mutex, OwnedSemaphorePermit, Semaphore},
time::timeout,
};
use crate::{
@ -22,6 +30,23 @@ pub mod __mysql_schema;
#[path = "schemas/postgresql/schema.rs"]
pub mod __postgresql_schema;
// There changes are based on Rocket 0.5-rc wrapper of Diesel: https://github.com/SergioBenitez/Rocket/blob/v0.5-rc/contrib/sync_db_pools
// A wrapper around spawn_blocking that propagates panics to the calling code.
pub async fn run_blocking<F, R>(job: F) -> R
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
match tokio::task::spawn_blocking(job).await {
Ok(ret) => ret,
Err(e) => match e.try_into_panic() {
Ok(panic) => std::panic::resume_unwind(panic),
Err(_) => unreachable!("spawn_blocking tasks are never cancelled"),
},
}
}
// This is used to generate the main DbConn and DbPool enums, which contain one variant for each database supported
macro_rules! generate_connections {
( $( $name:ident: $ty:ty ),+ ) => {
@ -29,12 +54,53 @@ macro_rules! generate_connections {
#[derive(Eq, PartialEq)]
pub enum DbConnType { $( $name, )+ }
pub struct DbConn {
conn: Arc<Mutex<Option<DbConnInner>>>,
permit: Option<OwnedSemaphorePermit>,
}
#[allow(non_camel_case_types)]
pub enum DbConn { $( #[cfg($name)] $name(PooledConnection<ConnectionManager< $ty >>), )+ }
pub enum DbConnInner { $( #[cfg($name)] $name(PooledConnection<ConnectionManager< $ty >>), )+ }
#[derive(Clone)]
pub struct DbPool {
// This is an 'Option' so that we can drop the pool in a 'spawn_blocking'.
pool: Option<DbPoolInner>,
semaphore: Arc<Semaphore>
}
#[allow(non_camel_case_types)]
#[derive(Clone)]
pub enum DbPool { $( #[cfg($name)] $name(Pool<ConnectionManager< $ty >>), )+ }
pub enum DbPoolInner { $( #[cfg($name)] $name(Pool<ConnectionManager< $ty >>), )+ }
impl Drop for DbConn {
fn drop(&mut self) {
let conn = self.conn.clone();
let permit = self.permit.take();
// Since connection can't be on the stack in an async fn during an
// await, we have to spawn a new blocking-safe thread...
tokio::task::spawn_blocking(move || {
// And then re-enter the runtime to wait on the async mutex, but in a blocking fashion.
let mut conn = tokio::runtime::Handle::current().block_on(conn.lock_owned());
if let Some(conn) = conn.take() {
drop(conn);
}
// Drop permit after the connection is dropped
drop(permit);
});
}
}
impl Drop for DbPool {
fn drop(&mut self) {
let pool = self.pool.take();
tokio::task::spawn_blocking(move || drop(pool));
}
}
impl DbPool {
// For the given database URL, guess it's type, run migrations create pool and return it
@ -50,9 +116,13 @@ macro_rules! generate_connections {
let manager = ConnectionManager::new(&url);
let pool = Pool::builder()
.max_size(CONFIG.database_max_conns())
.connection_timeout(Duration::from_secs(CONFIG.database_timeout()))
.build(manager)
.map_res("Failed to create pool")?;
return Ok(Self::$name(pool));
return Ok(DbPool {
pool: Some(DbPoolInner::$name(pool)),
semaphore: Arc::new(Semaphore::new(CONFIG.database_max_conns() as usize)),
});
}
#[cfg(not($name))]
#[allow(unreachable_code)]
@ -61,10 +131,26 @@ macro_rules! generate_connections {
)+ }
}
// Get a connection from the pool
pub fn get(&self) -> Result<DbConn, Error> {
match self { $(
pub async fn get(&self) -> Result<DbConn, Error> {
let duration = Duration::from_secs(CONFIG.database_timeout());
let permit = match timeout(duration, self.semaphore.clone().acquire_owned()).await {
Ok(p) => p.expect("Semaphore should be open"),
Err(_) => {
err!("Timeout waiting for database connection");
}
};
match self.pool.as_ref().expect("DbPool.pool should always be Some()") { $(
#[cfg($name)]
Self::$name(p) => Ok(DbConn::$name(p.get().map_res("Error retrieving connection from pool")?)),
DbPoolInner::$name(p) => {
let pool = p.clone();
let c = run_blocking(move || pool.get_timeout(duration)).await.map_res("Error retrieving connection from pool")?;
return Ok(DbConn {
conn: Arc::new(Mutex::new(Some(DbConnInner::$name(c)))),
permit: Some(permit)
});
},
)+ }
}
}
@ -113,42 +199,95 @@ macro_rules! db_run {
db_run! { $conn: sqlite, mysql, postgresql $body }
};
// Different code for each db
( $conn:ident: $( $($db:ident),+ $body:block )+ ) => {{
#[allow(unused)] use diesel::prelude::*;
match $conn {
$($(
#[cfg($db)]
crate::db::DbConn::$db(ref $conn) => {
paste::paste! {
#[allow(unused)] use crate::db::[<__ $db _schema>]::{self as schema, *};
#[allow(unused)] use [<__ $db _model>]::*;
#[allow(unused)] use crate::db::FromDb;
}
$body
},
)+)+
}}
};
// Same for all dbs
( @raw $conn:ident: $body:block ) => {
db_run! { @raw $conn: sqlite, mysql, postgresql $body }
};
// Different code for each db
( @raw $conn:ident: $( $($db:ident),+ $body:block )+ ) => {
( $conn:ident: $( $($db:ident),+ $body:block )+ ) => {{
#[allow(unused)] use diesel::prelude::*;
#[allow(unused_variables)]
match $conn {
$($(
#[cfg($db)]
crate::db::DbConn::$db(ref $conn) => {
$body
},
)+)+
}
};
// It is important that this inner Arc<Mutex<>> (or the OwnedMutexGuard
// derived from it) never be a variable on the stack at an await point,
// where Drop might be called at any time. This causes (synchronous)
// Drop to be called from asynchronous code, which some database
// wrappers do not or can not handle.
let conn = $conn.conn.clone();
// Since connection can't be on the stack in an async fn during an
// await, we have to spawn a new blocking-safe thread...
/*
run_blocking(move || {
// And then re-enter the runtime to wait on the async mutex, but in
// a blocking fashion.
let mut conn = tokio::runtime::Handle::current().block_on(conn.lock_owned());
let conn = conn.as_mut().expect("internal invariant broken: self.connection is Some");
*/
let mut __conn_mutex = conn.try_lock_owned().unwrap();
let conn = __conn_mutex.as_mut().unwrap();
match conn {
$($(
#[cfg($db)]
crate::db::DbConnInner::$db($conn) => {
paste::paste! {
#[allow(unused)] use crate::db::[<__ $db _schema>]::{self as schema, *};
#[allow(unused)] use [<__ $db _model>]::*;
#[allow(unused)] use crate::db::FromDb;
}
/*
// Since connection can't be on the stack in an async fn during an
// await, we have to spawn a new blocking-safe thread...
run_blocking(move || {
// And then re-enter the runtime to wait on the async mutex, but in
// a blocking fashion.
let mut conn = tokio::runtime::Handle::current().block_on(async {
conn.lock_owned().await
});
let conn = conn.as_mut().expect("internal invariant broken: self.connection is Some");
f(conn)
}).await;*/
$body
},
)+)+
}
// }).await
}};
( @raw $conn:ident: $( $($db:ident),+ $body:block )+ ) => {{
#[allow(unused)] use diesel::prelude::*;
// It is important that this inner Arc<Mutex<>> (or the OwnedMutexGuard
// derived from it) never be a variable on the stack at an await point,
// where Drop might be called at any time. This causes (synchronous)
// Drop to be called from asynchronous code, which some database
// wrappers do not or can not handle.
let conn = $conn.conn.clone();
// Since connection can't be on the stack in an async fn during an
// await, we have to spawn a new blocking-safe thread...
run_blocking(move || {
// And then re-enter the runtime to wait on the async mutex, but in
// a blocking fashion.
let mut conn = tokio::runtime::Handle::current().block_on(conn.lock_owned());
match conn.as_mut().expect("internal invariant broken: self.connection is Some") {
$($(
#[cfg($db)]
crate::db::DbConnInner::$db($conn) => {
paste::paste! {
#[allow(unused)] use crate::db::[<__ $db _schema>]::{self as schema, *};
// @RAW: #[allow(unused)] use [<__ $db _model>]::*;
#[allow(unused)] use crate::db::FromDb;
}
$body
},
)+)+
}
}).await
}};
}
pub trait FromDb {
@ -227,9 +366,10 @@ pub mod models;
/// Creates a back-up of the sqlite database
/// MySQL/MariaDB and PostgreSQL are not supported.
pub fn backup_database(conn: &DbConn) -> Result<(), Error> {
pub async fn backup_database(conn: &DbConn) -> Result<(), Error> {
db_run! {@raw conn:
postgresql, mysql {
let _ = conn;
err!("PostgreSQL and MySQL/MariaDB do not support this backup feature");
}
sqlite {
@ -244,7 +384,7 @@ pub fn backup_database(conn: &DbConn) -> Result<(), Error> {
}
/// Get the SQL Server version
pub fn get_sql_server_version(conn: &DbConn) -> String {
pub async fn get_sql_server_version(conn: &DbConn) -> String {
db_run! {@raw conn:
postgresql, mysql {
no_arg_sql_function!(version, diesel::sql_types::Text);
@ -260,15 +400,14 @@ pub fn get_sql_server_version(conn: &DbConn) -> String {
/// Attempts to retrieve a single connection from the managed database pool. If
/// no pool is currently managed, fails with an `InternalServerError` status. If
/// no connections are available, fails with a `ServiceUnavailable` status.
impl<'a, 'r> FromRequest<'a, 'r> for DbConn {
#[rocket::async_trait]
impl<'r> FromRequest<'r> for DbConn {
type Error = ();
fn from_request(request: &'a Request<'r>) -> Outcome<DbConn, ()> {
// https://github.com/SergioBenitez/Rocket/commit/e3c1a4ad3ab9b840482ec6de4200d30df43e357c
let pool = try_outcome!(request.guard::<State<DbPool>>());
match pool.get() {
Ok(conn) => Outcome::Success(conn),
Err(_) => Outcome::Failure((Status::ServiceUnavailable, ())),
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match request.rocket().state::<DbPool>() {
Some(p) => p.get().await.map_err(|_| ()).into_outcome(Status::ServiceUnavailable),
None => Outcome::Failure((Status::InternalServerError, ())),
}
}
}