diff --git a/Cargo.lock b/Cargo.lock index 2918af3a57..c919d3b8b4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3355,6 +3355,7 @@ dependencies = [ "rivet-util", "tracing", "universaldb", + "universalpubsub", "vbare", ] diff --git a/packages/core/pegboard-serverless/Cargo.toml b/packages/core/pegboard-serverless/Cargo.toml index ad985a1b4c..6112a75754 100644 --- a/packages/core/pegboard-serverless/Cargo.toml +++ b/packages/core/pegboard-serverless/Cargo.toml @@ -18,6 +18,7 @@ rivet-types.workspace = true rivet-util.workspace = true tracing.workspace = true universaldb.workspace = true +universalpubsub.workspace = true vbare.workspace = true namespace.workspace = true diff --git a/packages/core/pegboard-serverless/src/lib.rs b/packages/core/pegboard-serverless/src/lib.rs index 363de6ab7c..6ba8a5cd5f 100644 --- a/packages/core/pegboard-serverless/src/lib.rs +++ b/packages/core/pegboard-serverless/src/lib.rs @@ -19,6 +19,7 @@ use rivet_types::runner_configs::RunnerConfigKind; use tokio::{sync::oneshot, task::JoinHandle, time::Duration}; use universaldb::options::StreamingMode; use universaldb::utils::IsolationLevel::*; +use universalpubsub::PublishOpts; use vbare::OwnedVersionedData; const X_RIVET_ENDPOINT: HeaderName = HeaderName::from_static("x-rivet-endpoint"); @@ -27,6 +28,8 @@ const X_RIVET_TOTAL_SLOTS: HeaderName = HeaderName::from_static("x-rivet-total-s const X_RIVET_RUNNER_NAME: HeaderName = HeaderName::from_static("x-rivet-runner-name"); const X_RIVET_NAMESPACE_ID: HeaderName = HeaderName::from_static("x-rivet-namespace-id"); +const DRAIN_GRACE_PERIOD: Duration = Duration::from_secs(10); + struct OutboundConnection { handle: JoinHandle<()>, shutdown_tx: oneshot::Sender<()>, @@ -377,12 +380,14 @@ async fn outbound_handler( anyhow::Ok(()) }; + let sleep_until_drop = request_lifespan.saturating_sub(DRAIN_GRACE_PERIOD); tokio::select! { res = stream_handler => return res.map_err(Into::into), - _ = tokio::time::sleep(request_lifespan) => {} + _ = tokio::time::sleep(sleep_until_drop) => {} _ = shutdown_rx => {} } + // Stop runner draining.store(true, Ordering::SeqCst); ctx.msg(rivet_types::msgs::pegboard::BumpServerlessAutoscaler {}) @@ -394,34 +399,56 @@ async fn outbound_handler( } // Continue waiting on req while draining - while let Some(event) = source.next().await { - match event { - Ok(sse::Event::Open) => {} - Ok(sse::Event::Message(msg)) => { - tracing::debug!(%msg.data, "received outbound req message"); - - // If runner_id is none at this point it means we did not send the stopping signal yet, so - // send it now - if runner_id.is_none() { - let data = BASE64.decode(msg.data).context("invalid base64 message")?; - let payload = + let wait_for_shutdown_fut = async move { + while let Some(event) = source.next().await { + match event { + Ok(sse::Event::Open) => {} + Ok(sse::Event::Message(msg)) => { + tracing::debug!(%msg.data, "received outbound req message"); + + // If runner_id is none at this point it means we did not send the stopping signal yet, so + // send it now + if runner_id.is_none() { + let data = BASE64.decode(msg.data).context("invalid base64 message")?; + let payload = protocol::versioned::ToServerlessServer::deserialize_with_embedded_version( &data, ) .context("invalid payload")?; - match payload { - protocol::ToServerlessServer::ToServerlessServerInit(init) => { - let runner_id = - Id::parse(&init.runner_id).context("invalid runner id")?; - stop_runner(ctx, runner_id).await?; + match payload { + protocol::ToServerlessServer::ToServerlessServerInit(init) => { + let runner_id_local = + Id::parse(&init.runner_id).context("invalid runner id")?; + runner_id = Some(runner_id_local); + stop_runner(ctx, runner_id_local).await?; + } } } } + Err(sse::Error::StreamEnded) => break, + Err(err) => return Err(err.into()), } - Err(sse::Error::StreamEnded) => break, - Err(err) => return Err(err.into()), } + + Result::<()>::Ok(()) + }; + + // Wait for runner to shut down + tokio::select! { + res = wait_for_shutdown_fut => return res.map_err(Into::into), + _ = tokio::time::sleep(DRAIN_GRACE_PERIOD) => { + tracing::debug!("reached drain grace period before runner shut down") + } + + } + + // Close connection + // + // This will force the runner to stop the request in order to avoid hitting the serverless + // timeout threshold + if let Some(runner_id) = runner_id { + publish_to_client_stop(ctx, runner_id).await?; } tracing::debug!("outbound req stopped"); @@ -454,3 +481,22 @@ async fn stop_runner(ctx: &StandaloneCtx, runner_id: Id) -> Result<()> { Ok(()) } + +/// Send a stop message to the client. +/// +/// This will close the runner's WebSocket.. +async fn publish_to_client_stop(ctx: &StandaloneCtx, runner_id: Id) -> Result<()> { + let receiver_subject = + pegboard::pubsub_subjects::RunnerReceiverSubject::new(runner_id).to_string(); + + let message_serialized = rivet_runner_protocol::versioned::ToClient::latest( + rivet_runner_protocol::ToClient::ToClientClose, + ) + .serialize_with_embedded_version(rivet_runner_protocol::PROTOCOL_VERSION)?; + + ctx.ups()? + .publish(&receiver_subject, &message_serialized, PublishOpts::one()) + .await?; + + Ok(()) +} diff --git a/packages/infra/engine/src/commands/start.rs b/packages/infra/engine/src/commands/start.rs index d25735f9b8..5316072da8 100644 --- a/packages/infra/engine/src/commands/start.rs +++ b/packages/infra/engine/src/commands/start.rs @@ -14,7 +14,7 @@ pub struct Opts { /// Exclude the specified services instead of including them #[arg(long)] - exclude_services: bool, + except_services: Vec, } #[derive(clap::ValueEnum, Clone, PartialEq)] @@ -55,34 +55,37 @@ impl Opts { } // Select services to run - let services = if self.services.is_empty() { + let services = if self.services.is_empty() && self.except_services.is_empty() { // Run all services run_config.services.clone() + } else if !self.except_services.is_empty() { + // Exclude specified services + let except_service_kinds = self + .except_services + .iter() + .map(|x| x.clone().into()) + .collect::>(); + + run_config + .services + .iter() + .filter(|x| !except_service_kinds.iter().any(|y| y.eq(&x.kind))) + .cloned() + .collect::>() } else { - // Filter services + // Include only specified services let service_kinds = self .services .iter() .map(|x| x.clone().into()) .collect::>(); - if self.exclude_services { - // Exclude specified services - run_config - .services - .iter() - .filter(|x| !service_kinds.iter().any(|y| y.eq(&x.kind))) - .cloned() - .collect::>() - } else { - // Include only specified services - run_config - .services - .iter() - .filter(|x| service_kinds.iter().any(|y| y.eq(&x.kind))) - .cloned() - .collect::>() - } + run_config + .services + .iter() + .filter(|x| service_kinds.iter().any(|y| y.eq(&x.kind))) + .cloned() + .collect::>() }; // Start server diff --git a/packages/infra/engine/src/run_config.rs b/packages/infra/engine/src/run_config.rs index a149e9a257..e14d223c6a 100644 --- a/packages/infra/engine/src/run_config.rs +++ b/packages/infra/engine/src/run_config.rs @@ -6,7 +6,7 @@ pub fn config(_rivet_config: rivet_config::Config) -> Result { Service::new("api_peer", ServiceKind::ApiPeer, |config, pools| { Box::pin(rivet_api_peer::start(config, pools)) }), - Service::new("guard", ServiceKind::Standalone, |config, pools| { + Service::new("guard", ServiceKind::ApiPublic, |config, pools| { Box::pin(rivet_guard::start(config, pools)) }), Service::new( @@ -19,7 +19,8 @@ pub fn config(_rivet_config: rivet_config::Config) -> Result { }), Service::new( "pegboard_serverless", - ServiceKind::Standalone, + // There should only be one of these, since it's auto-scaling requests + ServiceKind::Singleton, |config, pools| Box::pin(pegboard_serverless::start(config, pools)), ), Service::new( diff --git a/packages/services/epoxy/src/workflows/coordinator/replica_status_change.rs b/packages/services/epoxy/src/workflows/coordinator/replica_status_change.rs index 6d8fde554d..7445616136 100644 --- a/packages/services/epoxy/src/workflows/coordinator/replica_status_change.rs +++ b/packages/services/epoxy/src/workflows/coordinator/replica_status_change.rs @@ -38,9 +38,9 @@ pub async fn replica_status_change( } #[tracing::instrument(skip_all)] -pub async fn replica_reconfigure( - ctx: &mut WorkflowCtx, -) -> Result<()> { +pub async fn replica_reconfigure(ctx: &mut WorkflowCtx) -> Result<()> { + ctx.activity(UpdateReplicaUrlsInput {}).await?; + let notify_out = ctx.activity(NotifyAllReplicasInput {}).await?; let replica_id = ctx.config().epoxy_replica_id(); @@ -108,6 +108,37 @@ pub async fn increment_epoch(ctx: &ActivityCtx, _input: &IncrementEpochInput) -> Ok(()) } +#[derive(Debug, Clone, Serialize, Deserialize, Hash)] +pub struct UpdateReplicaUrlsInput {} + +#[activity(UpdateReplicaUrls)] +pub async fn update_replica_urls(ctx: &ActivityCtx, _input: &UpdateReplicaUrlsInput) -> Result<()> { + let mut state = ctx.state::()?; + + // Update URLs for all replicas based on topology + for replica in state.config.replicas.iter_mut() { + let Some(dc) = ctx.config().dc_for_label(replica.replica_id as u16) else { + tracing::warn!( + replica_id = ?replica.replica_id, + "datacenter not found for replica, skipping url update" + ); + continue; + }; + + replica.api_peer_url = dc.peer_url.to_string(); + replica.guard_url = dc.public_url.to_string(); + + tracing::info!( + replica_id = ?replica.replica_id, + api_peer_url = ?dc.peer_url, + guard_url = ?dc.public_url, + "updated replica urls" + ); + } + + Ok(()) +} + #[derive(Debug, Clone, Serialize, Deserialize, Hash)] pub struct NotifyAllReplicasInput {}