diff --git a/scylla-rust-wrapper/src/future.rs b/scylla-rust-wrapper/src/future.rs index 0b279a3c..2ca995f3 100644 --- a/scylla-rust-wrapper/src/future.rs +++ b/scylla-rust-wrapper/src/future.rs @@ -12,7 +12,7 @@ use futures::future; use std::future::Future; use std::mem; use std::os::raw::c_void; -use std::sync::{Arc, Condvar, Mutex}; +use std::sync::{Arc, Condvar, Mutex, MutexGuard}; use tokio::task::JoinHandle; use tokio::time::Duration; @@ -48,14 +48,120 @@ impl BoundCallback { } } -#[derive(Default)] +/// State of the execution of the [CassFuture], +/// together with a join handle of the tokio task that is executing it. struct CassFutureState { - value: Option, - err_string: Option, - callback: Option, + execution_state: CassFutureExecution, + /// Presence of this handle while `execution_state` is not `Completed` indicates + /// that no thread is currently blocked on the future. This means that it might + /// not be executed (especially in case of the current-thread executor). + /// Absence means that some thread has blocked on the future, so it is necessarily + /// being executed. join_handle: Option>, } +/// State of the execution of the [CassFuture]. +enum CassFutureExecution { + RunningWithoutCallback, + RunningWithCallback { callback: BoundCallback }, + Completed(CassFutureCompleted), +} + +impl CassFutureExecution { + fn completed(&self) -> bool { + match self { + Self::Completed(_) => true, + Self::RunningWithCallback { .. } | Self::RunningWithoutCallback => false, + } + } + + /// Sets callback for the [CassFuture]. If the future has not completed yet, + /// the callback will be invoked once the future is completed, by the executor thread. + /// If the future has already completed, the callback will be invoked immediately. + unsafe fn set_callback( + mut state_lock: MutexGuard, + fut_ptr: CassBorrowedSharedPtr, + cb: CassFutureCallback, + data: *mut c_void, + ) -> CassError { + let bound_cb = BoundCallback { cb, data }; + + match state_lock.execution_state { + Self::RunningWithoutCallback => { + // Store the callback. + state_lock.execution_state = Self::RunningWithCallback { callback: bound_cb }; + CassError::CASS_OK + } + Self::RunningWithCallback { .. } => + // Another callback has been already set. + { + CassError::CASS_ERROR_LIB_CALLBACK_ALREADY_SET + } + Self::Completed { .. } => { + // The value is already available, we need to call the callback ourselves. + mem::drop(state_lock); + bound_cb.invoke(fut_ptr); + CassError::CASS_OK + } + } + } + + /// Sets the [CassFuture] as completed. This function is called by the executor thread + /// once it completes the underlying Rust future. If there's a callback set, + /// it will be invoked immediately. + fn complete( + mut state_lock: MutexGuard, + value: CassFutureResult, + cass_fut: &Arc, + ) { + let prev_state = mem::replace( + &mut state_lock.execution_state, + Self::Completed(CassFutureCompleted::new(value)), + ); + + // This is because we mustn't hold the lock while invoking the callback. + mem::drop(state_lock); + + let maybe_cb = match prev_state { + Self::RunningWithoutCallback => None, + Self::RunningWithCallback { callback } => Some(callback), + Self::Completed { .. } => unreachable!( + "Exactly one dedicated tokio task is expected to execute and complete the CassFuture." + ), + }; + + if let Some(bound_cb) = maybe_cb { + let fut_ptr = ArcFFI::as_ptr::(cass_fut); + // Safety: pointer is valid, because we get it from arc allocation. + bound_cb.invoke(fut_ptr); + } + } +} + +/// The result of a completed [CassFuture]. +struct CassFutureCompleted { + /// The result of the future, either a value or an error. + value: CassFutureResult, + /// Just a cache for the error message. Needed because the C API exposes a pointer to the + /// error message, and we need to keep it alive until the future is freed. + /// Initially, it's `None`, and it is set to `Some` when the error message is requested + /// by `cass_future_error_message()`. + cached_err_string: Option, +} + +impl CassFutureCompleted { + fn new(value: CassFutureResult) -> Self { + Self { + value, + cached_err_string: None, + } + } +} + +/// The C-API representation of a future. Implemented as a wrapper around a Rust future +/// that can be awaited and has a callback mechanism. It's **eager** in a way that +/// its execution starts possibly immediately (unless the executor thread pool is nempty, +/// which is the case for the current-thread executor). pub struct CassFuture { state: Mutex, wait_for_value: Condvar, @@ -86,23 +192,18 @@ impl CassFuture { fut: impl Future + Send + 'static, ) -> Arc { let cass_fut = Arc::new(CassFuture { - state: Mutex::new(Default::default()), + state: Mutex::new(CassFutureState { + join_handle: None, + execution_state: CassFutureExecution::RunningWithoutCallback, + }), wait_for_value: Condvar::new(), }); let cass_fut_clone = Arc::clone(&cass_fut); let join_handle = RUNTIME.spawn(async move { let r = fut.await; - let maybe_cb = { - let mut guard = cass_fut_clone.state.lock().unwrap(); - guard.value = Some(r); - // Take the callback and call it after releasing the lock - guard.callback.take() - }; - if let Some(bound_cb) = maybe_cb { - let fut_ptr = ArcFFI::as_ptr::(&cass_fut_clone); - // Safety: pointer is valid, because we get it from arc allocation. - bound_cb.invoke(fut_ptr); - } + + let guard = cass_fut_clone.state.lock().unwrap(); + CassFutureExecution::complete(guard, r, &cass_fut_clone); cass_fut_clone.wait_for_value.notify_all(); }); @@ -116,15 +217,15 @@ impl CassFuture { pub fn new_ready(r: CassFutureResult) -> Arc { Arc::new(CassFuture { state: Mutex::new(CassFutureState { - value: Some(r), - ..Default::default() + join_handle: None, + execution_state: CassFutureExecution::Completed(CassFutureCompleted::new(r)), }), wait_for_value: Condvar::new(), }) } pub fn with_waited_result(&self, f: impl FnOnce(&mut CassFutureResult) -> T) -> T { - self.with_waited_state(|s| f(s.value.as_mut().unwrap())) + self.with_waited_state(|s| f(&mut s.value)) } /// Awaits the future until completion. @@ -140,7 +241,7 @@ impl CassFuture { /// - JoinHandle is Some -> some other thread was working on the future, but /// timed out (see [CassFuture::with_waited_state_timed]). We need to /// take the ownership of the handle, and complete the work. - fn with_waited_state(&self, f: impl FnOnce(&mut CassFutureState) -> T) -> T { + fn with_waited_state(&self, f: impl FnOnce(&mut CassFutureCompleted) -> T) -> T { let mut guard = self.state.lock().unwrap(); loop { let handle = guard.join_handle.take(); @@ -153,7 +254,7 @@ impl CassFuture { guard = self .wait_for_value .wait_while(guard, |state| { - state.value.is_none() && state.join_handle.is_none() + !state.execution_state.completed() && state.join_handle.is_none() }) // unwrap: Error appears only when mutex is poisoned. .unwrap(); @@ -165,7 +266,15 @@ impl CassFuture { continue; } } - return f(&mut guard); + + // If we had ended up in either the handle's or with the condvar's `if` branch, + // we awaited the future and it is now completed. + let completed = match &mut guard.execution_state { + CassFutureExecution::RunningWithoutCallback + | CassFutureExecution::RunningWithCallback { .. } => unreachable!(), + CassFutureExecution::Completed(completed) => completed, + }; + return f(completed); } } @@ -174,7 +283,7 @@ impl CassFuture { f: impl FnOnce(&mut CassFutureResult) -> T, timeout_duration: Duration, ) -> Result { - self.with_waited_state_timed(|s| f(s.value.as_mut().unwrap()), timeout_duration) + self.with_waited_state_timed(|s| f(&mut s.value), timeout_duration) } /// Tries to await the future with a given timeout. @@ -194,7 +303,7 @@ impl CassFuture { /// take the ownership of the handle, and continue the work. fn with_waited_state_timed( &self, - f: impl FnOnce(&mut CassFutureState) -> T, + f: impl FnOnce(&mut CassFutureCompleted) -> T, timeout_duration: Duration, ) -> Result { let mut guard = self.state.lock().unwrap(); @@ -242,7 +351,7 @@ impl CassFuture { let (guard_result, timeout_result) = self .wait_for_value .wait_timeout_while(guard, remaining_timeout, |state| { - state.value.is_none() && state.join_handle.is_none() + !state.execution_state.completed() && state.join_handle.is_none() }) // unwrap: Error appears only when mutex is poisoned. .unwrap(); @@ -259,7 +368,14 @@ impl CassFuture { } } - return Ok(f(&mut guard)); + // If we had ended up in either the handle's or with the condvar's `if` branch + // and we didn't return `TimeoutError`, we awaited the future and it is now completed. + let completed = match &mut guard.execution_state { + CassFutureExecution::RunningWithoutCallback + | CassFutureExecution::RunningWithCallback { .. } => unreachable!(), + CassFutureExecution::Completed(completed) => completed, + }; + return Ok(f(completed)); } } @@ -269,21 +385,8 @@ impl CassFuture { cb: CassFutureCallback, data: *mut c_void, ) -> CassError { - let mut lock = self.state.lock().unwrap(); - if lock.callback.is_some() { - // Another callback has been already set - return CassError::CASS_ERROR_LIB_CALLBACK_ALREADY_SET; - } - let bound_cb = BoundCallback { cb, data }; - if lock.value.is_some() { - // The value is already available, we need to call the callback ourselves - mem::drop(lock); - bound_cb.invoke(self_ptr); - return CassError::CASS_OK; - } - // Store the callback - lock.callback = Some(bound_cb); - CassError::CASS_OK + let lock = self.state.lock().unwrap(); + unsafe { CassFutureExecution::set_callback(lock, self_ptr, cb, data) } } fn into_raw(self: Arc) -> CassOwnedSharedPtr { @@ -346,10 +449,7 @@ pub unsafe extern "C" fn cass_future_ready( }; let state_guard = future.state.lock().unwrap(); - match state_guard.value { - None => cass_false, - Some(_) => cass_true, - } + state_guard.execution_state.completed() as cass_bool_t } #[unsafe(no_mangle)] @@ -379,11 +479,10 @@ pub unsafe extern "C" fn cass_future_error_message( return; }; - future.with_waited_state(|state: &mut CassFutureState| { - let value = &state.value; - let msg = state - .err_string - .get_or_insert_with(|| match value.as_ref().unwrap() { + future.with_waited_state(|completed: &mut CassFutureCompleted| { + let msg = completed + .cached_err_string + .get_or_insert_with(|| match &completed.value { Ok(CassResultValue::QueryError(err)) => err.msg(), Err((_, s)) => s.msg(), _ => "".to_string(),