diff --git a/src/pg/mod.rs b/src/pg/mod.rs index 1c5bb46..6a7899a 100644 --- a/src/pg/mod.rs +++ b/src/pg/mod.rs @@ -1166,46 +1166,42 @@ mod tests { .await .unwrap(); - fn erase<'a, T: Future + Send + 'a>(t: T) -> impl Future + Send + 'a { - t - } - async fn fn12(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> { let f1 = diesel::select(1_i32.into_sql::()).get_result::(&mut conn); let f2 = diesel::select(2_i32.into_sql::()).get_result::(&mut conn); - erase(try_join(f1, f2)).await - } - - async fn fn34(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> { - let f3 = diesel::select(3_i32.into_sql::()).get_result::(&mut conn); - let f4 = diesel::select(4_i32.into_sql::()).get_result::(&mut conn); - - try_join(f3, f4).boxed().await + try_join(f1, f2).await } - async fn fn56(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> { + async fn fn37( + mut conn: &AsyncPgConnection, + ) -> QueryResult<(usize, (Vec, (i32, (Vec, i32))))> { + let f3 = diesel::select(0_i32.into_sql::()).execute(&mut conn); + let f4 = diesel::select(4_i32.into_sql::()).load::(&mut conn); let f5 = diesel::select(5_i32.into_sql::()).get_result::(&mut conn); - let f6 = diesel::select(6_i32.into_sql::()).get_result::(&mut conn); + let f6 = diesel::select(6_i32.into_sql::()).get_results::(&mut conn); + let f7 = diesel::select(7_i32.into_sql::()).first::(&mut conn); - try_join(f5.boxed(), f6.boxed()).await + try_join(f3, try_join(f4, try_join(f5, try_join(f6, f7)))).await } conn.transaction(|conn| { async move { let f12 = fn12(conn); - let f34 = fn34(conn); - let f56 = fn56(conn); + let f37 = fn37(conn); - let ((r1, r2), ((r3, r4), (r5, r6))) = - try_join(f12, try_join(f34, f56)).await.unwrap(); + let ((r1, r2), (r3, (r4, (r5, (r6, r7))))) = try_join(f12, f37).await.unwrap(); assert_eq!(r1, 1); assert_eq!(r2, 2); - assert_eq!(r3, 3); - assert_eq!(r4, 4); + assert_eq!(r3, 1); + assert_eq!(r4, vec![4]); assert_eq!(r5, 5); - assert_eq!(r6, 6); + assert_eq!(r6, vec![6]); + assert_eq!(r7, 7); + + fn12(conn).await?; + fn37(conn).await?; QueryResult::<_>::Ok(()) } diff --git a/src/run_query_dsl/mod.rs b/src/run_query_dsl/mod.rs index 437d2a2..18a9fda 100644 --- a/src/run_query_dsl/mod.rs +++ b/src/run_query_dsl/mod.rs @@ -1,13 +1,15 @@ +mod utils; + use crate::AsyncConnectionCore; use diesel::associations::HasTable; use diesel::query_builder::IntoUpdateTarget; use diesel::result::QueryResult; use diesel::AsChangeset; use futures_core::future::BoxFuture; -use futures_core::Stream; -use futures_util::{future, stream, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; +#[cfg(any(feature = "mysql", feature = "postgres"))] +use futures_util::FutureExt; +use futures_util::{stream, StreamExt, TryStreamExt}; use std::future::Future; -use std::pin::Pin; /// The traits used by `QueryDsl`. /// @@ -22,7 +24,7 @@ pub mod methods { use diesel::expression::QueryMetadata; use diesel::query_builder::{AsQuery, QueryFragment, QueryId}; use diesel::query_dsl::CompatibleType; - use futures_util::{Future, Stream, TryFutureExt}; + use futures_util::{Future, Stream}; /// The `execute` method /// @@ -74,6 +76,7 @@ pub mod methods { type LoadFuture<'conn>: Future>> + Send where Conn: 'conn; + /// The inner stream returned by [`LoadQuery::internal_load`] type Stream<'conn>: Stream> + Send where @@ -96,10 +99,7 @@ pub mod methods { ST: 'static, { type LoadFuture<'conn> - = future::MapOk< - Conn::LoadFuture<'conn, 'query>, - fn(Conn::Stream<'conn, 'query>) -> Self::Stream<'conn>, - > + = utils::MapOk, Self::Stream<'conn>> where Conn: 'conn; @@ -112,33 +112,13 @@ pub mod methods { Conn: 'conn; fn internal_load(self, conn: &mut Conn) -> Self::LoadFuture<'_> { - conn.load(self) - .map_ok(map_result_stream_future::) + utils::MapOk::new(conn.load(self), |stream| { + stream.map(|row| { + U::build_from_row(&row?).map_err(diesel::result::Error::DeserializationError) + }) + }) } } - - #[allow(clippy::type_complexity)] - fn map_result_stream_future<'s, 'a, U, S, R, DB, ST>( - stream: S, - ) -> stream::Map) -> QueryResult> - where - S: Stream> + Send + 's, - R: diesel::row::Row<'a, DB> + 's, - DB: Backend + 'static, - U: FromSqlRow + 'static, - ST: 'static, - { - stream.map(map_row_helper::<_, DB, U, ST>) - } - - fn map_row_helper<'a, R, DB, U, ST>(row: QueryResult) -> QueryResult - where - U: FromSqlRow, - R: diesel::row::Row<'a, DB>, - DB: Backend, - { - U::build_from_row(&row?).map_err(diesel::result::Error::DeserializationError) - } } /// The return types produced by the various [`RunQueryDsl`] methods @@ -149,37 +129,24 @@ pub mod methods { // the same connection #[allow(type_alias_bounds)] // we need these bounds otherwise we cannot use GAT's pub mod return_futures { + use crate::run_query_dsl::utils; + use super::methods::LoadQuery; - use diesel::QueryResult; - use futures_util::{future, stream}; + use futures_util::stream; use std::pin::Pin; /// The future returned by [`RunQueryDsl::load`](super::RunQueryDsl::load) /// and [`RunQueryDsl::get_results`](super::RunQueryDsl::get_results) /// /// This is essentially `impl Future>>` - pub type LoadFuture<'conn, 'query, Q: LoadQuery<'query, Conn, U>, Conn, U> = future::AndThen< - Q::LoadFuture<'conn>, - stream::TryCollect, Vec>, - fn(Q::Stream<'conn>) -> stream::TryCollect, Vec>, - >; + pub type LoadFuture<'conn, 'query, Q: LoadQuery<'query, Conn, U>, Conn, U> = + utils::AndThen, stream::TryCollect, Vec>>; /// The future returned by [`RunQueryDsl::get_result`](super::RunQueryDsl::get_result) /// /// This is essentially `impl Future>` - pub type GetResult<'conn, 'query, Q: LoadQuery<'query, Conn, U>, Conn, U> = future::AndThen< - Q::LoadFuture<'conn>, - future::Map< - stream::StreamFuture>>>, - fn((Option>, Pin>>)) -> QueryResult, - >, - fn( - Q::Stream<'conn>, - ) -> future::Map< - stream::StreamFuture>>>, - fn((Option>, Pin>>)) -> QueryResult, - >, - >; + pub type GetResult<'conn, 'query, Q: LoadQuery<'query, Conn, U>, Conn, U> = + utils::AndThen, utils::LoadNext>>>>; } /// Methods used to execute queries. @@ -346,13 +313,7 @@ pub trait RunQueryDsl: Sized { Conn: AsyncConnectionCore, Self: methods::LoadQuery<'query, Conn, U> + 'query, { - fn collect_result(stream: S) -> stream::TryCollect> - where - S: Stream>, - { - stream.try_collect() - } - self.internal_load(conn).and_then(collect_result::) + utils::AndThen::new(self.internal_load(conn), |stream| stream.try_collect()) } /// Executes the given query, returning a [`Stream`] with the returned rows. @@ -547,29 +508,9 @@ pub trait RunQueryDsl: Sized { Conn: AsyncConnectionCore, Self: methods::LoadQuery<'query, Conn, U> + 'query, { - #[allow(clippy::type_complexity)] - fn get_next_stream_element( - stream: S, - ) -> future::Map< - stream::StreamFuture>>, - fn((Option>, Pin>)) -> QueryResult, - > - where - S: Stream>, - { - fn map_option_to_result( - (o, _): (Option>, Pin>), - ) -> QueryResult { - match o { - Some(s) => s, - None => Err(diesel::result::Error::NotFound), - } - } - - Box::pin(stream).into_future().map(map_option_to_result) - } - - self.load_stream(conn).and_then(get_next_stream_element) + utils::AndThen::new(self.internal_load(conn), |stream| { + utils::LoadNext::new(Box::pin(stream)) + }) } /// Runs the command, returning an `Vec` with the affected rows. diff --git a/src/run_query_dsl/utils.rs b/src/run_query_dsl/utils.rs new file mode 100644 index 0000000..22f5891 --- /dev/null +++ b/src/run_query_dsl/utils.rs @@ -0,0 +1,112 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use diesel::QueryResult; +use futures_core::{ready, TryFuture, TryStream}; +use futures_util::{TryFutureExt, TryStreamExt}; + +// We use a custom future implementation here to erase some lifetimes +// that otherwise need to be specified explicitly +// +// Specifying these lifetimes results in the compiler not beeing +// able to look through the generic code and emit +// lifetime erros for pipelined queries. See +// https://github.com/weiznich/diesel_async/issues/249 for more context +#[repr(transparent)] +pub struct MapOk { + future: futures_util::future::MapOk T>, +} + +impl Future for MapOk +where + F: TryFuture, + futures_util::future::MapOk T>: Future>, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + unsafe { + // SAFETY: This projects pinning to the only inner field, so it + // should be safe + self.map_unchecked_mut(|s| &mut s.future) + } + .poll(cx) + } +} + +impl MapOk { + pub(crate) fn new(future: Fut, f: fn(Fut::Ok) -> T) -> Self { + Self { + future: future.map_ok(f), + } + } +} + +// similar to `MapOk` above this mainly exists to hide the lifetime +#[repr(transparent)] +pub struct AndThen { + future: futures_util::future::AndThen F2>, +} + +impl AndThen +where + Fut1: TryFuture, + Fut2: TryFuture, +{ + pub(crate) fn new(fut1: Fut1, f: fn(Fut1::Ok) -> Fut2) -> AndThen { + Self { + future: fut1.and_then(f), + } + } +} + +impl Future for AndThen +where + F1: TryFuture, + F2: TryFuture, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + unsafe { + // SAFETY: This projects pinning to the only inner field, so it + // should be safe + self.map_unchecked_mut(|s| &mut s.future) + } + .poll(cx) + } +} + +/// Converts a stream into a future, only yielding the first element. +/// Based on [`futures_util::stream::StreamFuture`]. +pub struct LoadNext { + stream: Option, +} + +impl LoadNext { + pub(crate) fn new(stream: St) -> Self { + Self { + stream: Some(stream), + } + } +} + +impl Future for LoadNext +where + St: TryStream + Unpin, +{ + type Output = QueryResult; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let first = { + let s = self.stream.as_mut().expect("polling LoadNext twice"); + ready!(s.try_poll_next_unpin(cx)) + }; + self.stream = None; + match first { + Some(first) => Poll::Ready(first), + None => Poll::Ready(Err(diesel::result::Error::NotFound)), + } + } +}