@@ -385,8 +385,25 @@ pi_result cuda_piEventRetain(pi_event event);
385385
386386// / \endcond
387387
388+ void _pi_queue::compute_stream_wait_for_barrier_if_needed (CUstream stream,
389+ pi_uint32 stream_i) {
390+ if (barrier_event_ && !compute_applied_barrier_[stream_i]) {
391+ PI_CHECK_ERROR (cuStreamWaitEvent (stream, barrier_event_, 0 ));
392+ compute_applied_barrier_[stream_i] = true ;
393+ }
394+ }
395+
396+ void _pi_queue::transfer_stream_wait_for_barrier_if_needed (CUstream stream,
397+ pi_uint32 stream_i) {
398+ if (barrier_event_ && !transfer_applied_barrier_[stream_i]) {
399+ PI_CHECK_ERROR (cuStreamWaitEvent (stream, barrier_event_, 0 ));
400+ transfer_applied_barrier_[stream_i] = true ;
401+ }
402+ }
403+
388404CUstream _pi_queue::get_next_compute_stream (pi_uint32 *stream_token) {
389405 pi_uint32 stream_i;
406+ pi_uint32 token;
390407 while (true ) {
391408 if (num_compute_streams_ < compute_streams_.size ()) {
392409 // the check above is for performance - so as not to lock mutex every time
@@ -398,40 +415,46 @@ CUstream _pi_queue::get_next_compute_stream(pi_uint32 *stream_token) {
398415 cuStreamCreate (&compute_streams_[num_compute_streams_++], flags_));
399416 }
400417 }
401- stream_i = compute_stream_idx_++;
418+ token = compute_stream_idx_++;
419+ stream_i = token % compute_streams_.size ();
402420 // if a stream has been reused before it was next selected round-robin
403421 // fashion, we want to delay its next use and instead select another one
404422 // that is more likely to have completed all the enqueued work.
405- if (delay_compute_[stream_i % compute_streams_. size () ]) {
406- delay_compute_[stream_i % compute_streams_. size () ] = false ;
423+ if (delay_compute_[stream_i]) {
424+ delay_compute_[stream_i] = false ;
407425 } else {
408426 break ;
409427 }
410428 }
411429 if (stream_token) {
412- *stream_token = stream_i ;
430+ *stream_token = token ;
413431 }
414- return compute_streams_[stream_i % compute_streams_.size ()];
432+ CUstream res = compute_streams_[stream_i];
433+ compute_stream_wait_for_barrier_if_needed (res, stream_i);
434+ return res;
415435}
416436
417437CUstream _pi_queue::get_next_compute_stream (pi_uint32 num_events_in_wait_list,
418438 const pi_event *event_wait_list,
419439 _pi_stream_guard &guard,
420440 pi_uint32 *stream_token) {
421441 for (pi_uint32 i = 0 ; i < num_events_in_wait_list; i++) {
422- pi_uint32 token = event_wait_list[i]->get_stream_token ();
442+ pi_uint32 token = event_wait_list[i]->get_compute_stream_token ();
423443 if (event_wait_list[i]->get_queue () == this && can_reuse_stream (token)) {
424444 std::unique_lock<std::mutex> compute_sync_guard (
425445 compute_stream_sync_mutex_);
426446 // redo the check after lock to avoid data races on
427447 // last_sync_compute_streams_
428448 if (can_reuse_stream (token)) {
429- delay_compute_[token % delay_compute_.size ()] = true ;
449+ pi_uint32 stream_i = token % delay_compute_.size ();
450+ delay_compute_[stream_i] = true ;
430451 if (stream_token) {
431452 *stream_token = token;
432453 }
433454 guard = _pi_stream_guard{std::move (compute_sync_guard)};
434- return event_wait_list[i]->get_stream ();
455+ CUstream res = event_wait_list[i]->get_stream ();
456+ compute_stream_wait_for_barrier_if_needed (res, stream_i);
457+ return res;
435458 }
436459 }
437460 }
@@ -453,7 +476,10 @@ CUstream _pi_queue::get_next_transfer_stream() {
453476 cuStreamCreate (&transfer_streams_[num_transfer_streams_++], flags_));
454477 }
455478 }
456- return transfer_streams_[transfer_stream_idx_++ % transfer_streams_.size ()];
479+ pi_uint32 stream_i = transfer_stream_idx_++ % transfer_streams_.size ();
480+ CUstream res = transfer_streams_[stream_i];
481+ transfer_stream_wait_for_barrier_if_needed (res, stream_i);
482+ return res;
457483}
458484
459485_pi_event::_pi_event (pi_command_type type, pi_context context, pi_queue queue,
@@ -2549,7 +2575,7 @@ pi_result cuda_piQueueFinish(pi_queue command_queue) {
25492575 nullptr ); // need PI_ERROR_INVALID_EXTERNAL_HANDLE error code
25502576 ScopedContext active (command_queue->get_context ());
25512577
2552- command_queue->sync_streams ([&result](CUstream s) {
2578+ command_queue->sync_streams < /* ResetUsed= */ true > ([&result](CUstream s) {
25532579 result = PI_CHECK_ERROR (cuStreamSynchronize (s));
25542580 });
25552581
@@ -3875,35 +3901,70 @@ pi_result cuda_piEnqueueEventsWaitWithBarrier(pi_queue command_queue,
38753901 pi_uint32 num_events_in_wait_list,
38763902 const pi_event *event_wait_list,
38773903 pi_event *event) {
3904+ // This function makes one stream work on the previous work (or work
3905+ // represented by input events) and then all future work waits on that stream.
38783906 if (!command_queue) {
38793907 return PI_ERROR_INVALID_QUEUE;
38803908 }
38813909
3910+ pi_result result;
3911+
38823912 try {
38833913 ScopedContext active (command_queue->get_context ());
3914+ pi_uint32 stream_token;
3915+ _pi_stream_guard guard;
3916+ CUstream cuStream = command_queue->get_next_compute_stream (
3917+ num_events_in_wait_list, event_wait_list, guard, &stream_token);
3918+ {
3919+ std::lock_guard (command_queue->barrier_mutex_ );
3920+ if (command_queue->barrier_event_ == nullptr ) {
3921+ PI_CHECK_ERROR (cuEventCreate (&command_queue->barrier_event_ ,
3922+ CU_EVENT_DISABLE_TIMING));
3923+ }
3924+ if (num_events_in_wait_list == 0 ) { // wait on all work
3925+ if (command_queue->barrier_tmp_event_ == nullptr ) {
3926+ PI_CHECK_ERROR (cuEventCreate (&command_queue->barrier_tmp_event_ ,
3927+ CU_EVENT_DISABLE_TIMING));
3928+ }
3929+ command_queue->sync_streams (
3930+ [cuStream,
3931+ tmp_event = command_queue->barrier_tmp_event_ ](CUstream s) {
3932+ if (cuStream != s) {
3933+ // record a new CUDA event on every stream and make one stream
3934+ // wait for these events
3935+ PI_CHECK_ERROR (cuEventRecord (tmp_event, s));
3936+ PI_CHECK_ERROR (cuStreamWaitEvent (cuStream, tmp_event, 0 ));
3937+ }
3938+ });
3939+ } else { // wait just on given events
3940+ forLatestEvents (event_wait_list, num_events_in_wait_list,
3941+ [cuStream](pi_event event) -> pi_result {
3942+ if (event->get_queue ()->has_been_synchronized (
3943+ event->get_compute_stream_token ())) {
3944+ return PI_SUCCESS;
3945+ } else {
3946+ return PI_CHECK_ERROR (
3947+ cuStreamWaitEvent (cuStream, event->get (), 0 ));
3948+ }
3949+ });
3950+ }
38843951
3885- if (event_wait_list) {
3886- auto result =
3887- forLatestEvents (event_wait_list, num_events_in_wait_list,
3888- [command_queue](pi_event event) -> pi_result {
3889- if (event->get_queue ()->has_been_synchronized (
3890- event->get_stream_token ())) {
3891- return PI_SUCCESS;
3892- } else {
3893- return enqueueEventWait (command_queue, event);
3894- }
3895- });
3896-
3897- if (result != PI_SUCCESS) {
3898- return result;
3952+ result = PI_CHECK_ERROR (
3953+ cuEventRecord (command_queue->barrier_event_ , cuStream));
3954+ for (unsigned int i = 0 ;
3955+ i < command_queue->compute_applied_barrier_ .size (); i++) {
3956+ command_queue->compute_applied_barrier_ [i] = false ;
3957+ }
3958+ for (unsigned int i = 0 ;
3959+ i < command_queue->transfer_applied_barrier_ .size (); i++) {
3960+ command_queue->transfer_applied_barrier_ [i] = false ;
38993961 }
39003962 }
3963+ if (result != PI_SUCCESS) {
3964+ return result;
3965+ }
39013966
39023967 if (event) {
3903- pi_uint32 stream_token;
3904- _pi_stream_guard guard;
3905- CUstream cuStream = command_queue->get_next_compute_stream (
3906- num_events_in_wait_list, event_wait_list, guard, &stream_token);
39073968 *event = _pi_event::make_native (PI_COMMAND_TYPE_MARKER, command_queue,
39083969 cuStream, stream_token);
39093970 (*event)->start ();
0 commit comments