From d98968e7ea3ddfea0618ee12751c137e02fa6dc0 Mon Sep 17 00:00:00 2001 From: Kacper Wojciechowski <39823706+jog1t@users.noreply.github.com> Date: Tue, 29 Jul 2025 00:04:40 +0200 Subject: [PATCH] feat: subscribe to events before connect --- .../fixtures/driver-test-suite/counter.ts | 4 ++ packages/rivetkit/src/actor/instance.ts | 14 +++++++ packages/rivetkit/src/actor/protocol/serde.ts | 1 + .../rivetkit/src/actor/router-endpoints.ts | 8 ++++ packages/rivetkit/src/actor/router.ts | 25 ++++++++++++- packages/rivetkit/src/client/actor-conn.ts | 18 +++++++-- packages/rivetkit/src/client/actor-handle.ts | 37 ++++++++++++++++++- packages/rivetkit/src/client/client.ts | 24 ++++++++++++ .../src/driver-test-suite/tests/actor-conn.ts | 29 ++++++++++++++- .../src/drivers/engine/actor-driver.ts | 11 +++++- .../src/drivers/file-system/manager.ts | 4 ++ packages/rivetkit/src/manager/driver.ts | 2 + .../rivetkit/src/manager/protocol/query.ts | 22 ++++++++++- 13 files changed, 188 insertions(+), 11 deletions(-) diff --git a/packages/rivetkit/fixtures/driver-test-suite/counter.ts b/packages/rivetkit/fixtures/driver-test-suite/counter.ts index fd653c007..54dfa7916 100644 --- a/packages/rivetkit/fixtures/driver-test-suite/counter.ts +++ b/packages/rivetkit/fixtures/driver-test-suite/counter.ts @@ -2,6 +2,10 @@ import { actor } from "rivetkit"; export const counter = actor({ state: { count: 0 }, + onConnect: (c, conn) => { + c.broadcast("onconnect:broadcast", "Hello!"); + conn.send("onconnect:msg", "Welcome to the counter actor!"); + }, actions: { increment: (c, x: number) => { c.state.count += x; diff --git a/packages/rivetkit/src/actor/instance.ts b/packages/rivetkit/src/actor/instance.ts index ae90e88e9..b5242faac 100644 --- a/packages/rivetkit/src/actor/instance.ts +++ b/packages/rivetkit/src/actor/instance.ts @@ -917,6 +917,7 @@ export class ActorInstance { state: CS, driverId: ConnectionDriver, driverState: unknown, + subscriptions: string[], authData: unknown, ): Promise> { this.#assertReady(); @@ -950,6 +951,11 @@ export class ActorInstance { // // Do this immediately after adding connection & before any async logic in order to avoid race conditions with sleep timeouts this.#resetSleepTimer(); + if (subscriptions) { + for (const sub of subscriptions) { + this.#addSubscription(sub, conn, true); + } + } // Add to persistence & save immediately this.#persist.connections.push(persist); @@ -1017,6 +1023,7 @@ export class ActorInstance { return await this.executeAction(ctx, name, args); }, onSubscribe: async (eventName, conn) => { + console.log("subscribing to event", { eventName, connId: conn.id }); this.inspector.emitter.emit("eventFired", { type: "subscribe", eventName, @@ -1489,6 +1496,13 @@ export class ActorInstance { _broadcast>(name: string, ...args: Args) { this.#assertReady(); + console.log("broadcasting event", { + name, + args, + actorId: this.id, + subscriptions: this.#subscriptionIndex.size, + connections: this.conns.size, + }); this.inspector.emitter.emit("eventFired", { type: "broadcast", eventName: name, diff --git a/packages/rivetkit/src/actor/protocol/serde.ts b/packages/rivetkit/src/actor/protocol/serde.ts index bda7cf419..aa3c6401a 100644 --- a/packages/rivetkit/src/actor/protocol/serde.ts +++ b/packages/rivetkit/src/actor/protocol/serde.ts @@ -13,6 +13,7 @@ export type InputData = string | Buffer | Blob | ArrayBufferLike | Uint8Array; export type OutputData = string | Uint8Array; export const EncodingSchema = z.enum(["json", "cbor", "bare"]); +export const SubscriptionsListSchema = z.array(z.string()); /** * Encoding used to communicate between the client & actor. diff --git a/packages/rivetkit/src/actor/router-endpoints.ts b/packages/rivetkit/src/actor/router-endpoints.ts index ba49e3ea0..0985cd60f 100644 --- a/packages/rivetkit/src/actor/router-endpoints.ts +++ b/packages/rivetkit/src/actor/router-endpoints.ts @@ -113,6 +113,7 @@ export async function handleWebSocketConnect( encoding: Encoding, parameters: unknown, authData: unknown, + subscriptions: string[], ): Promise { const exposeInternalError = req ? getRequestExposeInternalError(req) : false; @@ -182,6 +183,7 @@ export async function handleWebSocketConnect( connState, CONNECTION_DRIVER_WEBSOCKET, { encoding } satisfies GenericWebSocketDriverState, + subscriptions, authData, ); @@ -332,6 +334,7 @@ export async function handleSseConnect( _runConfig: RunConfig, actorDriver: ActorDriver, actorId: string, + subscriptions: string[], authData: unknown, ) { const encoding = getRequestEncoding(c.req); @@ -367,6 +370,7 @@ export async function handleSseConnect( connState, CONNECTION_DRIVER_SSE, { encoding } satisfies GenericSseDriverState, + subscriptions, authData, ); @@ -463,6 +467,7 @@ export async function handleAction( connState, CONNECTION_DRIVER_HTTP, {} satisfies GenericHttpDriverState, + [], authData, ); @@ -655,6 +660,8 @@ export const HEADER_CONN_ID = "X-RivetKit-Conn"; export const HEADER_CONN_TOKEN = "X-RivetKit-Conn-Token"; +export const HEADER_CONN_SUBS = "X-RivetKit-Conn-Subs"; + /** * Headers that publics can send from public clients. * @@ -669,6 +676,7 @@ export const ALLOWED_PUBLIC_HEADERS = [ HEADER_ACTOR_ID, HEADER_CONN_ID, HEADER_CONN_TOKEN, + HEADER_CONN_SUBS, ]; // Helper to get connection parameters for the request diff --git a/packages/rivetkit/src/actor/router.ts b/packages/rivetkit/src/actor/router.ts index 53025ba59..538f1df05 100644 --- a/packages/rivetkit/src/actor/router.ts +++ b/packages/rivetkit/src/actor/router.ts @@ -1,7 +1,10 @@ import { Hono, type Context as HonoContext } from "hono"; import { cors } from "hono/cors"; import invariant from "invariant"; -import { EncodingSchema } from "@/actor/protocol/serde"; +import { + EncodingSchema, + SubscriptionsListSchema, +} from "@/actor/protocol/serde"; import { type ActionOpts, type ActionOutput, @@ -13,6 +16,7 @@ import { HEADER_AUTH_DATA, HEADER_CONN_ID, HEADER_CONN_PARAMS, + HEADER_CONN_SUBS, HEADER_CONN_TOKEN, HEADER_ENCODING, handleAction, @@ -84,12 +88,16 @@ export function createActorRouter( const encodingRaw = c.req.header(HEADER_ENCODING); const connParamsRaw = c.req.header(HEADER_CONN_PARAMS); const authDataRaw = c.req.header(HEADER_AUTH_DATA); + const subsRaw = c.req.header(HEADER_CONN_SUBS); const encoding = EncodingSchema.parse(encodingRaw); const connParams = connParamsRaw ? JSON.parse(connParamsRaw) : undefined; const authData = authDataRaw ? JSON.parse(authDataRaw) : undefined; + const subs = subsRaw + ? SubscriptionsListSchema.parse(JSON.parse(subsRaw)) + : []; return await handleWebSocketConnect( c.req.raw, @@ -98,6 +106,7 @@ export function createActorRouter( c.env.actorId, encoding, connParams, + subs, authData, ); })(c, noopNext()); @@ -115,8 +124,20 @@ export function createActorRouter( if (authDataRaw) { authData = JSON.parse(authDataRaw); } + const subsRaw = c.req.header(HEADER_CONN_SUBS); + + const subscriptions = subsRaw + ? SubscriptionsListSchema.parse(JSON.parse(subsRaw)) + : []; - return handleSseConnect(c, runConfig, actorDriver, c.env.actorId, authData); + return handleSseConnect( + c, + runConfig, + actorDriver, + c.env.actorId, + subscriptions, + authData, + ); }); router.post("/action/:action", async (c) => { diff --git a/packages/rivetkit/src/client/actor-conn.ts b/packages/rivetkit/src/client/actor-conn.ts index dcd578455..c6986481b 100644 --- a/packages/rivetkit/src/client/actor-conn.ts +++ b/packages/rivetkit/src/client/actor-conn.ts @@ -113,7 +113,7 @@ export class ActorConnRaw { /** * Interval that keeps the NodeJS process alive if this is the only thing running. * - * See ttps://github.com/nodejs/node/issues/22088 + * @see https://github.com/nodejs/node/issues/22088 */ #keepNodeAliveInterval: NodeJS.Timeout; @@ -126,8 +126,6 @@ export class ActorConnRaw { #encoding: Encoding; #actorQuery: ActorQuery; - // TODO: ws message queue - /** * Do not call this directly. * @@ -203,7 +201,6 @@ export class ActorConnRaw { /** * Do not call this directly. -enc * Establishes a connection to the server using the specified endpoint & encoding & driver. * * @protected @@ -281,6 +278,7 @@ enc actorId, this.#encoding, this.#params, + Array.from(this.#eventSubscriptions.keys()), ); this.#transport = { websocket: ws }; ws.addEventListener("open", () => { @@ -863,3 +861,15 @@ enc */ export type ActorConn = ActorConnRaw & ActorDefinitionActions; + +/** + * Connection to a actor. Allows calling actor's remote procedure calls with inferred types. See {@link ActorConnRaw} for underlying methods. + * Needs to be established manually using #connect. + * + * @template AD The actor class that this connection is for. + * @see {@link ActorConnRaw} + * @see {@link ActorConn} + */ +export type ActorManualConn = ActorConnRaw & { + connect: () => void; +} & ActorDefinitionActions; diff --git a/packages/rivetkit/src/client/actor-handle.ts b/packages/rivetkit/src/client/actor-handle.ts index 484b5ef4d..647176b8f 100644 --- a/packages/rivetkit/src/client/actor-handle.ts +++ b/packages/rivetkit/src/client/actor-handle.ts @@ -18,7 +18,11 @@ import { } from "@/schemas/client-protocol/versioned"; import { bufferToArrayBuffer } from "@/utils"; import type { ActorDefinitionActions } from "./actor-common"; -import { type ActorConn, ActorConnRaw } from "./actor-conn"; +import { + type ActorConn, + ActorConnRaw, + type ActorManualConn, +} from "./actor-conn"; import { queryActor } from "./actor-query"; import { type ClientRaw, CREATE_ACTOR_CONN_PROXY } from "./client"; import { ActorError } from "./errors"; @@ -160,6 +164,33 @@ export class ActorHandleRaw { ) as ActorConn; } + /** + * Creates a new connection to the actor, that should be manually connected. + * This is useful for creating connections that are not immediately connected, + * such as when you want to set up event listeners before connecting. + * + * @param AD - The actor definition for the connection. + * @returns {ActorConn} A connection to the actor. + */ + create(): ActorManualConn { + logger().debug({ + msg: "creating a connection from handle", + query: this.#actorQuery, + }); + + const conn = new ActorConnRaw( + this.#client, + this.#driver, + this.#params, + this.#encoding, + this.#actorQuery, + ); + + return this.#client[CREATE_ACTOR_CONN_PROXY]( + conn, + ) as ActorManualConn; + } + /** * Makes a raw HTTP request to the actor. * @@ -259,10 +290,12 @@ export class ActorHandleRaw { */ export type ActorHandle = Omit< ActorHandleRaw, - "connect" + "connect" | "create" > & { // Add typed version of ActorConn (instead of using AnyActorDefinition) connect(): ActorConn; // Resolve method returns the actor ID resolve(): Promise; + // Add typed version of create + create(): ActorManualConn; } & ActorDefinitionActions; diff --git a/packages/rivetkit/src/client/client.ts b/packages/rivetkit/src/client/client.ts index eed988038..3bbdc257a 100644 --- a/packages/rivetkit/src/client/client.ts +++ b/packages/rivetkit/src/client/client.ts @@ -8,6 +8,7 @@ import type { ActorActionFunction } from "./actor-common"; import { type ActorConn, type ActorConnRaw, + type ActorManualConn, CONNECT_SYMBOL, } from "./actor-conn"; import { type ActorHandle, ActorHandleRaw } from "./actor-handle"; @@ -149,6 +150,7 @@ export interface Region { export const ACTOR_CONNS_SYMBOL = Symbol("actorConns"); export const CREATE_ACTOR_CONN_PROXY = Symbol("createActorConnProxy"); +export const CREATE_ACTOR_PROXY = Symbol("createActorProxy"); export const TRANSPORT_SYMBOL = Symbol("transport"); /** @@ -359,12 +361,34 @@ export class ClientRaw { // Save to connection list this[ACTOR_CONNS_SYMBOL].add(conn); + logger().debug({ + msg: "creating actor proxy for connection and connecting", + conn, + }); + // Start connection conn[CONNECT_SYMBOL](); return createActorProxy(conn) as ActorConn; } + [CREATE_ACTOR_PROXY]( + conn: ActorConnRaw, + ): ActorConn { + // Save to connection list + this[ACTOR_CONNS_SYMBOL].add(conn); + + logger().debug({ msg: "creating actor proxy for connection", conn }); + + Object.assign(conn, { + connect: () => { + conn[CONNECT_SYMBOL](); + }, + }); + + return createActorProxy(conn) as ActorManualConn; + } + /** * Disconnects from all actors. * diff --git a/packages/rivetkit/src/driver-test-suite/tests/actor-conn.ts b/packages/rivetkit/src/driver-test-suite/tests/actor-conn.ts index a0e3c9cfb..2f2143e42 100644 --- a/packages/rivetkit/src/driver-test-suite/tests/actor-conn.ts +++ b/packages/rivetkit/src/driver-test-suite/tests/actor-conn.ts @@ -1,4 +1,4 @@ -import { describe, expect, test } from "vitest"; +import { describe, expect, test, vi } from "vitest"; import type { DriverTestConfig } from "../mod"; import { FAKE_TIME, setupDriverTest, waitFor } from "../utils"; @@ -190,6 +190,33 @@ export function runActorConnTests(driverTestConfig: DriverTestConfig) { // Clean up await connection.dispose(); }); + + test("should handle events sent during onConnect", async (c) => { + const { client } = await setupDriverTest(c, driverTestConfig); + + // Create actor with onConnect event + const connection = client.counter + .getOrCreate(["test-onconnect"]) + .create(); + + // Set up event listener for onConnect + const onBroadcastFn = vi.fn(); + connection.on("onconnect:broadcast", onBroadcastFn); + + // Set up event listener for onConnect message + const onMsgFn = vi.fn(); + connection.on("onconnect:msg", onMsgFn); + + connection.connect(); + + // Verify the onConnect event was received + await vi.waitFor(() => { + expect(onBroadcastFn).toHaveBeenCalled(); + expect(onMsgFn).toHaveBeenCalled(); + }); + // Clean up + await connection.dispose(); + }); }); describe("Connection Parameters", () => { diff --git a/packages/rivetkit/src/drivers/engine/actor-driver.ts b/packages/rivetkit/src/drivers/engine/actor-driver.ts index 3d42566d8..5f841df41 100644 --- a/packages/rivetkit/src/drivers/engine/actor-driver.ts +++ b/packages/rivetkit/src/drivers/engine/actor-driver.ts @@ -7,7 +7,10 @@ import * as cbor from "cbor-x"; import { WSContext } from "hono/ws"; import invariant from "invariant"; import { deserializeActorKey } from "@/actor/keys"; -import { EncodingSchema } from "@/actor/protocol/serde"; +import { + EncodingSchema, + SubscriptionsListSchema, +} from "@/actor/protocol/serde"; import type { Client } from "@/client/client"; import { getLogger } from "@/common/log"; import { @@ -15,6 +18,7 @@ import { type AnyActorInstance, HEADER_AUTH_DATA, HEADER_CONN_PARAMS, + HEADER_CONN_TOKEN, HEADER_ENCODING, type ManagerDriver, serializeEmptyPersistData, @@ -296,10 +300,14 @@ export class EngineActorDriver implements ActorDriver { const encodingRaw = request.headers.get(HEADER_ENCODING); const connParamsRaw = request.headers.get(HEADER_CONN_PARAMS); const authDataRaw = request.headers.get(HEADER_AUTH_DATA); + const subsDataRaw = request.headers.get(HEADER_CONN_TOKEN); const encoding = EncodingSchema.parse(encodingRaw); const connParams = connParamsRaw ? JSON.parse(connParamsRaw) : undefined; const authData = authDataRaw ? JSON.parse(authDataRaw) : undefined; + const subsData = subsDataRaw + ? SubscriptionsListSchema.parse(JSON.parse(subsDataRaw)) + : []; // Fetch WS handler // @@ -314,6 +322,7 @@ export class EngineActorDriver implements ActorDriver { encoding, connParams, authData, + subsData, ); } else if (url.pathname.startsWith(PATH_RAW_WEBSOCKET_PREFIX)) { wsHandlerPromise = handleRawWebSocketHandler( diff --git a/packages/rivetkit/src/drivers/file-system/manager.ts b/packages/rivetkit/src/drivers/file-system/manager.ts index 6a8345900..1f9ee39f0 100644 --- a/packages/rivetkit/src/drivers/file-system/manager.ts +++ b/packages/rivetkit/src/drivers/file-system/manager.ts @@ -141,6 +141,7 @@ export class FileSystemManagerDriver implements ManagerDriver { actorId: string, encoding: Encoding, params: unknown, + subs: string[], ): Promise { // TODO: @@ -155,6 +156,7 @@ export class FileSystemManagerDriver implements ManagerDriver { encoding, params, undefined, + subs, ); return new InlineWebSocketAdapter2(wsHandler); } else if ( @@ -195,6 +197,7 @@ export class FileSystemManagerDriver implements ManagerDriver { encoding: Encoding, connParams: unknown, authData: unknown, + subscriptions: string[], ): Promise { const upgradeWebSocket = this.#runConfig.getUpgradeWebSocket?.(); invariant(upgradeWebSocket, "missing getUpgradeWebSocket"); @@ -210,6 +213,7 @@ export class FileSystemManagerDriver implements ManagerDriver { encoding, connParams, authData, + subscriptions, ); return upgradeWebSocket(() => wsHandler)(c, noopNext()); diff --git a/packages/rivetkit/src/manager/driver.ts b/packages/rivetkit/src/manager/driver.ts index 125ade201..f1d5c5ce4 100644 --- a/packages/rivetkit/src/manager/driver.ts +++ b/packages/rivetkit/src/manager/driver.ts @@ -21,6 +21,7 @@ export interface ManagerDriver { actorId: string, encoding: Encoding, params: unknown, + subscriptions?: string[], ): Promise; proxyRequest( c: HonoContext, @@ -34,6 +35,7 @@ export interface ManagerDriver { encoding: Encoding, params: unknown, authData: unknown, + subscriptions?: string[], ): Promise; displayInformation(): ManagerDisplayInformation; diff --git a/packages/rivetkit/src/manager/protocol/query.ts b/packages/rivetkit/src/manager/protocol/query.ts index 756a84617..056f5bffd 100644 --- a/packages/rivetkit/src/manager/protocol/query.ts +++ b/packages/rivetkit/src/manager/protocol/query.ts @@ -1,10 +1,14 @@ import { z } from "zod"; -import { EncodingSchema } from "@/actor/protocol/serde"; +import { + EncodingSchema, + SubscriptionsListSchema, +} from "@/actor/protocol/serde"; import { HEADER_ACTOR_ID, HEADER_ACTOR_QUERY, HEADER_CONN_ID, HEADER_CONN_PARAMS, + HEADER_CONN_SUBS, HEADER_CONN_TOKEN, HEADER_ENCODING, } from "@/actor/router-endpoints"; @@ -55,16 +59,32 @@ export const ActorQuerySchema = z.union([ }), ]); +const json = z.string().transform((value) => { + try { + return JSON.parse(value); + } catch { + throw new Error(`Invalid JSON: ${value}`); + } +}); + export const ConnectRequestSchema = z.object({ query: ActorQuerySchema.describe(HEADER_ACTOR_QUERY), encoding: EncodingSchema.describe(HEADER_ENCODING), connParams: z.string().optional().describe(HEADER_CONN_PARAMS), + subscriptions: json + .pipe(SubscriptionsListSchema) + .optional() + .describe(HEADER_CONN_SUBS), }); export const ConnectWebSocketRequestSchema = z.object({ query: ActorQuerySchema.describe("query"), encoding: EncodingSchema.describe("encoding"), connParams: z.unknown().optional().describe("conn_params"), + subscriptions: json + .pipe(SubscriptionsListSchema) + .optional() + .describe("subscriptions"), }); export const ConnMessageRequestSchema = z.object({