@@ -19,6 +19,7 @@ use rivet_types::runner_configs::RunnerConfigKind;
1919use tokio:: { sync:: oneshot, task:: JoinHandle , time:: Duration } ;
2020use universaldb:: options:: StreamingMode ;
2121use universaldb:: utils:: IsolationLevel :: * ;
22+ use universalpubsub:: PublishOpts ;
2223use vbare:: OwnedVersionedData ;
2324
2425const X_RIVET_ENDPOINT : HeaderName = HeaderName :: from_static ( "x-rivet-endpoint" ) ;
@@ -27,6 +28,8 @@ const X_RIVET_TOTAL_SLOTS: HeaderName = HeaderName::from_static("x-rivet-total-s
2728const X_RIVET_RUNNER_NAME : HeaderName = HeaderName :: from_static ( "x-rivet-runner-name" ) ;
2829const X_RIVET_NAMESPACE_ID : HeaderName = HeaderName :: from_static ( "x-rivet-namespace-id" ) ;
2930
31+ const DRAIN_GRACE_PERIOD : Duration = Duration :: from_secs ( 10 ) ;
32+
3033struct OutboundConnection {
3134 handle : JoinHandle < ( ) > ,
3235 shutdown_tx : oneshot:: Sender < ( ) > ,
@@ -377,12 +380,14 @@ async fn outbound_handler(
377380 anyhow:: Ok ( ( ) )
378381 } ;
379382
383+ let sleep_until_drop = request_lifespan. saturating_sub ( DRAIN_GRACE_PERIOD ) ;
380384 tokio:: select! {
381385 res = stream_handler => return res. map_err( Into :: into) ,
382- _ = tokio:: time:: sleep( request_lifespan ) => { }
386+ _ = tokio:: time:: sleep( sleep_until_drop ) => { }
383387 _ = shutdown_rx => { }
384388 }
385389
390+ // Stop runner
386391 draining. store ( true , Ordering :: SeqCst ) ;
387392
388393 ctx. msg ( rivet_types:: msgs:: pegboard:: BumpServerlessAutoscaler { } )
@@ -394,34 +399,56 @@ async fn outbound_handler(
394399 }
395400
396401 // Continue waiting on req while draining
397- while let Some ( event) = source. next ( ) . await {
398- match event {
399- Ok ( sse:: Event :: Open ) => { }
400- Ok ( sse:: Event :: Message ( msg) ) => {
401- tracing:: debug!( %msg. data, "received outbound req message" ) ;
402-
403- // If runner_id is none at this point it means we did not send the stopping signal yet, so
404- // send it now
405- if runner_id. is_none ( ) {
406- let data = BASE64 . decode ( msg. data ) . context ( "invalid base64 message" ) ?;
407- let payload =
402+ let wait_for_shutdown_fut = async move {
403+ while let Some ( event) = source. next ( ) . await {
404+ match event {
405+ Ok ( sse:: Event :: Open ) => { }
406+ Ok ( sse:: Event :: Message ( msg) ) => {
407+ tracing:: debug!( %msg. data, "received outbound req message" ) ;
408+
409+ // If runner_id is none at this point it means we did not send the stopping signal yet, so
410+ // send it now
411+ if runner_id. is_none ( ) {
412+ let data = BASE64 . decode ( msg. data ) . context ( "invalid base64 message" ) ?;
413+ let payload =
408414 protocol:: versioned:: ToServerlessServer :: deserialize_with_embedded_version (
409415 & data,
410416 )
411417 . context ( "invalid payload" ) ?;
412418
413- match payload {
414- protocol:: ToServerlessServer :: ToServerlessServerInit ( init) => {
415- let runner_id =
416- Id :: parse ( & init. runner_id ) . context ( "invalid runner id" ) ?;
417- stop_runner ( ctx, runner_id) . await ?;
419+ match payload {
420+ protocol:: ToServerlessServer :: ToServerlessServerInit ( init) => {
421+ let runner_id_local =
422+ Id :: parse ( & init. runner_id ) . context ( "invalid runner id" ) ?;
423+ runner_id = Some ( runner_id_local) ;
424+ stop_runner ( ctx, runner_id_local) . await ?;
425+ }
418426 }
419427 }
420428 }
429+ Err ( sse:: Error :: StreamEnded ) => break ,
430+ Err ( err) => return Err ( err. into ( ) ) ,
421431 }
422- Err ( sse:: Error :: StreamEnded ) => break ,
423- Err ( err) => return Err ( err. into ( ) ) ,
424432 }
433+
434+ Result :: < ( ) > :: Ok ( ( ) )
435+ } ;
436+
437+ // Wait for runner to shut down
438+ tokio:: select! {
439+ res = wait_for_shutdown_fut => return res. map_err( Into :: into) ,
440+ _ = tokio:: time:: sleep( DRAIN_GRACE_PERIOD ) => {
441+ tracing:: debug!( "reached drain grace period before runner shut down" )
442+ }
443+
444+ }
445+
446+ // Close connection
447+ //
448+ // This will force the runner to stop the request in order to avoid hitting the serverless
449+ // timeout threshold
450+ if let Some ( runner_id) = runner_id {
451+ publish_to_client_stop ( ctx, runner_id) . await ?;
425452 }
426453
427454 tracing:: debug!( "outbound req stopped" ) ;
@@ -454,3 +481,22 @@ async fn stop_runner(ctx: &StandaloneCtx, runner_id: Id) -> Result<()> {
454481
455482 Ok ( ( ) )
456483}
484+
485+ /// Send a stop message to the client.
486+ ///
487+ /// This will close the runner's WebSocket..
488+ async fn publish_to_client_stop ( ctx : & StandaloneCtx , runner_id : Id ) -> Result < ( ) > {
489+ let receiver_subject =
490+ pegboard:: pubsub_subjects:: RunnerReceiverSubject :: new ( runner_id) . to_string ( ) ;
491+
492+ let message_serialized = rivet_runner_protocol:: versioned:: ToClient :: latest (
493+ rivet_runner_protocol:: ToClient :: ToClientClose ,
494+ )
495+ . serialize_with_embedded_version ( rivet_runner_protocol:: PROTOCOL_VERSION ) ?;
496+
497+ ctx. ups ( ) ?
498+ . publish ( & receiver_subject, & message_serialized, PublishOpts :: one ( ) )
499+ . await ?;
500+
501+ Ok ( ( ) )
502+ }
0 commit comments