Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/rivetkit/fixtures/driver-test-suite/sleep.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { actor, type UniversalWebSocket } from "rivetkit";
import { promiseWithResolvers } from "rivetkit/utils";

export const SLEEP_TIMEOUT = 500;
export const SLEEP_TIMEOUT = 1000;

export const sleep = actor({
state: { startCount: 0, sleepCount: 0 },
Expand Down
19 changes: 16 additions & 3 deletions packages/rivetkit/src/actor/instance.ts
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,9 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
this.#config.options.connectionLivenessInterval,
);
this.#checkConnectionsLiveness();

// Trigger any pending alarms
await this._onAlarm();
}

async #scheduleEventInner(newEvent: PersistedScheduleEvent) {
Expand Down Expand Up @@ -401,6 +404,12 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
}
}

/**
* Triggers any pending alarms.
*
* This method is idempotent. It's called automatically when the actor wakes
* in order to trigger any pending alarms.
*/
async _onAlarm() {
const now = Date.now();
this.actorContext.log.debug({
Expand All @@ -424,7 +433,7 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
this.#rLog.warn({ msg: "no events are due yet, time may have broken" });
if (this.#persist.scheduledEvents.length > 0) {
const nextTs = this.#persist.scheduledEvents[0].timestamp;
this.actorContext.log.warn({
this.actorContext.log.debug({
msg: "alarm fired early, rescheduling for next event",
now,
nextTs,
Expand Down Expand Up @@ -786,7 +795,7 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
}

/**
* Connection disconnected.
* Call when conn is disconnected. Used by transports.
*
* If a clean diconnect, will be removed immediately.
*
Expand All @@ -800,7 +809,7 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
// If socket ID is provided, check if it matches the current socket ID
// If it doesn't match, this is a stale disconnect event from an old socket
if (socketId && conn.__socket && socketId !== conn.__socket.socketId) {
this.rLog.debug({
this.#rLog.debug({
msg: "ignoring stale disconnect event",
connId: conn.id,
eventSocketId: socketId,
Expand All @@ -825,6 +834,9 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {

// Remove socket
conn.__socket = undefined;

// Update sleep
this.#resetSleepTimer();
}
}

Expand All @@ -848,6 +860,7 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {

// Remove from state
this.#connections.delete(conn.id);
this.#rLog.debug({ msg: "removed conn", connId: conn.id });

// Remove subscriptions
for (const eventName of [...conn.subscriptions.values()]) {
Expand Down
3 changes: 2 additions & 1 deletion packages/rivetkit/src/actor/router-endpoints.ts
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,8 @@ export async function handleWebSocketConnect(
// Handle cleanup asynchronously
handlersPromise
.then(({ conn, actor }) => {
actor.__connDisconnected(conn, event.wasClean, socketId);
const wasClean = event.wasClean || event.code === 1000;
actor.__connDisconnected(conn, wasClean, socketId);
})
.catch((error) => {
deconstructError(
Expand Down
51 changes: 27 additions & 24 deletions packages/rivetkit/src/actor/router.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ export type ActorRouter = Hono<{ Bindings: ActorRouterBindings }>;
export function createActorRouter(
runConfig: RunConfig,
actorDriver: ActorDriver,
isTest: boolean,
): ActorRouter {
const router = new Hono<{ Bindings: ActorRouterBindings }>({ strict: false });

Expand All @@ -84,37 +85,39 @@ export function createActorRouter(
return c.text("ok");
});

// Test endpoint to force disconnect a connection non-cleanly
router.post("/.test/force-disconnect", async (c) => {
const connId = c.req.query("conn");
if (isTest) {
// Test endpoint to force disconnect a connection non-cleanly
router.post("/.test/force-disconnect", async (c) => {
const connId = c.req.query("conn");

if (!connId) {
return c.text("Missing conn query parameter", 400);
}
if (!connId) {
return c.text("Missing conn query parameter", 400);
}

const actor = await actorDriver.loadActor(c.env.actorId);
const conn = actor.__getConnForId(connId);
const actor = await actorDriver.loadActor(c.env.actorId);
const conn = actor.__getConnForId(connId);

if (!conn) {
return c.text(`Connection not found: ${connId}`, 404);
}
if (!conn) {
return c.text(`Connection not found: ${connId}`, 404);
}

// Force close the websocket/SSE connection without clean shutdown
const driverState = conn.__driverState;
if (driverState && ConnDriverKind.WEBSOCKET in driverState) {
const ws = driverState[ConnDriverKind.WEBSOCKET].websocket;
// Force close the websocket/SSE connection without clean shutdown
const driverState = conn.__driverState;
if (driverState && ConnDriverKind.WEBSOCKET in driverState) {
const ws = driverState[ConnDriverKind.WEBSOCKET].websocket;

// Force close without sending close frame
(ws.raw as any).terminate();
} else if (driverState && ConnDriverKind.SSE in driverState) {
const stream = driverState[ConnDriverKind.SSE].stream;
// Force close without sending close frame
(ws.raw as any).terminate();
} else if (driverState && ConnDriverKind.SSE in driverState) {
const stream = driverState[ConnDriverKind.SSE].stream;

// Force close the SSE stream
stream.abort();
}
// Force close the SSE stream
stream.abort();
}

return c.json({ success: true });
});
return c.json({ success: true });
});
}

router.get(PATH_CONNECT_WEBSOCKET, async (c) => {
const upgradeWebSocket = runConfig.getUpgradeWebSocket?.();
Expand Down
6 changes: 5 additions & 1 deletion packages/rivetkit/src/client/actor-conn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,8 @@ enc
if (!this.#transport) {
// Nothing to do
} else if ("websocket" in this.#transport) {
logger().debug("closing ws");

const ws = this.#transport.websocket;
// Check if WebSocket is already closed or closing
if (
Expand All @@ -927,10 +929,12 @@ enc
logger().debug({ msg: "ws closed" });
resolve(undefined);
});
ws.close();
ws.close(1000, "Normal closure");
await promise;
}
} else if ("sse" in this.#transport) {
logger().debug("closing sse");

// Send close request to server for SSE connections
if (this.#connectionId && this.#connectionToken) {
try {
Expand Down
13 changes: 11 additions & 2 deletions packages/rivetkit/src/driver-test-suite/mod.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ export interface SkipTests {
schedule?: boolean;
sleep?: boolean;
sse?: boolean;
inline?: boolean;
}

export interface DriverTestConfig {
Expand Down Expand Up @@ -79,7 +80,10 @@ export interface DriverDeployOutput {
export function runDriverTests(
driverTestConfigPartial: Omit<DriverTestConfig, "clientType" | "transport">,
) {
for (const clientType of ["http", "inline"] as ClientType[]) {
const clientTypes: ClientType[] = driverTestConfigPartial.skip?.inline
? ["http"]
: ["http", "inline"];
for (const clientType of clientTypes) {
const driverTestConfig: DriverTestConfig = {
...driverTestConfigPartial,
clientType,
Expand Down Expand Up @@ -148,7 +152,12 @@ export function runDriverTests(
export async function createTestRuntime(
registryPath: string,
driverFactory: (registry: Registry<any>) => Promise<{
rivetEngine?: { endpoint: string; namespace: string; runnerName: string };
rivetEngine?: {
endpoint: string;
namespace: string;
runnerName: string;
token: string;
};
driver: DriverConfig;
cleanup?: () => Promise<void>;
}>,
Expand Down
51 changes: 14 additions & 37 deletions packages/rivetkit/src/driver-test-suite/tests/actor-schedule.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@ export function runActorScheduleTests(driverTestConfig: DriverTestConfig) {
describe.skipIf(driverTestConfig.skip?.schedule)(
"Actor Schedule Tests",
() => {
// See alarm + actor sleeping test in actor-sleep.ts

describe("Scheduled Alarms", () => {
test("executes c.schedule.at() with specific timestamp", async (c) => {
const { client } = await setupDriverTest(c, driverTestConfig);

// Create instance
const scheduled = client.scheduled.getOrCreate();

// Schedule a task to run in 100ms using timestamp
const timestamp = Date.now() + 100;
// Schedule a task to run using timestamp
const timestamp = Date.now() + 250;
await scheduled.scheduleTaskAt(timestamp);

// Wait for longer than the scheduled time
await waitFor(driverTestConfig, 200);
await waitFor(driverTestConfig, 500);

// Verify the scheduled task ran
const lastRun = await scheduled.getLastRun();
Expand All @@ -34,11 +36,11 @@ export function runActorScheduleTests(driverTestConfig: DriverTestConfig) {
// Create instance
const scheduled = client.scheduled.getOrCreate();

// Schedule a task to run in 100ms using delay
await scheduled.scheduleTaskAfter(100);
// Schedule a task to run using delay
await scheduled.scheduleTaskAfter(250);

// Wait for longer than the scheduled time
await waitFor(driverTestConfig, 200);
await waitFor(driverTestConfig, 500);

// Verify the scheduled task ran
const lastRun = await scheduled.getLastRun();
Expand All @@ -48,31 +50,6 @@ export function runActorScheduleTests(driverTestConfig: DriverTestConfig) {
expect(scheduledCount).toBe(1);
});

test("scheduled tasks persist across actor restarts", async (c) => {
const { client } = await setupDriverTest(c, driverTestConfig);

// Create instance and schedule
const scheduled = client.scheduled.getOrCreate();
await scheduled.scheduleTaskAfter(200);

// Wait a little so the schedule is stored but hasn't triggered yet
await waitFor(driverTestConfig, 100);

// Get a new reference to simulate actor restart
const newInstance = client.scheduled.getOrCreate();

// Verify the schedule still exists but hasn't run yet
const initialCount = await newInstance.getScheduledCount();
expect(initialCount).toBe(0);

// Wait for the scheduled task to execute
await waitFor(driverTestConfig, 200);

// Verify the scheduled task ran after "restart"
const scheduledCount = await newInstance.getScheduledCount();
expect(scheduledCount).toBe(1);
});

test("multiple scheduled tasks execute in order", async (c) => {
const { client } = await setupDriverTest(c, driverTestConfig);

Expand All @@ -83,22 +60,22 @@ export function runActorScheduleTests(driverTestConfig: DriverTestConfig) {
await scheduled.clearHistory();

// Schedule multiple tasks with different delays
await scheduled.scheduleTaskAfterWithId("first", 100);
await scheduled.scheduleTaskAfterWithId("second", 300);
await scheduled.scheduleTaskAfterWithId("third", 500);
await scheduled.scheduleTaskAfterWithId("first", 250);
await scheduled.scheduleTaskAfterWithId("second", 750);
await scheduled.scheduleTaskAfterWithId("third", 1250);

// Wait for first task only
await waitFor(driverTestConfig, 200);
await waitFor(driverTestConfig, 500);
const history1 = await scheduled.getTaskHistory();
expect(history1).toEqual(["first"]);

// Wait for second task
await waitFor(driverTestConfig, 200);
await waitFor(driverTestConfig, 500);
const history2 = await scheduled.getTaskHistory();
expect(history2).toEqual(["first", "second"]);

// Wait for third task
await waitFor(driverTestConfig, 200);
await waitFor(driverTestConfig, 500);
const history3 = await scheduled.getTaskHistory();
expect(history3).toEqual(["first", "second", "third"]);
});
Expand Down
Loading
Loading