Skip to content

Commit 62f907e

Browse files
committed
fix(udb): optimize pg driver
1 parent 2a55f88 commit 62f907e

File tree

10 files changed

+117
-38
lines changed

10 files changed

+117
-38
lines changed

docker/dev/rivet-engine/config.jsonc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,6 @@
2424
"postgres": {
2525
"url": "postgresql://postgres:postgres@postgres:5432/rivet_engine"
2626
},
27-
"memory": {
28-
"channel": "default"
29-
},
3027
"cache": {
3128
"driver": "in_memory"
3229
},

out/openapi.json

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

packages/common/universaldb/src/driver/postgres/database.rs

Lines changed: 54 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
use std::sync::{Arc, Mutex};
1+
use std::{
2+
sync::{Arc, Mutex},
3+
time::Duration,
4+
};
25

36
use anyhow::{Context, Result};
47
use deadpool_postgres::{Config, ManagerConfig, Pool, PoolConfig, RecyclingMethod, Runtime};
8+
use tokio::task::JoinHandle;
59
use tokio_postgres::NoTls;
610

711
use crate::{
@@ -14,9 +18,13 @@ use crate::{
1418

1519
use super::transaction::PostgresTransactionDriver;
1620

21+
const TXN_TIMEOUT: Duration = Duration::from_secs(5);
22+
const GC_INTERVAL: Duration = Duration::from_secs(5);
23+
1724
pub struct PostgresDatabaseDriver {
1825
pool: Arc<Pool>,
1926
max_retries: Arc<Mutex<i32>>,
27+
gc_handle: JoinHandle<()>,
2028
}
2129

2230
impl PostgresDatabaseDriver {
@@ -53,7 +61,7 @@ impl PostgresDatabaseDriver {
5361
.context("failed to create btree_gist extension")?;
5462

5563
conn.execute(
56-
"CREATE SEQUENCE IF NOT EXISTS global_version_seq START WITH 1 INCREMENT BY 1 MINVALUE 1",
64+
"CREATE UNLOGGED SEQUENCE IF NOT EXISTS global_version_seq START WITH 1 INCREMENT BY 1 MINVALUE 1",
5765
&[],
5866
)
5967
.await
@@ -123,12 +131,39 @@ impl PostgresDatabaseDriver {
123131
.await
124132
.context("failed to create conflict_ranges table")?;
125133

126-
// Connection is automatically returned to the pool when dropped
127-
drop(conn);
134+
// Create index on ts column for efficient garbage collection
135+
conn.execute(
136+
"CREATE INDEX IF NOT EXISTS idx_conflict_ranges_ts ON conflict_ranges (ts)",
137+
&[],
138+
)
139+
.await
140+
.context("failed to create index on conflict_ranges ts column")?;
141+
142+
let gc_handle = tokio::spawn(async move {
143+
let mut interval = tokio::time::interval(GC_INTERVAL);
144+
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
145+
146+
loop {
147+
interval.tick().await;
148+
149+
// NOTE: Transactions have a max limit of 5 seconds, we delete after 10 seconds for extra padding
150+
// Delete old conflict ranges
151+
if let Err(err) = conn
152+
.execute(
153+
"DELETE FROM conflict_ranges where ts < now() - interval '10 seconds'",
154+
&[],
155+
)
156+
.await
157+
{
158+
tracing::error!(?err, "failed postgres gc task");
159+
}
160+
}
161+
});
128162

129163
Ok(PostgresDatabaseDriver {
130164
pool: Arc::new(pool),
131165
max_retries: Arc::new(Mutex::new(100)),
166+
gc_handle,
132167
})
133168
}
134169
}
@@ -155,13 +190,15 @@ impl DatabaseDriver for PostgresDatabaseDriver {
155190
retryable.maybe_committed = maybe_committed;
156191

157192
// Execute transaction
158-
let error = match closure(retryable.clone()).await {
159-
Ok(res) => match retryable.inner.driver.commit_ref().await {
160-
Ok(_) => return Ok(res),
161-
Err(e) => e,
162-
},
163-
Err(e) => e,
164-
};
193+
let error =
194+
match tokio::time::timeout(TXN_TIMEOUT, closure(retryable.clone())).await {
195+
Ok(Ok(res)) => match retryable.inner.driver.commit_ref().await {
196+
Ok(_) => return Ok(res),
197+
Err(e) => e,
198+
},
199+
Ok(Err(e)) => e,
200+
Err(e) => anyhow::Error::from(DatabaseError::TransactionTooOld),
201+
};
165202

166203
let chain = error
167204
.chain()
@@ -196,3 +233,9 @@ impl DatabaseDriver for PostgresDatabaseDriver {
196233
}
197234
}
198235
}
236+
237+
impl Drop for PostgresDatabaseDriver {
238+
fn drop(&mut self) {
239+
self.gc_handle.abort();
240+
}
241+
}

packages/common/universaldb/src/driver/postgres/transaction_task.rs

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -330,11 +330,20 @@ impl TransactionTask {
330330
tx: Transaction<'_>,
331331
start_version: i64,
332332
operations: Vec<Operation>,
333-
conflict_ranges: Vec<(Vec<u8>, Vec<u8>, ConflictRangeType)>,
333+
mut conflict_ranges: Vec<(Vec<u8>, Vec<u8>, ConflictRangeType)>,
334334
) -> Result<()> {
335-
let commit_version = tx
336-
.query_one("SELECT nextval('global_version_seq')", &[])
337-
.await
335+
// // Defer all constraint checks until commit
336+
// tx.execute("SET CONSTRAINTS ALL DEFERRED", &[])
337+
// .await
338+
// .map_err(map_postgres_error)?;
339+
340+
let (_, _, version_res) = tokio::join!(
341+
tx.execute("SET LOCAL lock_timeout = '0'", &[],),
342+
tx.execute("SET LOCAL deadlock_timeout = '10ms'", &[],),
343+
tx.query_one("SELECT nextval('global_version_seq')", &[]),
344+
);
345+
346+
let commit_version = version_res
338347
.context("failed to get postgres txn commit_version")?
339348
.get::<_, i64>(0);
340349

@@ -355,7 +364,7 @@ impl TransactionTask {
355364

356365
let query = "
357366
INSERT INTO conflict_ranges (range_data, conflict_type, start_version, commit_version)
358-
SELECT
367+
SELECT
359368
bytearange(begin_key, end_key, '[)'),
360369
conflict_type::range_type,
361370
$4,
@@ -377,13 +386,22 @@ impl TransactionTask {
377386
.await
378387
.map_err(map_postgres_error)?;
379388

380-
// TODO: Parallelize
381389
for op in operations {
382390
match op {
383391
Operation::Set { key, value } => {
384392
// TODO: versionstamps need to be calculated on the sql side, not in rust
385393
let value = substitute_versionstamp_if_incomplete(value.clone(), 0);
386394

395+
// // Poor man's upsert, you cant use ON CONFLICT with deferred constraints
396+
// let query = "WITH updated AS (
397+
// UPDATE kv
398+
// SET value = $2
399+
// WHERE key = $1
400+
// RETURNING 1
401+
// )
402+
// INSERT INTO kv (key, value)
403+
// SELECT $1, $2
404+
// WHERE NOT EXISTS (SELECT 1 FROM updated)";
387405
let query = "INSERT INTO kv (key, value) VALUES ($1, $2) ON CONFLICT (key) DO UPDATE SET value = $2";
388406
let stmt = tx.prepare_cached(query).await.map_err(map_postgres_error)?;
389407

@@ -435,6 +453,16 @@ impl TransactionTask {
435453

436454
// Store the result
437455
if let Some(new_value) = new_value {
456+
// // Poor man's upsert, you cant use ON CONFLICT with deferred constraints
457+
// let update_query = "WITH updated AS (
458+
// UPDATE kv
459+
// SET value = $2
460+
// WHERE key = $1
461+
// RETURNING 1
462+
// )
463+
// INSERT INTO kv (key, value)
464+
// SELECT $1, $2
465+
// WHERE NOT EXISTS (SELECT 1 FROM updated)";
438466
let update_query = "INSERT INTO kv (key, value) VALUES ($1, $2) ON CONFLICT (key) DO UPDATE SET value = $2";
439467
let stmt = tx
440468
.prepare_cached(update_query)

packages/common/universaldb/src/options.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ pub enum MutationType {
264264
/// Performs an atomic ``compare and clear`` operation. If the existing value in the database is equal to the given value, then given key is cleared.
265265
CompareAndClear,
266266
}
267-
#[derive(Clone, Copy, Debug)]
267+
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
268268
#[non_exhaustive]
269269
pub enum ConflictRangeType {
270270
/// Used to add a read conflict range

packages/common/universaldb/src/tx_ops.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,17 @@ pub enum Operation {
3434
},
3535
}
3636

37+
impl Operation {
38+
pub fn sorting_key(&self) -> &[u8] {
39+
match self {
40+
Operation::Set { key, .. } => key,
41+
Operation::Clear { key } => key,
42+
Operation::ClearRange { begin, .. } => begin,
43+
Operation::AtomicOp { key, .. } => key,
44+
}
45+
}
46+
}
47+
3748
#[derive(Debug, Clone)]
3849
pub enum GetOutput {
3950
Value(Vec<u8>),

packages/core/guard/core/src/websocket_handle.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ impl WebSocketHandleInner {
5353
let mut state = self.state.lock().await;
5454
match &mut *state {
5555
WebSocketState::Unaccepted { .. } | WebSocketState::Accepting => {
56-
bail!("websocket has not been accepted")
56+
bail!("websocket has not been accepted");
5757
}
5858
WebSocketState::Split { ws_tx } => {
5959
ws_tx.send(message).await?;

scripts/tests/actor_e2e.ts

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env tsx
22

3-
import { RIVET_ENDPOINT, createActor, destroyActor } from "./utils";
3+
import { RIVET_ENDPOINT, RIVET_TOKEN, createActor, destroyActor } from "./utils";
44

55
async function main() {
66
try {
@@ -31,9 +31,7 @@ async function main() {
3131

3232
console.log("Actor ping response:", pingResult);
3333

34-
// Test WebSocket connection
35-
console.log("Testing WebSocket connection to actor...");
36-
// await testWebSocket(actorResponse.actor.actor_id);
34+
await testWebSocket(actorResponse.actor.actor_id);
3735

3836
console.log("Destroying actor...");
3937
await destroyActor("default", actorResponse.actor.actor_id);
@@ -49,6 +47,8 @@ async function main() {
4947
}
5048

5149
function testWebSocket(actorId: string): Promise<void> {
50+
console.log("Testing WebSocket connection to actor...");
51+
5252
return new Promise((resolve, reject) => {
5353
// Parse the RIVET_ENDPOINT to get WebSocket URL
5454
const wsEndpoint = RIVET_ENDPOINT.replace("http://", "ws://").replace(
@@ -59,7 +59,7 @@ function testWebSocket(actorId: string): Promise<void> {
5959

6060
console.log(`Connecting WebSocket to: ${wsUrl}`);
6161

62-
const protocols = ["rivet", "rivet_target.actor", `rivet_actor.${actorId}`];
62+
const protocols = ["rivet", "rivet_target.actor", `rivet_actor.${actorId}`, `rivet_token.${RIVET_TOKEN}`];
6363
const ws = new WebSocket(wsUrl, protocols);
6464

6565
let pingReceived = false;
@@ -81,9 +81,9 @@ function testWebSocket(actorId: string): Promise<void> {
8181
ws.send("ping");
8282
});
8383

84-
ws.addEventListener("message", (data) => {
85-
const message = data.toString();
86-
console.log(`WebSocket received raw data:`, data);
84+
ws.addEventListener("message", (ev) => {
85+
const message = ev.data.toString();
86+
console.log(`WebSocket received raw data:`, ev.data);
8787
console.log(`WebSocket received message: "${message}"`);
8888

8989
if (

scripts/tests/spam_actors.ts

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#!/usr/bin/env tsx
22

3-
import { RIVET_ENDPOINT, createActor, destroyActor } from "./utils";
3+
import { RIVET_ENDPOINT, RIVET_TOKEN, createActor, destroyActor } from "./utils";
44

55
const ACTORS = parseInt(process.argv[2]) || 15;
66

@@ -44,7 +44,7 @@ async function testActor(i: number) {
4444

4545
console.log(`Actor ${i} ping response:`, pingResult);
4646

47-
// await testWebSocket(actorResponse.actor.actor_id);
47+
await testWebSocket(actorResponse.actor.actor_id);
4848

4949
console.log(`Destroying actor ${i}...`);
5050
await destroyActor("default", actorResponse.actor.actor_id);
@@ -66,7 +66,7 @@ function testWebSocket(actorId: string): Promise<void> {
6666

6767
console.log(`Connecting WebSocket to: ${wsUrl}`);
6868

69-
const protocols = ["rivet", "rivet_target.actor", `rivet_actor.${actorId}`];
69+
const protocols = ["rivet", "rivet_target.actor", `rivet_actor.${actorId}`, `rivet_token.${RIVET_TOKEN}`];
7070
const ws = new WebSocket(wsUrl, protocols);
7171

7272
let pingReceived = false;
@@ -88,9 +88,9 @@ function testWebSocket(actorId: string): Promise<void> {
8888
ws.send("ping");
8989
});
9090

91-
ws.addEventListener("message", (data) => {
92-
const message = data.toString();
93-
console.log(`WebSocket received raw data:`, data);
91+
ws.addEventListener("message", (ev) => {
92+
const message = ev.data.toString();
93+
console.log(`WebSocket received raw data:`, ev.data);
9494
console.log(`WebSocket received message: "${message}"`);
9595

9696
if (

scripts/tests/utils.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
export const RIVET_ENDPOINT =
22
process.env.RIVET_ENDPOINT ?? "http://localhost:6420";
3-
const RIVET_TOKEN = process.env.RIVET_TOKEN ?? "dev";
3+
export const RIVET_TOKEN = process.env.RIVET_TOKEN ?? "dev";
44

55
export async function createActor(
66
namespaceName: string,

0 commit comments

Comments
 (0)