Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 150 additions & 51 deletions scylla-rust-wrapper/src/future.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use futures::future;
use std::future::Future;
use std::mem;
use std::os::raw::c_void;
use std::sync::{Arc, Condvar, Mutex};
use std::sync::{Arc, Condvar, Mutex, MutexGuard};
use tokio::task::JoinHandle;
use tokio::time::Duration;

Expand Down Expand Up @@ -48,14 +48,120 @@ impl BoundCallback {
}
}

#[derive(Default)]
/// State of the execution of the [CassFuture],
/// together with a join handle of the tokio task that is executing it.
struct CassFutureState {
value: Option<CassFutureResult>,
err_string: Option<String>,
callback: Option<BoundCallback>,
execution_state: CassFutureExecution,
/// Presence of this handle while `execution_state` is not `Completed` indicates
/// that no thread is currently blocked on the future. This means that it might
/// not be executed (especially in case of the current-thread executor).
/// Absence means that some thread has blocked on the future, so it is necessarily
/// being executed.
join_handle: Option<JoinHandle<()>>,
}

/// State of the execution of the [CassFuture].
enum CassFutureExecution {
RunningWithoutCallback,
RunningWithCallback { callback: BoundCallback },
Completed(CassFutureCompleted),
}

impl CassFutureExecution {
fn completed(&self) -> bool {
match self {
Self::Completed(_) => true,
Self::RunningWithCallback { .. } | Self::RunningWithoutCallback => false,
}
}

/// Sets callback for the [CassFuture]. If the future has not completed yet,
/// the callback will be invoked once the future is completed, by the executor thread.
/// If the future has already completed, the callback will be invoked immediately.
unsafe fn set_callback(
mut state_lock: MutexGuard<CassFutureState>,
fut_ptr: CassBorrowedSharedPtr<CassFuture, CMut>,
cb: CassFutureCallback,
data: *mut c_void,
) -> CassError {
let bound_cb = BoundCallback { cb, data };

match state_lock.execution_state {
Self::RunningWithoutCallback => {
// Store the callback.
state_lock.execution_state = Self::RunningWithCallback { callback: bound_cb };
CassError::CASS_OK
}
Self::RunningWithCallback { .. } =>
// Another callback has been already set.
{
CassError::CASS_ERROR_LIB_CALLBACK_ALREADY_SET
}
Self::Completed { .. } => {
// The value is already available, we need to call the callback ourselves.
mem::drop(state_lock);
bound_cb.invoke(fut_ptr);
CassError::CASS_OK
}
}
}

/// Sets the [CassFuture] as completed. This function is called by the executor thread
/// once it completes the underlying Rust future. If there's a callback set,
/// it will be invoked immediately.
fn complete(
mut state_lock: MutexGuard<CassFutureState>,
value: CassFutureResult,
cass_fut: &Arc<CassFuture>,
) {
let prev_state = mem::replace(
&mut state_lock.execution_state,
Self::Completed(CassFutureCompleted::new(value)),
);

// This is because we mustn't hold the lock while invoking the callback.
mem::drop(state_lock);

let maybe_cb = match prev_state {
Self::RunningWithoutCallback => None,
Self::RunningWithCallback { callback } => Some(callback),
Self::Completed { .. } => unreachable!(
"Exactly one dedicated tokio task is expected to execute and complete the CassFuture."
),
};

if let Some(bound_cb) = maybe_cb {
let fut_ptr = ArcFFI::as_ptr::<CMut>(cass_fut);
// Safety: pointer is valid, because we get it from arc allocation.
bound_cb.invoke(fut_ptr);
}
}
}

/// The result of a completed [CassFuture].
struct CassFutureCompleted {
/// The result of the future, either a value or an error.
value: CassFutureResult,
/// Just a cache for the error message. Needed because the C API exposes a pointer to the
/// error message, and we need to keep it alive until the future is freed.
/// Initially, it's `None`, and it is set to `Some` when the error message is requested
/// by `cass_future_error_message()`.
cached_err_string: Option<String>,
}

impl CassFutureCompleted {
fn new(value: CassFutureResult) -> Self {
Self {
value,
cached_err_string: None,
}
}
}

