Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ description = "An async extension for Diesel the safe, extensible ORM and Query
rust-version = "1.78.0"

[dependencies]
async-trait = "0.1.66"
futures-channel = { version = "0.3.17", default-features = false, features = [
"std",
"sink",
Expand All @@ -30,6 +29,7 @@ mysql_async = { version = "0.34", optional = true, default-features = false, fea
mysql_common = { version = "0.32", optional = true, default-features = false }

bb8 = { version = "0.9", optional = true }
async-trait = { version = "0.1.66", optional = true }
deadpool = { version = "0.12", optional = true, default-features = false, features = [
"managed",
] }
Expand Down Expand Up @@ -80,7 +80,7 @@ sync-connection-wrapper = ["tokio/rt"]
async-connection-wrapper = ["tokio/net"]
pool = []
r2d2 = ["pool", "diesel/r2d2"]
bb8 = ["pool", "dep:bb8"]
bb8 = ["pool", "dep:bb8", "dep:async-trait"]
mobc = ["pool", "dep:mobc"]
deadpool = ["pool", "dep:deadpool"]

Expand Down
2 changes: 1 addition & 1 deletion src/async_connection_wrapper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ mod implementation {
runtime: &'a B,
}

impl<'a, S, B> Iterator for AsyncCursorWrapper<'a, S, B>
impl<S, B> Iterator for AsyncCursorWrapper<'_, S, B>
where
S: Stream,
B: BlockOn,
Expand Down
88 changes: 51 additions & 37 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@
use diesel::backend::Backend;
use diesel::connection::{CacheSize, Instrumentation};
use diesel::query_builder::{AsQuery, QueryFragment, QueryId};
use diesel::result::Error;
use diesel::row::Row;
use diesel::{ConnectionResult, QueryResult};
use futures_util::{Future, Stream};
use futures_util::future::BoxFuture;
use futures_util::{Future, FutureExt, Stream};
use std::fmt::Debug;

pub use scoped_futures;
Expand Down Expand Up @@ -115,21 +115,19 @@ pub use self::transaction_manager::{AnsiTransactionManager, TransactionManager};
/// Perform simple operations on a backend.
///
/// You should likely use [`AsyncConnection`] instead.
#[async_trait::async_trait]
pub trait SimpleAsyncConnection {
/// Execute multiple SQL statements within the same string.
///
/// This function is used to execute migrations,
/// which may contain more than one SQL statement.
async fn batch_execute(&mut self, query: &str) -> QueryResult<()>;
fn batch_execute(&mut self, query: &str) -> impl Future<Output = QueryResult<()>> + Send;
}

/// An async connection to a database
///
/// This trait represents a n async database connection. It can be used to query the database through
/// the query dsl provided by diesel, custom extensions or raw sql queries. It essentially mirrors
/// the sync diesel [`Connection`](diesel::connection::Connection) implementation
#[async_trait::async_trait]
pub trait AsyncConnection: SimpleAsyncConnection + Sized + Send {
/// The future returned by `AsyncConnection::execute`
type ExecuteFuture<'conn, 'query>: Future<Output = QueryResult<usize>> + Send;
Expand All @@ -151,7 +149,7 @@ pub trait AsyncConnection: SimpleAsyncConnection + Sized + Send {
/// The argument to this method and the method's behavior varies by backend.
/// See the documentation for that backend's connection class
/// for details about what it accepts and how it behaves.
async fn establish(database_url: &str) -> ConnectionResult<Self>;
fn establish(database_url: &str) -> impl Future<Output = ConnectionResult<Self>> + Send;

/// Executes the given function inside of a database transaction
///
Expand Down Expand Up @@ -230,34 +228,44 @@ pub trait AsyncConnection: SimpleAsyncConnection + Sized + Send {
/// # Ok(())
/// # }
/// ```
async fn transaction<'a, R, E, F>(&mut self, callback: F) -> Result<R, E>
fn transaction<'a, 'conn, R, E, F>(
&'conn mut self,
callback: F,
) -> BoxFuture<'conn, Result<R, E>>
// we cannot use `impl Trait` here due to bugs in rustc
// https://github.com/rust-lang/rust/issues/100013
//impl Future<Output = Result<R, E>> + Send + 'async_trait
where
F: for<'r> FnOnce(&'r mut Self) -> ScopedBoxFuture<'a, 'r, Result<R, E>> + Send + 'a,
E: From<diesel::result::Error> + Send + 'a,
R: Send + 'a,
'a: 'conn,
{
Self::TransactionManager::transaction(self, callback).await
Self::TransactionManager::transaction(self, callback).boxed()
}

/// Creates a transaction that will never be committed. This is useful for
/// tests. Panics if called while inside of a transaction or
/// if called with a connection containing a broken transaction
async fn begin_test_transaction(&mut self) -> QueryResult<()> {
fn begin_test_transaction(&mut self) -> impl Future<Output = QueryResult<()>> + Send {
use diesel::connection::TransactionManagerStatus;

match Self::TransactionManager::transaction_manager_status_mut(self) {
TransactionManagerStatus::Valid(valid_status) => {
assert_eq!(None, valid_status.transaction_depth())
}
TransactionManagerStatus::InError => panic!("Transaction manager in error"),
};
Self::TransactionManager::begin_transaction(self).await?;
// set the test transaction flag
// to prevent that this connection gets dropped in connection pools
// Tests commonly set the poolsize to 1 and use `begin_test_transaction`
// to prevent modifications to the schema
Self::TransactionManager::transaction_manager_status_mut(self).set_test_transaction_flag();
Ok(())
async {
match Self::TransactionManager::transaction_manager_status_mut(self) {
TransactionManagerStatus::Valid(valid_status) => {
assert_eq!(None, valid_status.transaction_depth())
}
TransactionManagerStatus::InError => panic!("Transaction manager in error"),
};
Self::TransactionManager::begin_transaction(self).await?;
// set the test transaction flag
// to prevent that this connection gets dropped in connection pools
// Tests commonly set the poolsize to 1 and use `begin_test_transaction`
// to prevent modifications to the schema
Self::TransactionManager::transaction_manager_status_mut(self)
.set_test_transaction_flag();
Ok(())
}
}

/// Executes the given function inside a transaction, but does not commit
Expand Down Expand Up @@ -297,27 +305,33 @@ pub trait AsyncConnection: SimpleAsyncConnection + Sized + Send {
/// # Ok(())
/// # }
/// ```
async fn test_transaction<'a, R, E, F>(&'a mut self, f: F) -> R
fn test_transaction<'conn, 'a, R, E, F>(
&'conn mut self,
f: F,
) -> impl Future<Output = R> + Send + 'conn
where
F: for<'r> FnOnce(&'r mut Self) -> ScopedBoxFuture<'a, 'r, Result<R, E>> + Send + 'a,
E: Debug + Send + 'a,
R: Send + 'a,
Self: 'a,
'a: 'conn,
{
use futures_util::TryFutureExt;

let mut user_result = None;
let _ = self
.transaction::<R, _, _>(|c| {
f(c).map_err(|_| Error::RollbackTransaction)
.and_then(|r| {
user_result = Some(r);
futures_util::future::ready(Err(Error::RollbackTransaction))
})
.scope_boxed()
})
.await;
user_result.expect("Transaction did not succeed")
let (user_result_tx, user_result_rx) = std::sync::mpsc::channel();
self.transaction::<R, _, _>(move |conn| {
f(conn)
.map_err(|_| diesel::result::Error::RollbackTransaction)
.and_then(move |r| {
let _ = user_result_tx.send(r);
futures_util::future::ready(Err(diesel::result::Error::RollbackTransaction))
})
.scope_boxed()
})
.then(move |_r| {
let r = user_result_rx
.try_recv()
.expect("Transaction did not succeed");
futures_util::future::ready(r)
})
}

#[doc(hidden)]
Expand Down
6 changes: 2 additions & 4 deletions src/mysql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ pub struct AsyncMysqlConnection {
instrumentation: DynInstrumentation,
}

#[async_trait::async_trait]
impl SimpleAsyncConnection for AsyncMysqlConnection {
async fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> {
self.instrumentation()
Expand Down Expand Up @@ -63,7 +62,6 @@ const CONNECTION_SETUP_QUERIES: &[&str] = &[
"SET character_set_results = 'utf8mb4'",
];

#[async_trait::async_trait]
impl AsyncConnection for AsyncMysqlConnection {
type ExecuteFuture<'conn, 'query> = BoxFuture<'conn, QueryResult<usize>>;
type LoadFuture<'conn, 'query> = BoxFuture<'conn, QueryResult<Self::Stream<'conn, 'query>>>;
Expand Down Expand Up @@ -208,9 +206,9 @@ fn update_transaction_manager_status<T>(
query_result
}

fn prepare_statement_helper<'a, 'b>(
fn prepare_statement_helper<'a>(
conn: &'a mut mysql_async::Conn,
sql: &'b str,
sql: &str,
_is_for_cache: diesel::connection::statement_cache::PrepareForCache,
_metadata: &[MysqlType],
) -> CallbackHelper<impl Future<Output = QueryResult<(Statement, &'a mut mysql_async::Conn)>> + Send>
Expand Down
2 changes: 1 addition & 1 deletion src/mysql/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ pub struct MysqlField<'a> {
name: Cow<'a, str>,
}

impl<'a> diesel::row::Field<'a, Mysql> for MysqlField<'_> {
impl diesel::row::Field<'_, Mysql> for MysqlField<'_> {
fn field_name(&self) -> Option<&str> {
Some(&*self.name)
}
Expand Down
6 changes: 2 additions & 4 deletions src/pg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ pub struct AsyncPgConnection {
instrumentation: Arc<std::sync::Mutex<DynInstrumentation>>,
}

#[async_trait::async_trait]
impl SimpleAsyncConnection for AsyncPgConnection {
async fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
self.record_instrumentation(InstrumentationEvent::start_query(&StrQueryHelper::new(
Expand All @@ -154,7 +153,6 @@ impl SimpleAsyncConnection for AsyncPgConnection {
}
}

#[async_trait::async_trait]
impl AsyncConnection for AsyncPgConnection {
type LoadFuture<'conn, 'query> = BoxFuture<'query, QueryResult<Self::Stream<'conn, 'query>>>;
type ExecuteFuture<'conn, 'query> = BoxFuture<'query, QueryResult<usize>>;
Expand Down Expand Up @@ -306,9 +304,9 @@ fn update_transaction_manager_status<T>(
query_result
}

fn prepare_statement_helper<'a>(
fn prepare_statement_helper(
conn: Arc<tokio_postgres::Client>,
sql: &'a str,
sql: &str,
_is_for_cache: PrepareForCache,
metadata: &[PgTypeMetadata],
) -> CallbackHelper<
Expand Down
2 changes: 1 addition & 1 deletion src/pg/transaction_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ where
}
}

impl<'a, C> QueryFragment<Pg> for TransactionBuilder<'a, C> {
impl<C> QueryFragment<Pg> for TransactionBuilder<'_, C> {
fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
out.push_sql("BEGIN TRANSACTION");
if let Some(ref isolation_level) = self.isolation_level {
Expand Down
1 change: 0 additions & 1 deletion src/pooled_connection/bb8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
//! # Ok(())
//! # }
//! ```

use super::{AsyncDieselConnectionManager, PoolError, PoolableConnection};
use bb8::ManageConnection;
use diesel::query_builder::QueryFragment;
Expand Down
45 changes: 26 additions & 19 deletions src/pooled_connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ use crate::{TransactionManager, UpdateAndFetchResults};
use diesel::associations::HasTable;
use diesel::connection::{CacheSize, Instrumentation};
use diesel::QueryResult;
use futures_util::future::BoxFuture;
use futures_util::{future, FutureExt};
use std::borrow::Cow;
use std::fmt;
use std::future::Future;
use std::ops::DerefMut;

#[cfg(feature = "bb8")]
Expand Down Expand Up @@ -164,7 +166,6 @@ where
}
}

#[async_trait::async_trait]
impl<C> SimpleAsyncConnection for C
where
C: DerefMut + Send,
Expand All @@ -176,7 +177,6 @@ where
}
}

#[async_trait::async_trait]
impl<C> AsyncConnection for C
where
C: DerefMut + Send,
Expand Down Expand Up @@ -251,7 +251,6 @@ where
#[allow(missing_debug_implementations)]
pub struct PoolTransactionManager<TM>(std::marker::PhantomData<TM>);

#[async_trait::async_trait]
impl<C, TM> TransactionManager<C> for PoolTransactionManager<TM>
where
C: DerefMut + Send,
Expand Down Expand Up @@ -283,18 +282,22 @@ where
}
}

#[async_trait::async_trait]
impl<Changes, Output, Conn> UpdateAndFetchResults<Changes, Output> for Conn
where
Conn: DerefMut + Send,
Changes: diesel::prelude::Identifiable + HasTable + Send,
Conn::Target: UpdateAndFetchResults<Changes, Output>,
{
async fn update_and_fetch(&mut self, changeset: Changes) -> QueryResult<Output>
fn update_and_fetch<'conn, 'changes>(
&'conn mut self,
changeset: Changes,
) -> BoxFuture<'changes, QueryResult<Output>>
where
Changes: 'async_trait,
Changes: 'changes,
'conn: 'changes,
Self: 'changes,
{
self.deref_mut().update_and_fetch(changeset).await
self.deref_mut().update_and_fetch(changeset)
}
}

Expand All @@ -321,13 +324,15 @@ impl diesel::query_builder::Query for CheckConnectionQuery {
impl<C> diesel::query_dsl::RunQueryDsl<C> for CheckConnectionQuery {}

#[doc(hidden)]
#[async_trait::async_trait]
pub trait PoolableConnection: AsyncConnection {
/// Check if a connection is still valid
///
/// The default implementation will perform a check based on the provided
/// recycling method variant
async fn ping(&mut self, config: &RecyclingMethod<Self>) -> diesel::QueryResult<()>
fn ping(
&mut self,
config: &RecyclingMethod<Self>,
) -> impl Future<Output = diesel::QueryResult<()>> + Send
where
for<'a> Self: 'a,
diesel::dsl::select<diesel::dsl::AsExprOf<i32, diesel::sql_types::Integer>>:
Expand All @@ -337,19 +342,21 @@ pub trait PoolableConnection: AsyncConnection {
use crate::run_query_dsl::RunQueryDsl;
use diesel::IntoSql;

match config {
RecyclingMethod::Fast => Ok(()),
RecyclingMethod::Verified => {
diesel::select(1_i32.into_sql::<diesel::sql_types::Integer>())
async move {
match config {
RecyclingMethod::Fast => Ok(()),
RecyclingMethod::Verified => {
diesel::select(1_i32.into_sql::<diesel::sql_types::Integer>())
.execute(self)
.await
.map(|_| ())
}
RecyclingMethod::CustomQuery(query) => diesel::sql_query(query.as_ref())
.execute(self)
.await
.map(|_| ())
.map(|_| ()),
RecyclingMethod::CustomFunction(c) => c(self).await,
}
RecyclingMethod::CustomQuery(query) => diesel::sql_query(query.as_ref())
.execute(self)
.await
.map(|_| ()),
RecyclingMethod::CustomFunction(c) => c(self).await,
}
}

Expand Down
Loading
Loading