Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions packages/core/pegboard-serverless/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ rivet-types.workspace = true
rivet-util.workspace = true
tracing.workspace = true
universaldb.workspace = true
universalpubsub.workspace = true
vbare.workspace = true

namespace.workspace = true
Expand Down
84 changes: 65 additions & 19 deletions packages/core/pegboard-serverless/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use rivet_types::runner_configs::RunnerConfigKind;
use tokio::{sync::oneshot, task::JoinHandle, time::Duration};
use universaldb::options::StreamingMode;
use universaldb::utils::IsolationLevel::*;
use universalpubsub::PublishOpts;
use vbare::OwnedVersionedData;

const X_RIVET_ENDPOINT: HeaderName = HeaderName::from_static("x-rivet-endpoint");
Expand All @@ -27,6 +28,8 @@ const X_RIVET_TOTAL_SLOTS: HeaderName = HeaderName::from_static("x-rivet-total-s
const X_RIVET_RUNNER_NAME: HeaderName = HeaderName::from_static("x-rivet-runner-name");
const X_RIVET_NAMESPACE_ID: HeaderName = HeaderName::from_static("x-rivet-namespace-id");

const DRAIN_GRACE_PERIOD: Duration = Duration::from_secs(10);

struct OutboundConnection {
handle: JoinHandle<()>,
shutdown_tx: oneshot::Sender<()>,
Expand Down Expand Up @@ -377,12 +380,14 @@ async fn outbound_handler(
anyhow::Ok(())
};

let sleep_until_drop = request_lifespan.saturating_sub(DRAIN_GRACE_PERIOD);
tokio::select! {
res = stream_handler => return res.map_err(Into::into),
_ = tokio::time::sleep(request_lifespan) => {}
_ = tokio::time::sleep(sleep_until_drop) => {}
_ = shutdown_rx => {}
}

// Stop runner
draining.store(true, Ordering::SeqCst);

ctx.msg(rivet_types::msgs::pegboard::BumpServerlessAutoscaler {})
Expand All @@ -394,34 +399,56 @@ async fn outbound_handler(
}

// Continue waiting on req while draining
while let Some(event) = source.next().await {
match event {
Ok(sse::Event::Open) => {}
Ok(sse::Event::Message(msg)) => {
tracing::debug!(%msg.data, "received outbound req message");

// If runner_id is none at this point it means we did not send the stopping signal yet, so
// send it now
if runner_id.is_none() {
let data = BASE64.decode(msg.data).context("invalid base64 message")?;
let payload =
let wait_for_shutdown_fut = async move {
while let Some(event) = source.next().await {
match event {
Ok(sse::Event::Open) => {}
Ok(sse::Event::Message(msg)) => {
tracing::debug!(%msg.data, "received outbound req message");

// If runner_id is none at this point it means we did not send the stopping signal yet, so
// send it now
if runner_id.is_none() {
let data = BASE64.decode(msg.data).context("invalid base64 message")?;
let payload =
protocol::versioned::ToServerlessServer::deserialize_with_embedded_version(
&data,
)
.context("invalid payload")?;

match payload {
protocol::ToServerlessServer::ToServerlessServerInit(init) => {
let runner_id =
Id::parse(&init.runner_id).context("invalid runner id")?;
stop_runner(ctx, runner_id).await?;
match payload {
protocol::ToServerlessServer::ToServerlessServerInit(init) => {
let runner_id_local =
Id::parse(&init.runner_id).context("invalid runner id")?;
runner_id = Some(runner_id_local);
stop_runner(ctx, runner_id_local).await?;
}
}
}
}
Err(sse::Error::StreamEnded) => break,
Err(err) => return Err(err.into()),
}
Err(sse::Error::StreamEnded) => break,
Err(err) => return Err(err.into()),
}

Result::<()>::Ok(())
};

// Wait for runner to shut down
tokio::select! {
res = wait_for_shutdown_fut => return res.map_err(Into::into),
_ = tokio::time::sleep(DRAIN_GRACE_PERIOD) => {
tracing::debug!("reached drain grace period before runner shut down")
}

}

// Close connection
//
// This will force the runner to stop the request in order to avoid hitting the serverless
// timeout threshold
if let Some(runner_id) = runner_id {
publish_to_client_stop(ctx, runner_id).await?;
}

tracing::debug!("outbound req stopped");
Expand Down Expand Up @@ -454,3 +481,22 @@ async fn stop_runner(ctx: &StandaloneCtx, runner_id: Id) -> Result<()> {

Ok(())
}

/// Send a stop message to the client.
///
/// This will close the runner's WebSocket..
async fn publish_to_client_stop(ctx: &StandaloneCtx, runner_id: Id) -> Result<()> {
let receiver_subject =
pegboard::pubsub_subjects::RunnerReceiverSubject::new(runner_id).to_string();

let message_serialized = rivet_runner_protocol::versioned::ToClient::latest(
rivet_runner_protocol::ToClient::ToClientClose,
)
.serialize_with_embedded_version(rivet_runner_protocol::PROTOCOL_VERSION)?;

ctx.ups()?
.publish(&receiver_subject, &message_serialized, PublishOpts::one())
.await?;

Ok(())
}
43 changes: 23 additions & 20 deletions packages/infra/engine/src/commands/start.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub struct Opts {

/// Exclude the specified services instead of including them
#[arg(long)]
exclude_services: bool,
except_services: Vec<ServiceKind>,
}

#[derive(clap::ValueEnum, Clone, PartialEq)]
Expand Down Expand Up @@ -55,34 +55,37 @@ impl Opts {
}

// Select services to run
let services = if self.services.is_empty() {
let services = if self.services.is_empty() && self.except_services.is_empty() {
// Run all services
run_config.services.clone()
} else if !self.except_services.is_empty() {
// Exclude specified services
let except_service_kinds = self
.except_services
.iter()
.map(|x| x.clone().into())
.collect::<Vec<rivet_service_manager::ServiceKind>>();

run_config
.services
.iter()
.filter(|x| !except_service_kinds.iter().any(|y| y.eq(&x.kind)))
.cloned()
.collect::<Vec<_>>()
} else {
// Filter services
// Include only specified services
let service_kinds = self
.services
.iter()
.map(|x| x.clone().into())
.collect::<Vec<rivet_service_manager::ServiceKind>>();

if self.exclude_services {
// Exclude specified services
run_config
.services
.iter()
.filter(|x| !service_kinds.iter().any(|y| y.eq(&x.kind)))
.cloned()
.collect::<Vec<_>>()
} else {
// Include only specified services
run_config
.services
.iter()
.filter(|x| service_kinds.iter().any(|y| y.eq(&x.kind)))
.cloned()
.collect::<Vec<_>>()
}
run_config
.services
.iter()
.filter(|x| service_kinds.iter().any(|y| y.eq(&x.kind)))
.cloned()
.collect::<Vec<_>>()
};

// Start server
Expand Down
5 changes: 3 additions & 2 deletions packages/infra/engine/src/run_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub fn config(_rivet_config: rivet_config::Config) -> Result<RunConfigData> {
Service::new("api_peer", ServiceKind::ApiPeer, |config, pools| {
Box::pin(rivet_api_peer::start(config, pools))
}),
Service::new("guard", ServiceKind::Standalone, |config, pools| {
Service::new("guard", ServiceKind::ApiPublic, |config, pools| {
Box::pin(rivet_guard::start(config, pools))
}),
Service::new(
Expand All @@ -19,7 +19,8 @@ pub fn config(_rivet_config: rivet_config::Config) -> Result<RunConfigData> {
}),
Service::new(
"pegboard_serverless",
ServiceKind::Standalone,
// There should only be one of these, since it's auto-scaling requests
ServiceKind::Singleton,
|config, pools| Box::pin(pegboard_serverless::start(config, pools)),
),
Service::new(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ pub async fn replica_status_change(
}

#[tracing::instrument(skip_all)]
pub async fn replica_reconfigure(
ctx: &mut WorkflowCtx,
) -> Result<()> {
pub async fn replica_reconfigure(ctx: &mut WorkflowCtx) -> Result<()> {
ctx.activity(UpdateReplicaUrlsInput {}).await?;

let notify_out = ctx.activity(NotifyAllReplicasInput {}).await?;

let replica_id = ctx.config().epoxy_replica_id();
Expand Down Expand Up @@ -108,6 +108,37 @@ pub async fn increment_epoch(ctx: &ActivityCtx, _input: &IncrementEpochInput) ->
Ok(())
}

#[derive(Debug, Clone, Serialize, Deserialize, Hash)]
pub struct UpdateReplicaUrlsInput {}

#[activity(UpdateReplicaUrls)]
pub async fn update_replica_urls(ctx: &ActivityCtx, _input: &UpdateReplicaUrlsInput) -> Result<()> {
let mut state = ctx.state::<State>()?;

// Update URLs for all replicas based on topology
for replica in state.config.replicas.iter_mut() {
let Some(dc) = ctx.config().dc_for_label(replica.replica_id as u16) else {
tracing::warn!(
replica_id = ?replica.replica_id,
"datacenter not found for replica, skipping url update"
);
continue;
};

replica.api_peer_url = dc.peer_url.to_string();
replica.guard_url = dc.public_url.to_string();

tracing::info!(
replica_id = ?replica.replica_id,
api_peer_url = ?dc.peer_url,
guard_url = ?dc.public_url,
"updated replica urls"
);
}

Ok(())
}

#[derive(Debug, Clone, Serialize, Deserialize, Hash)]
pub struct NotifyAllReplicasInput {}

Expand Down
Loading