diff --git a/bench/bench_select.ml b/bench/bench_select.ml new file mode 100644 index 000000000..ef91398ca --- /dev/null +++ b/bench/bench_select.ml @@ -0,0 +1,57 @@ + +open Eio.Stdenv +open Eio + +let sender_fibers = 4 +let cap = 10 + +let message = 1234 + +(* Send [n_msgs] items to streams in a round-robin way. *) +let sender ~n_msgs streams = + let msgs = Seq.take n_msgs (Seq.ints 0) in + let streams = Seq.cycle (List.to_seq streams) in + let zipped = Seq.zip msgs streams in + ignore (Seq.iter (fun (_i, stream) -> + Stream.add stream message) zipped) + +(* Start one sender fiber for each stream, and let it send n_msgs messages. + Each fiber sends to all streams in a round-robin way. *) +let run_senders ~dom_mgr ?(n_msgs = 100) streams = + Switch.run @@ fun sw -> + ignore @@ List.iter (fun _stream -> + Fiber.fork ~sw (fun () -> + Domain_manager.run dom_mgr (fun () -> + sender ~n_msgs streams))) streams + +(* Receive messages from all streams. *) +let receiver ~n_msgs streams = + for _i = 1 to n_msgs do + assert (Int.equal message (Stream.select streams)); + done + +(* Create [n] streams. *) +let make_streams cap n = + let unfolder i = if i == 0 then None else Some (Stream.create cap, i-1) in + let seq = Seq.unfold unfolder n in + List.of_seq seq + +let run env = + let dom_mgr = domain_mgr env in + let clock = clock env in + let streams = make_streams cap sender_fibers in + let selector = List.map (fun s -> (s, fun i -> i)) streams in + let n_msgs = 10000 in + Switch.run @@ fun sw -> + Fiber.fork ~sw (fun () -> run_senders ~dom_mgr ~n_msgs streams); + let before = Time.now clock in + receiver ~n_msgs:(sender_fibers * n_msgs) selector; + let after = Time.now clock in + let elapsed = after -. before in + let time_per_iter = elapsed /. (Float.of_int @@ sender_fibers * n_msgs) in + [Metric.create + (Printf.sprintf "sync:true senders:%d msgs_per_sender:%d" sender_fibers n_msgs) + (`Float (1e9 *. time_per_iter)) "ns" + "Time per transmitted int"] + + diff --git a/bench/main.ml b/bench/main.ml index 707253019..4d4b0bd07 100644 --- a/bench/main.ml +++ b/bench/main.ml @@ -9,6 +9,7 @@ let benchmarks = [ "Stream", Bench_stream.run; "HTTP", Bench_http.run; "Eio_unix.Fd", Bench_fd.run; + "StreamSelect", Bench_select.run; ] let usage_error () = diff --git a/lib_eio/stream.ml b/lib_eio/stream.ml index 974cfa3b7..8ebbcb8de 100644 --- a/lib_eio/stream.ml +++ b/lib_eio/stream.ml @@ -94,6 +94,53 @@ module Locking = struct Mutex.unlock t.mutex; Some v + let select_of_many streams_fns = + let finished = Atomic.make false in + let cancel_fns = ref [] in + let add_cancel_fn fn = cancel_fns := fn :: !cancel_fns in + let cancel_all () = List.iter (fun fn -> fn ()) !cancel_fns in + let wait ctx enqueue (t, f) = begin + Mutex.lock t.mutex; + (* First check if any items are already available and return early if there are. *) + if not (Queue.is_empty t.items) + then ( + (* If no other stream has yielded already, we are the first one. *) + if Atomic.compare_and_set finished false true + then ( + (* Therefore, cancel all other waiters and take available item. *) + cancel_all (); + let item = Queue.take t.items in + ignore (Waiters.wake_one t.writers ()); + enqueue (Ok (f item))); + Mutex.unlock t.mutex + ) + else add_cancel_fn @@ + (* Otherwise, register interest in this stream. *) + Waiters.cancellable_await_internal ~mutex:(Some t.mutex) t.readers t.id ctx (fun r -> + if Result.is_ok r then ( + if not (Atomic.compare_and_set finished false true) then ( + (* Another stream has yielded an item in the meantime. However, as + we have been waiting on this stream it must have been empty. + + As the stream's mutex was held since before last checking for an item, + the queue must be empty. + *) + assert ((Queue.length t.items) < t.capacity); + Queue.add (Result.get_ok r) t.items + ) else ( + (* remove all other entries of this fiber in other streams' waiters. *) + ignore (Waiters.wake_one t.writers ()); + cancel_all (); + (* item is returned to waiting caller through enqueue and enter_unchecked. *) + enqueue (Result.map f r)) + )); + end in + (* Register interest in all streams and return first available item. *) + let wait_for_stream streams_fns = begin + Suspend.enter_unchecked (fun ctx enqueue -> List.iter (wait ctx enqueue) streams_fns) + end in + wait_for_stream streams_fns + let length t = Mutex.lock t.mutex; let len = Queue.length t.items in @@ -125,6 +172,13 @@ let take_nonblocking = function | Sync x -> Sync.take_nonblocking x | Locking x -> Locking.take_nonblocking x +let select streams = + let filter s = match s with + | (Sync _, _) -> assert false + | (Locking x, f) -> (x, f) + in + Locking.select_of_many (List.map filter streams) + let length = function | Sync _ -> 0 | Locking x -> Locking.length x diff --git a/lib_eio/stream.mli b/lib_eio/stream.mli index 6554cac1a..79b7075b6 100644 --- a/lib_eio/stream.mli +++ b/lib_eio/stream.mli @@ -40,6 +40,10 @@ val take_nonblocking : 'a t -> 'a option Note that if another domain may add to the stream then a [None] result may already be out-of-date by the time this returns. *) +val select : ('a t * ('a -> 'b)) list -> 'b +(** [select] returns the first item yielded by any stream. This only + works for streams with non-zero capacity. *) + val length : 'a t -> int (** [length t] returns the number of items currently in [t]. *) diff --git a/lib_eio/waiters.ml b/lib_eio/waiters.ml index c0cbd4624..99c21155e 100644 --- a/lib_eio/waiters.ml +++ b/lib_eio/waiters.ml @@ -38,11 +38,12 @@ let rec wake_one t v = let is_empty = Lwt_dllist.is_empty -let await_internal ~mutex (t:'a t) id ctx enqueue = +let cancellable_await_internal ~mutex (t:'a t) id ctx enqueue = match Fiber_context.get_error ctx with | Some ex -> Option.iter Mutex.unlock mutex; - enqueue (Error ex) + enqueue (Error ex); + fun () -> () | None -> let resolved_waiter = ref Hook.null in let finished = Atomic.make false in @@ -56,14 +57,24 @@ let await_internal ~mutex (t:'a t) id ctx enqueue = enqueue (Error ex) ) in + let unwait () = + if Atomic.compare_and_set finished false true + then Hook.remove !resolved_waiter + in Fiber_context.set_cancel_fn ctx cancel; let waiter = { enqueue; finished } in match mutex with | None -> - resolved_waiter := add_waiter t waiter + resolved_waiter := add_waiter t waiter; + unwait | Some mutex -> resolved_waiter := add_waiter_protected ~mutex t waiter; - Mutex.unlock mutex + Mutex.unlock mutex; + unwait + +let await_internal ~mutex (t: 'a t) id ctx enqueue = + let _cancel = (cancellable_await_internal ~mutex t id ctx enqueue) in + () (* Returns a result if the wait succeeds, or raises if cancelled. *) let await ~mutex waiters id = diff --git a/lib_eio/waiters.mli b/lib_eio/waiters.mli index 724cf96e7..04b8d4557 100644 --- a/lib_eio/waiters.mli +++ b/lib_eio/waiters.mli @@ -27,8 +27,8 @@ val await : If [t] can be used from multiple domains: - [mutex] must be set to the mutex to use to unlock it. - [mutex] must be already held when calling this function, which will unlock it before blocking. - When [await] returns, [mutex] will have been unlocked. - @raise Cancel.Cancelled if the fiber's context is cancelled *) + When [await] returns, [mutex] will have been unlocked. + @raise Cancel.Cancelled if the fiber's context is cancelled *) val await_internal : mutex:Mutex.t option -> @@ -40,3 +40,12 @@ val await_internal : Note: [enqueue] is called from the triggering domain, which is currently calling {!wake_one} or {!wake_all} and must therefore be holding [mutex]. *) + +val cancellable_await_internal : + mutex:Mutex.t option -> + 'a t -> Ctf.id -> Fiber_context.t -> + (('a, exn) result -> unit) -> (unit -> unit) +(** Like [await_internal], but returns a function which, when called, + removes the current fiber continuation from the waiters list. + This is used when a fiber is waiting for multiple [Waiter]s simultaneously, + and needs to remove itself from other waiters once it has been enqueued by one.*) diff --git a/tests/stream.md b/tests/stream.md index c5a035e3b..10771d00d 100644 --- a/tests/stream.md +++ b/tests/stream.md @@ -357,3 +357,22 @@ Non-blocking take with zero-capacity stream: +Got None from stream - : unit = () ``` + +Selecting from multiple channels: + +```ocaml +# run @@ fun () -> Switch.run (fun sw -> + let t1, t2 = (S.create 2), (S.create 2) in + let selector = [(t1, fun x -> x); (t2, fun x -> x)] in + Fiber.fork ~sw (fun () -> S.add t2 "foo"); + Fiber.fork ~sw (fun () -> traceln "%s" (S.select selector)); + Fiber.fork ~sw (fun () -> traceln "%s" (S.select selector)); + Fiber.fork ~sw (fun () -> traceln "%s" (S.select selector)); + Fiber.fork ~sw (fun () -> S.add t2 "bar"); + Fiber.fork ~sw (fun () -> S.add t1 "baz"); + ) ++foo ++bar ++baz +- : unit = () +```