/// The C-API representation of a future. Implemented as a wrapper around a Rust future
/// that can be awaited and has a callback mechanism. It's **eager** in a way that
/// its execution starts possibly immediately (unless the executor thread pool is nempty,
/// which is the case for the current-thread executor).
pub struct CassFuture {
state: Mutex<CassFutureState>,
wait_for_value: Condvar,
Expand Down Expand Up @@ -86,23 +192,18 @@ impl CassFuture {
fut: impl Future<Output = CassFutureResult> + Send + 'static,
) -> Arc<CassFuture> {
let cass_fut = Arc::new(CassFuture {
state: Mutex::new(Default::default()),
state: Mutex::new(CassFutureState {
join_handle: None,
execution_state: CassFutureExecution::RunningWithoutCallback,
}),
wait_for_value: Condvar::new(),
});
let cass_fut_clone = Arc::clone(&cass_fut);
let join_handle = RUNTIME.spawn(async move {
let r = fut.await;
let maybe_cb = {
let mut guard = cass_fut_clone.state.lock().unwrap();
guard.value = Some(r);
// Take the callback and call it after releasing the lock
guard.callback.take()
};
if let Some(bound_cb) = maybe_cb {
let fut_ptr = ArcFFI::as_ptr::<CMut>(&cass_fut_clone);
// Safety: pointer is valid, because we get it from arc allocation.
bound_cb.invoke(fut_ptr);
}

let guard = cass_fut_clone.state.lock().unwrap();
CassFutureExecution::complete(guard, r, &cass_fut_clone);

cass_fut_clone.wait_for_value.notify_all();
});
Expand All @@ -116,15 +217,15 @@ impl CassFuture {
pub fn new_ready(r: CassFutureResult) -> Arc<Self> {
Arc::new(CassFuture {
state: Mutex::new(CassFutureState {
value: Some(r),
..Default::default()
join_handle: None,
execution_state: CassFutureExecution::Completed(CassFutureCompleted::new(r)),
}),
wait_for_value: Condvar::new(),
})
}

pub fn with_waited_result<T>(&self, f: impl FnOnce(&mut CassFutureResult) -> T) -> T {
self.with_waited_state(|s| f(s.value.as_mut().unwrap()))
self.with_waited_state(|s| f(&mut s.value))
}

/// Awaits the future until completion.
Expand All @@ -140,7 +241,7 @@ impl CassFuture {
/// - JoinHandle is Some -> some other thread was working on the future, but
/// timed out (see [CassFuture::with_waited_state_timed]). We need to
/// take the ownership of the handle, and complete the work.
fn with_waited_state<T>(&self, f: impl FnOnce(&mut CassFutureState) -> T) -> T {
fn with_waited_state<T>(&self, f: impl FnOnce(&mut CassFutureCompleted) -> T) -> T {
let mut guard = self.state.lock().unwrap();
loop {
let handle = guard.join_handle.take();
Expand All @@ -153,7 +254,7 @@ impl CassFuture {
guard = self
.wait_for_value
.wait_while(guard, |state| {
state.value.is_none() && state.join_handle.is_none()
!state.execution_state.completed() && state.join_handle.is_none()
})
// unwrap: Error appears only when mutex is poisoned.
.unwrap();
Expand All @@ -165,7 +266,15 @@ impl CassFuture {
continue;
}
}
return f(&mut guard);

// If we had ended up in either the handle's or with the condvar's `if` branch,
// we awaited the future and it is now completed.
let completed = match &mut guard.execution_state {
CassFutureExecution::RunningWithoutCallback
| CassFutureExecution::RunningWithCallback { .. } => unreachable!(),
CassFutureExecution::Completed(completed) => completed,
};
return f(completed);
}
}

Expand All @@ -174,7 +283,7 @@ impl CassFuture {
f: impl FnOnce(&mut CassFutureResult) -> T,
timeout_duration: Duration,
) -> Result<T, FutureError> {
self.with_waited_state_timed(|s| f(s.value.as_mut().unwrap()), timeout_duration)
self.with_waited_state_timed(|s| f(&mut s.value), timeout_duration)
}

/// Tries to await the future with a given timeout.
Expand All @@ -194,7 +303,7 @@ impl CassFuture {
/// take the ownership of the handle, and continue the work.
fn with_waited_state_timed<T>(
&self,
f: impl FnOnce(&mut CassFutureState) -> T,
f: impl FnOnce(&mut CassFutureCompleted) -> T,
timeout_duration: Duration,
) -> Result<T, FutureError> {
let mut guard = self.state.lock().unwrap();
Expand Down Expand Up @@ -242,7 +351,7 @@ impl CassFuture {
let (guard_result, timeout_result) = self
.wait_for_value
.wait_timeout_while(guard, remaining_timeout, |state| {
state.value.is_none() && state.join_handle.is_none()
!state.execution_state.completed() && state.join_handle.is_none()
})
// unwrap: Error appears only when mutex is poisoned.
.unwrap();
Expand All @@ -259,7 +368,14 @@ impl CassFuture {
}
}

return Ok(f(&mut guard));
// If we had ended up in either the handle's or with the condvar's `if` branch
// and we didn't return `TimeoutError`, we awaited the future and it is now completed.
let completed = match &mut guard.execution_state {
CassFutureExecution::RunningWithoutCallback
| CassFutureExecution::RunningWithCallback { .. } => unreachable!(),
CassFutureExecution::Completed(completed) => completed,
};
return Ok(f(completed));
}
}

Expand All @@ -269,21 +385,8 @@ impl CassFuture {
cb: CassFutureCallback,
data: *mut c_void,
) -> CassError {
let mut lock = self.state.lock().unwrap();
if lock.callback.is_some() {
// Another callback has been already set
return CassError::CASS_ERROR_LIB_CALLBACK_ALREADY_SET;
}
let bound_cb = BoundCallback { cb, data };
if lock.value.is_some() {
// The value is already available, we need to call the callback ourselves
mem::drop(lock);
bound_cb.invoke(self_ptr);
return CassError::CASS_OK;
}
// Store the callback
lock.callback = Some(bound_cb);
CassError::CASS_OK
let lock = self.state.lock().unwrap();
unsafe { CassFutureExecution::set_callback(lock, self_ptr, cb, data) }
}

fn into_raw(self: Arc<Self>) -> CassOwnedSharedPtr<Self, CMut> {
Expand Down Expand Up @@ -346,10 +449,7 @@ pub unsafe extern "C" fn cass_future_ready(
};

let state_guard = future.state.lock().unwrap();
match state_guard.value {
None => cass_false,
Some(_) => cass_true,
}
state_guard.execution_state.completed() as cass_bool_t
}

#[unsafe(no_mangle)]
Expand Down Expand Up @@ -379,11 +479,10 @@ pub unsafe extern "C" fn cass_future_error_message(
return;
};

future.with_waited_state(|state: &mut CassFutureState| {
let value = &state.value;
let msg = state
.err_string
.get_or_insert_with(|| match value.as_ref().unwrap() {
future.with_waited_state(|completed: &mut CassFutureCompleted| {
let msg = completed
.cached_err_string
.get_or_insert_with(|| match &completed.value {
Ok(CassResultValue::QueryError(err)) => err.msg(),
Err((_, s)) => s.msg(),
_ => "".to_string(),
Expand Down