diff --git a/rust/bridge/node/futures/src/executor.rs b/rust/bridge/node/futures/src/executor.rs index cde991fd02..271fbeda59 100644 --- a/rust/bridge/node/futures/src/executor.rs +++ b/rust/bridge/node/futures/src/executor.rs @@ -14,25 +14,42 @@ use std::task::{Poll, Wake}; /// [EventQueue]: https://docs.rs/neon/0.7.0-napi.3/neon/event/struct.EventQueue.html pub trait EventQueueEx { /// Schedules the future to run on the JavaScript main thread until complete. - fn send_future(&self, future: impl Future + 'static + Send); + fn send_future(self: Arc, future: impl Future + 'static + Send); + /// Polls the future synchronously, then schedules it to run on the JavaScript main thread from + /// then on. + fn start_future(self: Arc, future: impl Future + 'static + Send); } impl EventQueueEx for EventQueue { - fn send_future(&self, future: impl Future + 'static + Send) { - self.send(move |mut cx| { + fn send_future(self: Arc, future: impl Future + 'static + Send) { + let self_for_task = self.clone(); + self.send(move |_| { let task = Arc::new(FutureTask { - queue: cx.queue(), + queue: self_for_task, future: Mutex::new(Some(Box::pin(future))), }); task.poll(); Ok(()) }) } + + fn start_future(self: Arc, future: impl Future + 'static + Send) { + let task = Arc::new(FutureTask { + queue: self, + future: Mutex::new(Some(Box::pin(future))), + }); + task.poll(); + } } /// Used to "send" a task from a thread to itself through a multi-threaded interface. -struct AssertSendSafe(T); +pub(crate) struct AssertSendSafe(T); unsafe impl Send for AssertSendSafe {} +impl AssertSendSafe { + pub unsafe fn wrap(value: T) -> Self { + Self(value) + } +} impl Future for AssertSendSafe { type Output = T::Output; @@ -43,33 +60,6 @@ impl Future for AssertSendSafe { } } -/// Adds support for executing closures and futures on the JavaScript main thread's microtask queue. -pub trait ContextEx<'a>: Context<'a> { - /// Schedules `f` to run on the microtask queue. - /// - /// Equivalent to `cx.queue().send(f)` except that `f` doesn't need to be `Send`. - fn run_on_queue(&mut self, f: impl FnOnce(TaskContext<'_>) -> NeonResult<()> + 'static) { - // Because we're currently in a JavaScript context, - // and `f` will run on the event queue associated with the current context, - // we can assert that it's safe to Send `f` to the queue. - let f = AssertSendSafe(f); - self.queue().send(move |cx| f.0(cx)); - } - - /// Schedules `f` to run on the microtask queue. - /// - /// Equivalent to `cx.queue().send_future(f)` except that `f` doesn't need to be `Send`. - fn run_future_on_queue(&mut self, f: impl Future + 'static) { - // Because we're currently in a JavaScript context, - // and `f` will run on the event queue associated with the current context, - // we can assert that it's safe to Send `f` to the queue. - let f = AssertSendSafe(f); - self.queue().send_future(f); - } -} - -impl<'a, T: Context<'a>> ContextEx<'a> for T {} - /// Implements waking for futures scheduled on the JavaScript microtask queue. /// /// When the task is awoken, it reschedules itself on the task queue to re-poll the top-level Future. @@ -77,7 +67,7 @@ struct FutureTask where F: Future + 'static + Send, { - queue: EventQueue, + queue: Arc, future: Mutex>>>, } @@ -88,7 +78,7 @@ where /// Polls the top-level future, while setting `self` up as the waker once more. /// /// When the future completes, it is replaced by `None` to avoid accidentally polling twice. - fn poll(self: Arc) { + fn poll(self: &Arc) { let future = &mut *self.future.lock().expect("Lock can be taken"); if let Some(active_future) = future { match active_future @@ -107,9 +97,9 @@ where F: Future + 'static + Send, { fn wake(self: Arc) { - let self_for_closure = self.clone(); - self.queue.send(move |_cx| { - self_for_closure.poll(); + let queue = self.queue.clone(); + queue.send(move |_cx| { + self.poll(); Ok(()) }) } diff --git a/rust/bridge/node/futures/src/lib.rs b/rust/bridge/node/futures/src/lib.rs index 0fd0428416..61fb02b01e 100644 --- a/rust/bridge/node/futures/src/lib.rs +++ b/rust/bridge/node/futures/src/lib.rs @@ -23,7 +23,7 @@ #![warn(clippy::unwrap_used)] mod executor; -pub use executor::{ContextEx, EventQueueEx}; +pub use executor::EventQueueEx; mod exception; pub use exception::PersistentException; diff --git a/rust/bridge/node/futures/src/promise.rs b/rust/bridge/node/futures/src/promise.rs index 8b863eeb5e..3be5608bcb 100644 --- a/rust/bridge/node/futures/src/promise.rs +++ b/rust/bridge/node/futures/src/promise.rs @@ -7,8 +7,9 @@ use futures::FutureExt; use neon::prelude::*; use std::future::Future; use std::panic::{catch_unwind, AssertUnwindSafe, UnwindSafe}; +use std::sync::Arc; -use crate::executor::ContextEx; +use crate::executor::{AssertSendSafe, EventQueueEx}; use crate::util::describe_panic; use crate::*; @@ -79,13 +80,14 @@ where let promise = promise_ctor.construct(cx, vec![bound_save_promise_callbacks])?; let callbacks_object_root = callbacks_object.root(cx); - let queue = cx.queue(); + let queue = Arc::new(cx.queue()); + let queue_for_future = queue.clone(); - cx.run_future_on_queue(async move { + let future = async move { let result: std::thread::Result> = future.catch_unwind().await; - queue.send(move |mut cx| -> NeonResult<()> { + queue_for_future.send(move |mut cx| -> NeonResult<()> { let settled_result: std::thread::Result, Handle>> = match result { Ok(Ok(settle)) => { @@ -126,7 +128,12 @@ where Ok(()) }); - }); + }; + + // AssertSendSafe because `queue` is running on the same thread as the current context `cx`, + // so in practice we are always on the same thread. + let future = unsafe { AssertSendSafe::wrap(future) }; + queue.start_future(future); Ok(promise) } diff --git a/rust/bridge/node/futures/tests-node-module/src/lib.rs b/rust/bridge/node/futures/tests-node-module/src/lib.rs index ae85c749bc..3323e9a3b7 100644 --- a/rust/bridge/node/futures/tests-node-module/src/lib.rs +++ b/rust/bridge/node/futures/tests-node-module/src/lib.rs @@ -5,6 +5,7 @@ use neon::prelude::*; use signal_neon_futures::*; +use std::sync::Arc; mod panics_and_throws; use panics_and_throws::*; @@ -17,7 +18,7 @@ fn increment_async(mut cx: FunctionContext) -> JsResult { // A complicated test that manually calls a callback at its conclusion. let promise = cx.argument::(0)?; let completion_callback = cx.argument::(1)?.root(&mut cx); - let queue = cx.queue(); + let queue = Arc::new(cx.queue()); let future = JsFuture::from_promise(&mut cx, promise, |cx, result| match result { Ok(value) => Ok(value @@ -27,7 +28,7 @@ fn increment_async(mut cx: FunctionContext) -> JsResult { Err(err) => Err(err.to_string(cx).unwrap().value(cx)), })?; - cx.run_future_on_queue(async move { + queue.clone().start_future(async move { let value_or_error = future.await; queue.send(move |mut cx| { let new_value = match value_or_error {