Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 18 additions & 22 deletions src/pg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1166,46 +1166,42 @@ mod tests {
.await
.unwrap();

fn erase<'a, T: Future + Send + 'a>(t: T) -> impl Future<Output = T::Output> + Send + 'a {
t
}

async fn fn12(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> {
let f1 = diesel::select(1_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
let f2 = diesel::select(2_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);

erase(try_join(f1, f2)).await
}

async fn fn34(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> {
let f3 = diesel::select(3_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
let f4 = diesel::select(4_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);

try_join(f3, f4).boxed().await
try_join(f1, f2).await
}

async fn fn56(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> {
async fn fn37(
mut conn: &AsyncPgConnection,
) -> QueryResult<(usize, (Vec<i32>, (i32, (Vec<i32>, i32))))> {
let f3 = diesel::select(0_i32.into_sql::<Integer>()).execute(&mut conn);
let f4 = diesel::select(4_i32.into_sql::<Integer>()).load::<i32>(&mut conn);
let f5 = diesel::select(5_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
let f6 = diesel::select(6_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
let f6 = diesel::select(6_i32.into_sql::<Integer>()).get_results::<i32>(&mut conn);
let f7 = diesel::select(7_i32.into_sql::<Integer>()).first::<i32>(&mut conn);

try_join(f5.boxed(), f6.boxed()).await
try_join(f3, try_join(f4, try_join(f5, try_join(f6, f7)))).await
}

conn.transaction(|conn| {
async move {
let f12 = fn12(conn);
let f34 = fn34(conn);
let f56 = fn56(conn);
let f37 = fn37(conn);

let ((r1, r2), ((r3, r4), (r5, r6))) =
try_join(f12, try_join(f34, f56)).await.unwrap();
let ((r1, r2), (r3, (r4, (r5, (r6, r7))))) = try_join(f12, f37).await.unwrap();

assert_eq!(r1, 1);
assert_eq!(r2, 2);
assert_eq!(r3, 3);
assert_eq!(r4, 4);
assert_eq!(r3, 1);
assert_eq!(r4, vec![4]);
assert_eq!(r5, 5);
assert_eq!(r6, 6);
assert_eq!(r6, vec![6]);
assert_eq!(r7, 7);

fn12(conn).await?;
fn37(conn).await?;

QueryResult::<_>::Ok(())
}
Expand Down
107 changes: 24 additions & 83 deletions src/run_query_dsl/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
mod utils;

use crate::AsyncConnectionCore;
use diesel::associations::HasTable;
use diesel::query_builder::IntoUpdateTarget;
use diesel::result::QueryResult;
use diesel::AsChangeset;
use futures_core::future::BoxFuture;
use futures_core::Stream;
use futures_util::{future, stream, FutureExt, StreamExt, TryFutureExt, TryStreamExt};
#[cfg(any(feature = "mysql", feature = "postgres"))]
use futures_util::FutureExt;
use futures_util::{stream, StreamExt, TryStreamExt};
use std::future::Future;
use std::pin::Pin;

/// The traits used by `QueryDsl`.
///
Expand All @@ -22,7 +24,7 @@ pub mod methods {
use diesel::expression::QueryMetadata;
use diesel::query_builder::{AsQuery, QueryFragment, QueryId};
use diesel::query_dsl::CompatibleType;
use futures_util::{Future, Stream, TryFutureExt};
use futures_util::{Future, Stream};

/// The `execute` method
///
Expand Down Expand Up @@ -74,6 +76,7 @@ pub mod methods {
type LoadFuture<'conn>: Future<Output = QueryResult<Self::Stream<'conn>>> + Send
where
Conn: 'conn;

/// The inner stream returned by [`LoadQuery::internal_load`]
type Stream<'conn>: Stream<Item = QueryResult<U>> + Send
where
Expand All @@ -96,10 +99,7 @@ pub mod methods {
ST: 'static,
{
type LoadFuture<'conn>
= future::MapOk<
Conn::LoadFuture<'conn, 'query>,
fn(Conn::Stream<'conn, 'query>) -> Self::Stream<'conn>,
>
= utils::MapOk<Conn::LoadFuture<'conn, 'query>, Self::Stream<'conn>>
where
Conn: 'conn;

Expand All @@ -112,33 +112,13 @@ pub mod methods {
Conn: 'conn;

fn internal_load(self, conn: &mut Conn) -> Self::LoadFuture<'_> {
conn.load(self)
.map_ok(map_result_stream_future::<U, _, _, DB, ST>)
utils::MapOk::new(conn.load(self), |stream| {
stream.map(|row| {
U::build_from_row(&row?).map_err(diesel::result::Error::DeserializationError)
})
})
}
}

#[allow(clippy::type_complexity)]
fn map_result_stream_future<'s, 'a, U, S, R, DB, ST>(
stream: S,
) -> stream::Map<S, fn(QueryResult<R>) -> QueryResult<U>>
where
S: Stream<Item = QueryResult<R>> + Send + 's,
R: diesel::row::Row<'a, DB> + 's,
DB: Backend + 'static,
U: FromSqlRow<ST, DB> + 'static,
ST: 'static,
{
stream.map(map_row_helper::<_, DB, U, ST>)
}

fn map_row_helper<'a, R, DB, U, ST>(row: QueryResult<R>) -> QueryResult<U>
where
U: FromSqlRow<ST, DB>,
R: diesel::row::Row<'a, DB>,
DB: Backend,
{
U::build_from_row(&row?).map_err(diesel::result::Error::DeserializationError)
}
}

/// The return types produced by the various [`RunQueryDsl`] methods
Expand All @@ -149,37 +129,24 @@ pub mod methods {
// the same connection
#[allow(type_alias_bounds)] // we need these bounds otherwise we cannot use GAT's
pub mod return_futures {
use crate::run_query_dsl::utils;

use super::methods::LoadQuery;
use diesel::QueryResult;
use futures_util::{future, stream};
use futures_util::stream;
use std::pin::Pin;

/// The future returned by [`RunQueryDsl::load`](super::RunQueryDsl::load)
/// and [`RunQueryDsl::get_results`](super::RunQueryDsl::get_results)
///
/// This is essentially `impl Future<Output = QueryResult<Vec<U>>>`
pub type LoadFuture<'conn, 'query, Q: LoadQuery<'query, Conn, U>, Conn, U> = future::AndThen<
Q::LoadFuture<'conn>,
stream::TryCollect<Q::Stream<'conn>, Vec<U>>,
fn(Q::Stream<'conn>) -> stream::TryCollect<Q::Stream<'conn>, Vec<U>>,
>;
pub type LoadFuture<'conn, 'query, Q: LoadQuery<'query, Conn, U>, Conn, U> =
utils::AndThen<Q::LoadFuture<'conn>, stream::TryCollect<Q::Stream<'conn>, Vec<U>>>;

/// The future returned by [`RunQueryDsl::get_result`](super::RunQueryDsl::get_result)
///
/// This is essentially `impl Future<Output = QueryResult<U>>`
pub type GetResult<'conn, 'query, Q: LoadQuery<'query, Conn, U>, Conn, U> = future::AndThen<
Q::LoadFuture<'conn>,
future::Map<
stream::StreamFuture<Pin<Box<Q::Stream<'conn>>>>,
fn((Option<QueryResult<U>>, Pin<Box<Q::Stream<'conn>>>)) -> QueryResult<U>,
>,
fn(
Q::Stream<'conn>,
) -> future::Map<
stream::StreamFuture<Pin<Box<Q::Stream<'conn>>>>,
fn((Option<QueryResult<U>>, Pin<Box<Q::Stream<'conn>>>)) -> QueryResult<U>,
>,
>;
pub type GetResult<'conn, 'query, Q: LoadQuery<'query, Conn, U>, Conn, U> =
utils::AndThen<Q::LoadFuture<'conn>, utils::LoadNext<Pin<Box<Q::Stream<'conn>>>>>;
}

/// Methods used to execute queries.
Expand Down Expand Up @@ -346,13 +313,7 @@ pub trait RunQueryDsl<Conn>: Sized {
Conn: AsyncConnectionCore,
Self: methods::LoadQuery<'query, Conn, U> + 'query,
{
fn collect_result<U, S>(stream: S) -> stream::TryCollect<S, Vec<U>>
where
S: Stream<Item = QueryResult<U>>,
{
stream.try_collect()
}
self.internal_load(conn).and_then(collect_result::<U, _>)
utils::AndThen::new(self.internal_load(conn), |stream| stream.try_collect())
}

/// Executes the given query, returning a [`Stream`] with the returned rows.
Expand Down Expand Up @@ -547,29 +508,9 @@ pub trait RunQueryDsl<Conn>: Sized {
Conn: AsyncConnectionCore,
Self: methods::LoadQuery<'query, Conn, U> + 'query,
{
#[allow(clippy::type_complexity)]
fn get_next_stream_element<S, U>(
stream: S,
) -> future::Map<
stream::StreamFuture<Pin<Box<S>>>,
fn((Option<QueryResult<U>>, Pin<Box<S>>)) -> QueryResult<U>,
>
where
S: Stream<Item = QueryResult<U>>,
{
fn map_option_to_result<U, S>(
(o, _): (Option<QueryResult<U>>, Pin<Box<S>>),
) -> QueryResult<U> {
match o {
Some(s) => s,
None => Err(diesel::result::Error::NotFound),
}
}

Box::pin(stream).into_future().map(map_option_to_result)
}

self.load_stream(conn).and_then(get_next_stream_element)
utils::AndThen::new(self.internal_load(conn), |stream| {
utils::LoadNext::new(Box::pin(stream))
})
}

/// Runs the command, returning an `Vec` with the affected rows.
Expand Down
112 changes: 112 additions & 0 deletions src/run_query_dsl/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};

use diesel::QueryResult;
use futures_core::{ready, TryFuture, TryStream};
use futures_util::{TryFutureExt, TryStreamExt};

// We use a custom future implementation here to erase some lifetimes
// that otherwise need to be specified explicitly
//
// Specifying these lifetimes results in the compiler not beeing
// able to look through the generic code and emit
// lifetime erros for pipelined queries. See
// https://github.com/weiznich/diesel_async/issues/249 for more context
#[repr(transparent)]
pub struct MapOk<F: TryFutureExt, T> {
future: futures_util::future::MapOk<F, fn(F::Ok) -> T>,
}

impl<F, T> Future for MapOk<F, T>
where
F: TryFuture,
futures_util::future::MapOk<F, fn(F::Ok) -> T>: Future<Output = Result<T, F::Error>>,
{
type Output = Result<T, F::Error>;

fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
unsafe {
// SAFETY: This projects pinning to the only inner field, so it
// should be safe
self.map_unchecked_mut(|s| &mut s.future)
}
.poll(cx)
}
}

impl<Fut: TryFutureExt, T> MapOk<Fut, T> {
pub(crate) fn new(future: Fut, f: fn(Fut::Ok) -> T) -> Self {
Self {
future: future.map_ok(f),
}
}
}

// similar to `MapOk` above this mainly exists to hide the lifetime
#[repr(transparent)]
pub struct AndThen<F1: TryFuture, F2> {
future: futures_util::future::AndThen<F1, F2, fn(F1::Ok) -> F2>,
}

impl<Fut1, Fut2> AndThen<Fut1, Fut2>
where
Fut1: TryFuture,
Fut2: TryFuture<Error = Fut1::Error>,
{
pub(crate) fn new(fut1: Fut1, f: fn(Fut1::Ok) -> Fut2) -> AndThen<Fut1, Fut2> {
Self {
future: fut1.and_then(f),
}
}
}

impl<F1, F2> Future for AndThen<F1, F2>
where
F1: TryFuture,
F2: TryFuture<Error = F1::Error>,
{
type Output = Result<F2::Ok, F2::Error>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
unsafe {
// SAFETY: This projects pinning to the only inner field, so it
// should be safe
self.map_unchecked_mut(|s| &mut s.future)
}
.poll(cx)
}
}

/// Converts a stream into a future, only yielding the first element.
/// Based on [`futures_util::stream::StreamFuture`].
pub struct LoadNext<St> {
stream: Option<St>,
}

impl<St> LoadNext<St> {
pub(crate) fn new(stream: St) -> Self {
Self {
stream: Some(stream),
}
}
}

impl<St> Future for LoadNext<St>
where
St: TryStream<Error = diesel::result::Error> + Unpin,
{
type Output = QueryResult<St::Ok>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let first = {
let s = self.stream.as_mut().expect("polling LoadNext twice");
ready!(s.try_poll_next_unpin(cx))
};
self.stream = None;
match first {
Some(first) => Poll::Ready(first),
None => Poll::Ready(Err(diesel::result::Error::NotFound)),
}
}
}
Loading