Skip to content

Commit 0127964

Browse files
committed
feat: subscribe to events before connect
1 parent d43e771 commit 0127964

File tree

16 files changed

+258
-13
lines changed

16 files changed

+258
-13
lines changed

packages/core/fixtures/driver-test-suite/counter.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@ import { actor } from "@rivetkit/core";
33
export const counter = actor({
44
onAuth: () => {},
55
state: { count: 0 },
6+
onConnect: (c, conn) => {
7+
c.broadcast("onconnect:broadcast", "Hello!");
8+
conn.send("onconnect:msg", "Welcome to the counter actor!");
9+
},
610
actions: {
711
increment: (c, x: number) => {
812
c.state.count += x;

packages/core/src/actor/instance.ts

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -811,6 +811,7 @@ export class ActorInstance<
811811
state: CS,
812812
driverId: ConnectionDriver,
813813
driverState: unknown,
814+
subscriptions: string[],
814815
authData: unknown,
815816
): Promise<Conn<S, CP, CS, V, I, AD, DB>> {
816817
this.#assertReady();
@@ -840,6 +841,12 @@ export class ActorInstance<
840841
);
841842
this.#connections.set(conn.id, conn);
842843

844+
if (subscriptions) {
845+
for (const sub of subscriptions) {
846+
this.#addSubscription(sub, conn, true);
847+
}
848+
}
849+
843850
// Add to persistence & save immediately
844851
this.#persist.c.push(persist);
845852
this.saveState({ immediate: true });
@@ -901,6 +908,7 @@ export class ActorInstance<
901908
return await this.executeAction(ctx, name, args);
902909
},
903910
onSubscribe: async (eventName, conn) => {
911+
console.log("subscribing to event", { eventName, connId: conn.id });
904912
this.inspector.emitter.emit("eventFired", {
905913
type: "subscribe",
906914
eventName,
@@ -1334,6 +1342,13 @@ export class ActorInstance<
13341342
_broadcast<Args extends Array<unknown>>(name: string, ...args: Args) {
13351343
this.#assertReady();
13361344

1345+
console.log("broadcasting event", {
1346+
name,
1347+
args,
1348+
actorId: this.id,
1349+
subscriptions: this.#subscriptionIndex.size,
1350+
connections: this.conns.size,
1351+
});
13371352
this.inspector.emitter.emit("eventFired", {
13381353
type: "broadcast",
13391354
eventName: name,

packages/core/src/actor/protocol/serde.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ export const EncodingSchema = z.enum(["json", "cbor"]);
1717
*/
1818
export type Encoding = z.infer<typeof EncodingSchema>;
1919

20+
export const SubscriptionsListSchema = z.array(z.string());
21+
export type SubscriptionsList = z.infer<typeof SubscriptionsListSchema>;
22+
2023
/**
2124
* Helper class that helps serialize data without re-serializing for the same encoding.
2225
*/

packages/core/src/actor/router-endpoints.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ export async function handleWebSocketConnect(
108108
encoding: Encoding,
109109
parameters: unknown,
110110
authData: unknown,
111+
subscriptions: string[],
111112
): Promise<UpgradeWebSocketArgs> {
112113
const exposeInternalError = c ? getRequestExposeInternalError(c.req) : false;
113114

@@ -176,6 +177,7 @@ export async function handleWebSocketConnect(
176177
connState,
177178
CONNECTION_DRIVER_WEBSOCKET,
178179
{ encoding } satisfies GenericWebSocketDriverState,
180+
subscriptions,
179181
authData,
180182
);
181183

@@ -322,6 +324,7 @@ export async function handleSseConnect(
322324
runConfig: RunConfig,
323325
actorDriver: ActorDriver,
324326
actorId: string,
327+
subscriptions: string[],
325328
authData: unknown,
326329
) {
327330
const encoding = getRequestEncoding(c.req);
@@ -357,6 +360,7 @@ export async function handleSseConnect(
357360
connState,
358361
CONNECTION_DRIVER_SSE,
359362
{ encoding } satisfies GenericSseDriverState,
363+
subscriptions,
360364
authData,
361365
);
362366

@@ -487,6 +491,7 @@ export async function handleAction(
487491
connState,
488492
CONNECTION_DRIVER_HTTP,
489493
{} satisfies GenericHttpDriverState,
494+
[],
490495
authData,
491496
);
492497

@@ -706,6 +711,8 @@ export const HEADER_CONN_ID = "X-RivetKit-Conn";
706711

707712
export const HEADER_CONN_TOKEN = "X-RivetKit-Conn-Token";
708713

714+
export const HEADER_CONN_SUBS = "X-RivetKit-Conn-Subs";
715+
709716
/**
710717
* Headers that publics can send from public clients.
711718
*
@@ -720,6 +727,7 @@ export const ALLOWED_PUBLIC_HEADERS = [
720727
HEADER_ACTOR_ID,
721728
HEADER_CONN_ID,
722729
HEADER_CONN_TOKEN,
730+
HEADER_CONN_SUBS,
723731
];
724732

725733
// Helper to get connection parameters for the request

packages/core/src/actor/router.ts

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import { Hono, type Context as HonoContext } from "hono";
22
import invariant from "invariant";
3-
import { EncodingSchema } from "@/actor/protocol/serde";
3+
import {
4+
EncodingSchema,
5+
SubscriptionsListSchema,
6+
} from "@/actor/protocol/serde";
47
import {
58
type ActionOpts,
69
type ActionOutput,
@@ -12,6 +15,7 @@ import {
1215
HEADER_AUTH_DATA,
1316
HEADER_CONN_ID,
1417
HEADER_CONN_PARAMS,
18+
HEADER_CONN_SUBS,
1519
HEADER_CONN_TOKEN,
1620
HEADER_ENCODING,
1721
handleAction,
@@ -83,12 +87,16 @@ export function createActorRouter(
8387
const encodingRaw = c.req.header(HEADER_ENCODING);
8488
const connParamsRaw = c.req.header(HEADER_CONN_PARAMS);
8589
const authDataRaw = c.req.header(HEADER_AUTH_DATA);
90+
const subsRaw = c.req.header(HEADER_CONN_SUBS);
8691

8792
const encoding = EncodingSchema.parse(encodingRaw);
8893
const connParams = connParamsRaw
8994
? JSON.parse(connParamsRaw)
9095
: undefined;
9196
const authData = authDataRaw ? JSON.parse(authDataRaw) : undefined;
97+
const subs = subsRaw
98+
? SubscriptionsListSchema.parse(JSON.parse(subsRaw))
99+
: [];
92100

93101
return await handleWebSocketConnect(
94102
c as HonoContext,
@@ -97,6 +105,7 @@ export function createActorRouter(
97105
c.env.actorId,
98106
encoding,
99107
connParams,
108+
subs,
100109
authData,
101110
);
102111
})(c, noopNext());
@@ -114,8 +123,20 @@ export function createActorRouter(
114123
if (authDataRaw) {
115124
authData = JSON.parse(authDataRaw);
116125
}
126+
const subsRaw = c.req.header(HEADER_CONN_SUBS);
127+
128+
const subscriptions = subsRaw
129+
? SubscriptionsListSchema.parse(JSON.parse(subsRaw))
130+
: [];
117131

118-
return handleSseConnect(c, runConfig, actorDriver, c.env.actorId, authData);
132+
return handleSseConnect(
133+
c,
134+
runConfig,
135+
actorDriver,
136+
c.env.actorId,
137+
subscriptions,
138+
authData,
139+
);
119140
});
120141

121142
router.post("/action/:action", async (c) => {

packages/core/src/client/actor-conn.ts

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ export class ActorConnRaw {
9898
/**
9999
* Interval that keeps the NodeJS process alive if this is the only thing running.
100100
*
101-
* See ttps://github.com/nodejs/node/issues/22088
101+
* @see https://github.com/nodejs/node/issues/22088
102102
*/
103103
#keepNodeAliveInterval: NodeJS.Timeout;
104104

@@ -111,8 +111,6 @@ export class ActorConnRaw {
111111
#encodingKind: Encoding;
112112
#actorQuery: ActorQuery;
113113

114-
// TODO: ws message queue
115-
116114
/**
117115
* Do not call this directly.
118116
*
@@ -187,7 +185,6 @@ export class ActorConnRaw {
187185

188186
/**
189187
* Do not call this directly.
190-
enc
191188
* Establishes a connection to the server using the specified endpoint & encoding & driver.
192189
*
193190
* @protected
@@ -259,6 +256,7 @@ enc
259256
this.#actorQuery,
260257
this.#encodingKind,
261258
this.#params,
259+
Array.from(this.#eventSubscriptions.keys()),
262260
signal ? { signal } : undefined,
263261
);
264262
this.#transport = { websocket: ws };
@@ -282,6 +280,7 @@ enc
282280
this.#actorQuery,
283281
this.#encodingKind,
284282
this.#params,
283+
Array.from(this.#eventSubscriptions.keys()),
285284
signal ? { signal } : undefined,
286285
);
287286
this.#transport = { sse: eventSource };
@@ -807,3 +806,15 @@ enc
807806
*/
808807
export type ActorConn<AD extends AnyActorDefinition> = ActorConnRaw &
809808
ActorDefinitionActions<AD>;
809+
810+
/**
811+
* Connection to a actor. Allows calling actor's remote procedure calls with inferred types. See {@link ActorConnRaw} for underlying methods.
812+
* Needs to be established manually using #connect.
813+
*
814+
* @template AD The actor class that this connection is for.
815+
* @see {@link ActorConnRaw}
816+
* @see {@link ActorConn}
817+
*/
818+
export type ActorManualConn<AD extends AnyActorDefinition> = ActorConnRaw & {
819+
connect: () => void;
820+
} & ActorDefinitionActions<AD>;

packages/core/src/client/actor-handle.ts

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,16 @@ import { assertUnreachable } from "@/actor/utils";
55
import { importWebSocket } from "@/common/websocket";
66
import type { ActorQuery } from "@/manager/protocol/query";
77
import type { ActorDefinitionActions } from "./actor-common";
8-
import { type ActorConn, ActorConnRaw } from "./actor-conn";
8+
import {
9+
type ActorConn,
10+
ActorConnRaw,
11+
type ActorManualConn,
12+
} from "./actor-conn";
913
import {
1014
type ClientDriver,
1115
type ClientRaw,
1216
CREATE_ACTOR_CONN_PROXY,
17+
CREATE_ACTOR_PROXY,
1318
} from "./client";
1419
import { logger } from "./log";
1520
import { rawHttpFetch, rawWebSocket } from "./raw-utils";
@@ -98,6 +103,32 @@ export class ActorHandleRaw {
98103
) as ActorConn<AnyActorDefinition>;
99104
}
100105

106+
/**
107+
* Creates a new connection to the actor, that should be manually connected.
108+
* This is useful for creating connections that are not immediately connected,
109+
* such as when you want to set up event listeners before connecting.
110+
*
111+
* @param AD - The actor definition for the connection.
112+
* @returns {ActorConn<AD>} A connection to the actor.
113+
*/
114+
create(): ActorManualConn<AnyActorDefinition> {
115+
logger().debug("creating a connection from handle", {
116+
query: this.#actorQuery,
117+
});
118+
119+
const conn = new ActorConnRaw(
120+
this.#client,
121+
this.#driver,
122+
this.#params,
123+
this.#encodingKind,
124+
this.#actorQuery,
125+
);
126+
127+
return this.#client[CREATE_ACTOR_PROXY](
128+
conn,
129+
) as ActorManualConn<AnyActorDefinition>;
130+
}
131+
101132
/**
102133
* Makes a raw HTTP request to the actor.
103134
*
@@ -188,10 +219,12 @@ export class ActorHandleRaw {
188219
*/
189220
export type ActorHandle<AD extends AnyActorDefinition> = Omit<
190221
ActorHandleRaw,
191-
"connect"
222+
"connect" | "create"
192223
> & {
193224
// Add typed version of ActorConn (instead of using AnyActorDefinition)
194225
connect(): ActorConn<AD>;
195226
// Resolve method returns the actor ID
196227
resolve(): Promise<string>;
228+
// Add typed version of create
229+
create(): ActorManualConn<AD>;
197230
} & ActorDefinitionActions<AD>;

packages/core/src/client/client.ts

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import type { ActorActionFunction } from "./actor-common";
1111
import {
1212
type ActorConn,
1313
type ActorConnRaw,
14+
type ActorManualConn,
1415
CONNECT_SYMBOL,
1516
} from "./actor-conn";
1617
import { type ActorHandle, ActorHandleRaw } from "./actor-handle";
@@ -157,6 +158,7 @@ export interface Region {
157158

158159
export const ACTOR_CONNS_SYMBOL = Symbol("actorConns");
159160
export const CREATE_ACTOR_CONN_PROXY = Symbol("createActorConnProxy");
161+
export const CREATE_ACTOR_PROXY = Symbol("createActorProxy");
160162
export const TRANSPORT_SYMBOL = Symbol("transport");
161163

162164
export interface ClientDriver {
@@ -181,13 +183,15 @@ export interface ClientDriver {
181183
actorQuery: ActorQuery,
182184
encodingKind: Encoding,
183185
params: unknown,
186+
subscriptions: string[],
184187
opts: { signal?: AbortSignal } | undefined,
185188
): Promise<WebSocket>;
186189
connectSse(
187190
c: HonoContext | undefined,
188191
actorQuery: ActorQuery,
189192
encodingKind: Encoding,
190193
params: unknown,
194+
subscriptions: string[],
191195
opts: { signal?: AbortSignal } | undefined,
192196
): Promise<UniversalEventSource>;
193197
sendHttpMessage(
@@ -426,12 +430,35 @@ export class ClientRaw {
426430
// Save to connection list
427431
this[ACTOR_CONNS_SYMBOL].add(conn);
428432

433+
logger().debug("creating actor proxy for connection and connecting", {
434+
conn,
435+
});
436+
429437
// Start connection
430438
conn[CONNECT_SYMBOL]();
431439

432440
return createActorProxy(conn) as ActorConn<AD>;
433441
}
434442

443+
[CREATE_ACTOR_PROXY]<AD extends AnyActorDefinition>(
444+
conn: ActorConnRaw,
445+
): ActorConn<AD> {
446+
// Save to connection list
447+
this[ACTOR_CONNS_SYMBOL].add(conn);
448+
449+
logger().debug("creating actor proxy for connection", {
450+
conn,
451+
});
452+
453+
Object.assign(conn, {
454+
connect: () => {
455+
conn[CONNECT_SYMBOL]();
456+
},
457+
});
458+
459+
return createActorProxy(conn) as ActorManualConn<AD>;
460+
}
461+
435462
/**
436463
* Disconnects from all actors.
437464
*

packages/core/src/client/http-client-driver.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ export function createHttpClientDriver(managerEndpoint: string): ClientDriver {
121121
actorQuery: ActorQuery,
122122
encodingKind: Encoding,
123123
params: unknown,
124+
subs: string[] | undefined,
124125
): Promise<WebSocket> => {
125126
const { WebSocket } = await dynamicImports;
126127

@@ -138,6 +139,9 @@ export function createHttpClientDriver(managerEndpoint: string): ClientDriver {
138139
protocol.push(
139140
`conn_params.${encodeURIComponent(JSON.stringify(params))}`,
140141
);
142+
if (subs) {
143+
protocol.push(`subs.${encodeURIComponent(JSON.stringify(subs))}`);
144+
}
141145

142146
// HACK: See packages/drivers/cloudflare-workers/src/websocket.ts
143147
protocol.push("rivetkit");

0 commit comments

Comments
 (0)