From a641385f5484fa310bd14dc5d411bf0022f0c7c4 Mon Sep 17 00:00:00 2001 From: Kevin GRONDIN Date: Fri, 19 Sep 2025 20:55:19 +0200 Subject: [PATCH 1/2] Use custom Future combinators to avoid GAT errors --- Cargo.toml | 1 + src/pg/mod.rs | 40 +++++------- src/run_query_dsl/mod.rs | 107 +++++++----------------------- src/run_query_dsl/utils.rs | 130 +++++++++++++++++++++++++++++++++++++ 4 files changed, 173 insertions(+), 105 deletions(-) create mode 100644 src/run_query_dsl/utils.rs diff --git a/Cargo.toml b/Cargo.toml index e79f77f..ff13ccf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ deadpool = { version = "0.12", optional = true, default-features = false, featur ] } mobc = { version = ">=0.7,<0.10", optional = true } scoped-futures = { version = "0.1", features = ["std"] } +pin-project-lite = "0.2.16" [dependencies.diesel] version = "~2.3.0" diff --git a/src/pg/mod.rs b/src/pg/mod.rs index 1c5bb46..33eb02f 100644 --- a/src/pg/mod.rs +++ b/src/pg/mod.rs @@ -1166,46 +1166,42 @@ mod tests { .await .unwrap(); - fn erase<'a, T: Future + Send + 'a>(t: T) -> impl Future + Send + 'a { - t - } - async fn fn12(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> { let f1 = diesel::select(1_i32.into_sql::()).get_result::(&mut conn); let f2 = diesel::select(2_i32.into_sql::()).get_result::(&mut conn); - erase(try_join(f1, f2)).await - } - - async fn fn34(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> { - let f3 = diesel::select(3_i32.into_sql::()).get_result::(&mut conn); - let f4 = diesel::select(4_i32.into_sql::()).get_result::(&mut conn); - - try_join(f3, f4).boxed().await + try_join(f1, f2).await } - async fn fn56(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> { + async fn fn37( + mut conn: &AsyncPgConnection, + ) -> QueryResult<(usize, Vec, i32, Vec, i32)> { + let f3 = diesel::select(0_i32.into_sql::()).execute(&mut conn); + let f4 = diesel::select(4_i32.into_sql::()).load::(&mut conn); let f5 = diesel::select(5_i32.into_sql::()).get_result::(&mut conn); - let f6 = diesel::select(6_i32.into_sql::()).get_result::(&mut conn); + let f6 = diesel::select(6_i32.into_sql::()).get_results::(&mut conn); + let f7 = diesel::select(7_i32.into_sql::()).first::(&mut conn); - try_join(f5.boxed(), f6.boxed()).await + try_join!(f3, f4, f5, f6, f7) } conn.transaction(|conn| { async move { let f12 = fn12(conn); - let f34 = fn34(conn); - let f56 = fn56(conn); + let f37 = fn37(conn); - let ((r1, r2), ((r3, r4), (r5, r6))) = - try_join(f12, try_join(f34, f56)).await.unwrap(); + let ((r1, r2), (r3, r4, r5, r6, r7)) = try_join!(f12, f37).unwrap(); assert_eq!(r1, 1); assert_eq!(r2, 2); - assert_eq!(r3, 3); - assert_eq!(r4, 4); + assert_eq!(r3, 1); + assert_eq!(r4, vec![4]); assert_eq!(r5, 5); - assert_eq!(r6, 6); + assert_eq!(r6, vec![6]); + assert_eq!(r7, 7); + + fn12(conn).await?; + fn37(conn).await?; QueryResult::<_>::Ok(()) } diff --git a/src/run_query_dsl/mod.rs b/src/run_query_dsl/mod.rs index 437d2a2..3963ba2 100644 --- a/src/run_query_dsl/mod.rs +++ b/src/run_query_dsl/mod.rs @@ -1,13 +1,15 @@ +mod utils; + use crate::AsyncConnectionCore; use diesel::associations::HasTable; use diesel::query_builder::IntoUpdateTarget; use diesel::result::QueryResult; use diesel::AsChangeset; use futures_core::future::BoxFuture; -use futures_core::Stream; -use futures_util::{future, stream, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; +#[cfg(any(feature = "mysql", feature = "postgres"))] +use futures_util::FutureExt; +use futures_util::{stream, StreamExt, TryStreamExt}; use std::future::Future; -use std::pin::Pin; /// The traits used by `QueryDsl`. /// @@ -22,7 +24,7 @@ pub mod methods { use diesel::expression::QueryMetadata; use diesel::query_builder::{AsQuery, QueryFragment, QueryId}; use diesel::query_dsl::CompatibleType; - use futures_util::{Future, Stream, TryFutureExt}; + use futures_util::{Future, Stream}; /// The `execute` method /// @@ -74,6 +76,7 @@ pub mod methods { type LoadFuture<'conn>: Future>> + Send where Conn: 'conn; + /// The inner stream returned by [`LoadQuery::internal_load`] type Stream<'conn>: Stream> + Send where @@ -96,10 +99,7 @@ pub mod methods { ST: 'static, { type LoadFuture<'conn> - = future::MapOk< - Conn::LoadFuture<'conn, 'query>, - fn(Conn::Stream<'conn, 'query>) -> Self::Stream<'conn>, - > + = utils::Map, Self::Stream<'conn>> where Conn: 'conn; @@ -112,33 +112,13 @@ pub mod methods { Conn: 'conn; fn internal_load(self, conn: &mut Conn) -> Self::LoadFuture<'_> { - conn.load(self) - .map_ok(map_result_stream_future::) + utils::Map::new(conn.load(self), |stream| { + Ok(stream?.map(|row| { + U::build_from_row(&row?).map_err(diesel::result::Error::DeserializationError) + })) + }) } } - - #[allow(clippy::type_complexity)] - fn map_result_stream_future<'s, 'a, U, S, R, DB, ST>( - stream: S, - ) -> stream::Map) -> QueryResult> - where - S: Stream> + Send + 's, - R: diesel::row::Row<'a, DB> + 's, - DB: Backend + 'static, - U: FromSqlRow + 'static, - ST: 'static, - { - stream.map(map_row_helper::<_, DB, U, ST>) - } - - fn map_row_helper<'a, R, DB, U, ST>(row: QueryResult) -> QueryResult - where - U: FromSqlRow, - R: diesel::row::Row<'a, DB>, - DB: Backend, - { - U::build_from_row(&row?).map_err(diesel::result::Error::DeserializationError) - } } /// The return types produced by the various [`RunQueryDsl`] methods @@ -149,37 +129,24 @@ pub mod methods { // the same connection #[allow(type_alias_bounds)] // we need these bounds otherwise we cannot use GAT's pub mod return_futures { + use crate::run_query_dsl::utils; + use super::methods::LoadQuery; - use diesel::QueryResult; - use futures_util::{future, stream}; + use futures_util::stream; use std::pin::Pin; /// The future returned by [`RunQueryDsl::load`](super::RunQueryDsl::load) /// and [`RunQueryDsl::get_results`](super::RunQueryDsl::get_results) /// /// This is essentially `impl Future>>` - pub type LoadFuture<'conn, 'query, Q: LoadQuery<'query, Conn, U>, Conn, U> = future::AndThen< - Q::LoadFuture<'conn>, - stream::TryCollect, Vec>, - fn(Q::Stream<'conn>) -> stream::TryCollect, Vec>, - >; + pub type LoadFuture<'conn, 'query, Q: LoadQuery<'query, Conn, U>, Conn, U> = + utils::AndThen, stream::TryCollect, Vec>>; /// The future returned by [`RunQueryDsl::get_result`](super::RunQueryDsl::get_result) /// /// This is essentially `impl Future>` - pub type GetResult<'conn, 'query, Q: LoadQuery<'query, Conn, U>, Conn, U> = future::AndThen< - Q::LoadFuture<'conn>, - future::Map< - stream::StreamFuture>>>, - fn((Option>, Pin>>)) -> QueryResult, - >, - fn( - Q::Stream<'conn>, - ) -> future::Map< - stream::StreamFuture>>>, - fn((Option>, Pin>>)) -> QueryResult, - >, - >; + pub type GetResult<'conn, 'query, Q: LoadQuery<'query, Conn, U>, Conn, U> = + utils::AndThen, utils::LoadNext>>>>; } /// Methods used to execute queries. @@ -346,13 +313,7 @@ pub trait RunQueryDsl: Sized { Conn: AsyncConnectionCore, Self: methods::LoadQuery<'query, Conn, U> + 'query, { - fn collect_result(stream: S) -> stream::TryCollect> - where - S: Stream>, - { - stream.try_collect() - } - self.internal_load(conn).and_then(collect_result::) + utils::AndThen::new(self.internal_load(conn), |stream| Ok(stream?.try_collect())) } /// Executes the given query, returning a [`Stream`] with the returned rows. @@ -547,29 +508,9 @@ pub trait RunQueryDsl: Sized { Conn: AsyncConnectionCore, Self: methods::LoadQuery<'query, Conn, U> + 'query, { - #[allow(clippy::type_complexity)] - fn get_next_stream_element( - stream: S, - ) -> future::Map< - stream::StreamFuture>>, - fn((Option>, Pin>)) -> QueryResult, - > - where - S: Stream>, - { - fn map_option_to_result( - (o, _): (Option>, Pin>), - ) -> QueryResult { - match o { - Some(s) => s, - None => Err(diesel::result::Error::NotFound), - } - } - - Box::pin(stream).into_future().map(map_option_to_result) - } - - self.load_stream(conn).and_then(get_next_stream_element) + utils::AndThen::new(self.internal_load(conn), |stream| { + Ok(utils::LoadNext::new(Box::pin(stream?))) + }) } /// Runs the command, returning an `Vec` with the affected rows. diff --git a/src/run_query_dsl/utils.rs b/src/run_query_dsl/utils.rs new file mode 100644 index 0000000..57e7995 --- /dev/null +++ b/src/run_query_dsl/utils.rs @@ -0,0 +1,130 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use diesel::QueryResult; +use futures_core::{ready, TryFuture, TryStream}; +use futures_util::TryStreamExt; +use pin_project_lite::pin_project; + +pin_project! { + /// Reimplementation of [`futures_util::future::Map`] without the generic closure argument + #[project = MapProj] + #[project_replace = MapProjReplace] + pub enum Map { + Incomplete { + #[pin] + future: Fut, + f: fn(Fut::Output) -> QueryResult, + }, + Complete, + } +} + +impl Map { + pub(crate) fn new(future: Fut, f: fn(Fut::Output) -> QueryResult) -> Self { + Self::Incomplete { future, f } + } +} + +impl Future for Map { + type Output = QueryResult; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.as_mut().project() { + MapProj::Incomplete { future, .. } => { + let output = ready!(future.poll(cx)); + match self.as_mut().project_replace(Map::Complete) { + MapProjReplace::Incomplete { f, .. } => Poll::Ready(f(output)), + MapProjReplace::Complete => unreachable!(), + } + } + MapProj::Complete => panic!("Map polled after completion"), + } + } +} + +pin_project! { + /// Reimplementation of [`futures_util::future::AndThen`] without the generic closure argument + #[project = AndThenProj] + pub enum AndThen { + First { + #[pin] + future1: Map, + }, + Second { + #[pin] + future2: Fut2, + }, + Empty, + } +} + +impl AndThen { + pub(crate) fn new(fut1: Fut1, f: fn(Fut1::Output) -> QueryResult) -> AndThen { + Self::First { + future1: Map::new(fut1, f), + } + } +} + +impl Future for AndThen +where + Fut1: TryFuture, + Fut2: TryFuture, +{ + type Output = QueryResult; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + loop { + match self.as_mut().project() { + AndThenProj::First { future1 } => match ready!(future1.try_poll(cx)) { + Ok(future2) => self.set(Self::Second { future2 }), + Err(error) => { + self.set(Self::Empty); + break Poll::Ready(Err(error)); + } + }, + AndThenProj::Second { future2 } => { + let output = ready!(future2.try_poll(cx)); + self.set(Self::Empty); + break Poll::Ready(output); + } + AndThenProj::Empty => panic!("AndThen polled after completion"), + } + } + } +} + +/// Converts a stream into a future, only yielding the first element. +/// Based on [`futures_util::stream::StreamFuture`]. +pub struct LoadNext { + stream: Option, +} + +impl LoadNext { + pub(crate) fn new(stream: St) -> Self { + Self { + stream: Some(stream), + } + } +} + +impl Future for LoadNext +where + St: TryStream + Unpin, +{ + type Output = QueryResult; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let first = { + let s = self.stream.as_mut().expect("polling LoadNext twice"); + ready!(s.try_poll_next_unpin(cx)) + }; + self.stream = None; + match first { + Some(first) => Poll::Ready(first), + None => Poll::Ready(Err(diesel::result::Error::NotFound)), + } + } +} From 661de8c4442a123ccde74001b44fc6a154dbc4b2 Mon Sep 17 00:00:00 2001 From: Georg Semmler Date: Thu, 2 Oct 2025 12:18:08 +0200 Subject: [PATCH 2/2] Replace hand written future impls This commit replaces the hand written future impls with such ones that wrap the types from futures-util. This hopefully reduces the complexity and also allows using the well tested types instead of hand written ones. I also removed the dependency on pin-project-light in favour of two lines of unsafe code as pin-project would do the same thing internally as well and these two instances are trivial to reason about. --- Cargo.toml | 1 - src/pg/mod.rs | 6 +- src/run_query_dsl/mod.rs | 12 ++-- src/run_query_dsl/utils.rs | 120 ++++++++++++++++--------------------- 4 files changed, 60 insertions(+), 79 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ff13ccf..e79f77f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,7 +36,6 @@ deadpool = { version = "0.12", optional = true, default-features = false, featur ] } mobc = { version = ">=0.7,<0.10", optional = true } scoped-futures = { version = "0.1", features = ["std"] } -pin-project-lite = "0.2.16" [dependencies.diesel] version = "~2.3.0" diff --git a/src/pg/mod.rs b/src/pg/mod.rs index 33eb02f..6a7899a 100644 --- a/src/pg/mod.rs +++ b/src/pg/mod.rs @@ -1175,14 +1175,14 @@ mod tests { async fn fn37( mut conn: &AsyncPgConnection, - ) -> QueryResult<(usize, Vec, i32, Vec, i32)> { + ) -> QueryResult<(usize, (Vec, (i32, (Vec, i32))))> { let f3 = diesel::select(0_i32.into_sql::()).execute(&mut conn); let f4 = diesel::select(4_i32.into_sql::()).load::(&mut conn); let f5 = diesel::select(5_i32.into_sql::()).get_result::(&mut conn); let f6 = diesel::select(6_i32.into_sql::()).get_results::(&mut conn); let f7 = diesel::select(7_i32.into_sql::()).first::(&mut conn); - try_join!(f3, f4, f5, f6, f7) + try_join(f3, try_join(f4, try_join(f5, try_join(f6, f7)))).await } conn.transaction(|conn| { @@ -1190,7 +1190,7 @@ mod tests { let f12 = fn12(conn); let f37 = fn37(conn); - let ((r1, r2), (r3, r4, r5, r6, r7)) = try_join!(f12, f37).unwrap(); + let ((r1, r2), (r3, (r4, (r5, (r6, r7))))) = try_join(f12, f37).await.unwrap(); assert_eq!(r1, 1); assert_eq!(r2, 2); diff --git a/src/run_query_dsl/mod.rs b/src/run_query_dsl/mod.rs index 3963ba2..18a9fda 100644 --- a/src/run_query_dsl/mod.rs +++ b/src/run_query_dsl/mod.rs @@ -99,7 +99,7 @@ pub mod methods { ST: 'static, { type LoadFuture<'conn> - = utils::Map, Self::Stream<'conn>> + = utils::MapOk, Self::Stream<'conn>> where Conn: 'conn; @@ -112,10 +112,10 @@ pub mod methods { Conn: 'conn; fn internal_load(self, conn: &mut Conn) -> Self::LoadFuture<'_> { - utils::Map::new(conn.load(self), |stream| { - Ok(stream?.map(|row| { + utils::MapOk::new(conn.load(self), |stream| { + stream.map(|row| { U::build_from_row(&row?).map_err(diesel::result::Error::DeserializationError) - })) + }) }) } } @@ -313,7 +313,7 @@ pub trait RunQueryDsl: Sized { Conn: AsyncConnectionCore, Self: methods::LoadQuery<'query, Conn, U> + 'query, { - utils::AndThen::new(self.internal_load(conn), |stream| Ok(stream?.try_collect())) + utils::AndThen::new(self.internal_load(conn), |stream| stream.try_collect()) } /// Executes the given query, returning a [`Stream`] with the returned rows. @@ -509,7 +509,7 @@ pub trait RunQueryDsl: Sized { Self: methods::LoadQuery<'query, Conn, U> + 'query, { utils::AndThen::new(self.internal_load(conn), |stream| { - Ok(utils::LoadNext::new(Box::pin(stream?))) + utils::LoadNext::new(Box::pin(stream)) }) } diff --git a/src/run_query_dsl/utils.rs b/src/run_query_dsl/utils.rs index 57e7995..22f5891 100644 --- a/src/run_query_dsl/utils.rs +++ b/src/run_query_dsl/utils.rs @@ -4,95 +4,77 @@ use std::task::{Context, Poll}; use diesel::QueryResult; use futures_core::{ready, TryFuture, TryStream}; -use futures_util::TryStreamExt; -use pin_project_lite::pin_project; +use futures_util::{TryFutureExt, TryStreamExt}; -pin_project! { - /// Reimplementation of [`futures_util::future::Map`] without the generic closure argument - #[project = MapProj] - #[project_replace = MapProjReplace] - pub enum Map { - Incomplete { - #[pin] - future: Fut, - f: fn(Fut::Output) -> QueryResult, - }, - Complete, - } +// We use a custom future implementation here to erase some lifetimes +// that otherwise need to be specified explicitly +// +// Specifying these lifetimes results in the compiler not beeing +// able to look through the generic code and emit +// lifetime erros for pipelined queries. See +// https://github.com/weiznich/diesel_async/issues/249 for more context +#[repr(transparent)] +pub struct MapOk { + future: futures_util::future::MapOk T>, } -impl Map { - pub(crate) fn new(future: Fut, f: fn(Fut::Output) -> QueryResult) -> Self { - Self::Incomplete { future, f } +impl Future for MapOk +where + F: TryFuture, + futures_util::future::MapOk T>: Future>, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { + unsafe { + // SAFETY: This projects pinning to the only inner field, so it + // should be safe + self.map_unchecked_mut(|s| &mut s.future) + } + .poll(cx) } } -impl Future for Map { - type Output = QueryResult; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.as_mut().project() { - MapProj::Incomplete { future, .. } => { - let output = ready!(future.poll(cx)); - match self.as_mut().project_replace(Map::Complete) { - MapProjReplace::Incomplete { f, .. } => Poll::Ready(f(output)), - MapProjReplace::Complete => unreachable!(), - } - } - MapProj::Complete => panic!("Map polled after completion"), +impl MapOk { + pub(crate) fn new(future: Fut, f: fn(Fut::Ok) -> T) -> Self { + Self { + future: future.map_ok(f), } } } -pin_project! { - /// Reimplementation of [`futures_util::future::AndThen`] without the generic closure argument - #[project = AndThenProj] - pub enum AndThen { - First { - #[pin] - future1: Map, - }, - Second { - #[pin] - future2: Fut2, - }, - Empty, - } +// similar to `MapOk` above this mainly exists to hide the lifetime +#[repr(transparent)] +pub struct AndThen { + future: futures_util::future::AndThen F2>, } -impl AndThen { - pub(crate) fn new(fut1: Fut1, f: fn(Fut1::Output) -> QueryResult) -> AndThen { - Self::First { - future1: Map::new(fut1, f), +impl AndThen +where + Fut1: TryFuture, + Fut2: TryFuture, +{ + pub(crate) fn new(fut1: Fut1, f: fn(Fut1::Ok) -> Fut2) -> AndThen { + Self { + future: fut1.and_then(f), } } } -impl Future for AndThen +impl Future for AndThen where - Fut1: TryFuture, - Fut2: TryFuture, + F1: TryFuture, + F2: TryFuture, { - type Output = QueryResult; + type Output = Result; - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - loop { - match self.as_mut().project() { - AndThenProj::First { future1 } => match ready!(future1.try_poll(cx)) { - Ok(future2) => self.set(Self::Second { future2 }), - Err(error) => { - self.set(Self::Empty); - break Poll::Ready(Err(error)); - } - }, - AndThenProj::Second { future2 } => { - let output = ready!(future2.try_poll(cx)); - self.set(Self::Empty); - break Poll::Ready(output); - } - AndThenProj::Empty => panic!("AndThen polled after completion"), - } + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + unsafe { + // SAFETY: This projects pinning to the only inner field, so it + // should be safe + self.map_unchecked_mut(|s| &mut s.future) } + .poll(cx) } }