Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions packages/rivetkit/fixtures/driver-test-suite/counter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ import { actor } from "rivetkit";

export const counter = actor({
state: { count: 0 },
onConnect: (c, conn) => {
c.broadcast("onconnect:broadcast", "Hello!");
conn.send("onconnect:msg", "Welcome to the counter actor!");
},
actions: {
increment: (c, x: number) => {
c.state.count += x;
Expand Down
14 changes: 14 additions & 0 deletions packages/rivetkit/src/actor/instance.ts
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,7 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
state: CS,
driverId: ConnectionDriver,
driverState: unknown,
subscriptions: string[],
authData: unknown,
): Promise<Conn<S, CP, CS, V, I, DB>> {
this.#assertReady();
Expand Down Expand Up @@ -950,6 +951,11 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
//
// Do this immediately after adding connection & before any async logic in order to avoid race conditions with sleep timeouts
this.#resetSleepTimer();
if (subscriptions) {
for (const sub of subscriptions) {
this.#addSubscription(sub, conn, true);
}
}

// Add to persistence & save immediately
this.#persist.connections.push(persist);
Expand Down Expand Up @@ -1017,6 +1023,7 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
return await this.executeAction(ctx, name, args);
},
onSubscribe: async (eventName, conn) => {
console.log("subscribing to event", { eventName, connId: conn.id });
this.inspector.emitter.emit("eventFired", {
type: "subscribe",
eventName,
Expand Down Expand Up @@ -1489,6 +1496,13 @@ export class ActorInstance<S, CP, CS, V, I, DB extends AnyDatabaseProvider> {
_broadcast<Args extends Array<unknown>>(name: string, ...args: Args) {
this.#assertReady();

console.log("broadcasting event", {
name,
args,
actorId: this.id,
subscriptions: this.#subscriptionIndex.size,
connections: this.conns.size,
});
this.inspector.emitter.emit("eventFired", {
type: "broadcast",
eventName: name,
Expand Down
1 change: 1 addition & 0 deletions packages/rivetkit/src/actor/protocol/serde.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ export type InputData = string | Buffer | Blob | ArrayBufferLike | Uint8Array;
export type OutputData = string | Uint8Array;

export const EncodingSchema = z.enum(["json", "cbor", "bare"]);
export const SubscriptionsListSchema = z.array(z.string());

/**
* Encoding used to communicate between the client & actor.
Expand Down
8 changes: 8 additions & 0 deletions packages/rivetkit/src/actor/router-endpoints.ts
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ export async function handleWebSocketConnect(
encoding: Encoding,
parameters: unknown,
authData: unknown,
subscriptions: string[],
): Promise<UpgradeWebSocketArgs> {
const exposeInternalError = req ? getRequestExposeInternalError(req) : false;

Expand Down Expand Up @@ -182,6 +183,7 @@ export async function handleWebSocketConnect(
connState,
CONNECTION_DRIVER_WEBSOCKET,
{ encoding } satisfies GenericWebSocketDriverState,
subscriptions,
authData,
);

Expand Down Expand Up @@ -332,6 +334,7 @@ export async function handleSseConnect(
_runConfig: RunConfig,
actorDriver: ActorDriver,
actorId: string,
subscriptions: string[],
authData: unknown,
) {
const encoding = getRequestEncoding(c.req);
Expand Down Expand Up @@ -367,6 +370,7 @@ export async function handleSseConnect(
connState,
CONNECTION_DRIVER_SSE,
{ encoding } satisfies GenericSseDriverState,
subscriptions,
authData,
);

Expand Down Expand Up @@ -463,6 +467,7 @@ export async function handleAction(
connState,
CONNECTION_DRIVER_HTTP,
{} satisfies GenericHttpDriverState,
[],
authData,
);

Expand Down Expand Up @@ -655,6 +660,8 @@ export const HEADER_CONN_ID = "X-RivetKit-Conn";

export const HEADER_CONN_TOKEN = "X-RivetKit-Conn-Token";

export const HEADER_CONN_SUBS = "X-RivetKit-Conn-Subs";

/**
* Headers that publics can send from public clients.
*
Expand All @@ -669,6 +676,7 @@ export const ALLOWED_PUBLIC_HEADERS = [
HEADER_ACTOR_ID,
HEADER_CONN_ID,
HEADER_CONN_TOKEN,
HEADER_CONN_SUBS,
];

// Helper to get connection parameters for the request
Expand Down
25 changes: 23 additions & 2 deletions packages/rivetkit/src/actor/router.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import { Hono, type Context as HonoContext } from "hono";
import { cors } from "hono/cors";
import invariant from "invariant";
import { EncodingSchema } from "@/actor/protocol/serde";
import {
EncodingSchema,
SubscriptionsListSchema,
} from "@/actor/protocol/serde";
import {
type ActionOpts,
type ActionOutput,
Expand All @@ -13,6 +16,7 @@ import {
HEADER_AUTH_DATA,
HEADER_CONN_ID,
HEADER_CONN_PARAMS,
HEADER_CONN_SUBS,
HEADER_CONN_TOKEN,
HEADER_ENCODING,
handleAction,
Expand Down Expand Up @@ -84,12 +88,16 @@ export function createActorRouter(
const encodingRaw = c.req.header(HEADER_ENCODING);
const connParamsRaw = c.req.header(HEADER_CONN_PARAMS);
const authDataRaw = c.req.header(HEADER_AUTH_DATA);
const subsRaw = c.req.header(HEADER_CONN_SUBS);

const encoding = EncodingSchema.parse(encodingRaw);
const connParams = connParamsRaw
? JSON.parse(connParamsRaw)
: undefined;
const authData = authDataRaw ? JSON.parse(authDataRaw) : undefined;
const subs = subsRaw
? SubscriptionsListSchema.parse(JSON.parse(subsRaw))
: [];

return await handleWebSocketConnect(
c.req.raw,
Expand All @@ -98,6 +106,7 @@ export function createActorRouter(
c.env.actorId,
encoding,
connParams,
subs,
authData,
);
})(c, noopNext());
Expand All @@ -115,8 +124,20 @@ export function createActorRouter(
if (authDataRaw) {
authData = JSON.parse(authDataRaw);
}
const subsRaw = c.req.header(HEADER_CONN_SUBS);

const subscriptions = subsRaw
? SubscriptionsListSchema.parse(JSON.parse(subsRaw))
: [];

return handleSseConnect(c, runConfig, actorDriver, c.env.actorId, authData);
return handleSseConnect(
c,
runConfig,
actorDriver,
c.env.actorId,
subscriptions,
authData,
);
});

router.post("/action/:action", async (c) => {
Expand Down
18 changes: 14 additions & 4 deletions packages/rivetkit/src/client/actor-conn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ export class ActorConnRaw {
/**
* Interval that keeps the NodeJS process alive if this is the only thing running.
*
* See ttps://github.com/nodejs/node/issues/22088
* @see https://github.com/nodejs/node/issues/22088
*/
#keepNodeAliveInterval: NodeJS.Timeout;

Expand All @@ -126,8 +126,6 @@ export class ActorConnRaw {
#encoding: Encoding;
#actorQuery: ActorQuery;

// TODO: ws message queue

/**
* Do not call this directly.
*
Expand Down Expand Up @@ -203,7 +201,6 @@ export class ActorConnRaw {

/**
* Do not call this directly.
enc
* Establishes a connection to the server using the specified endpoint & encoding & driver.
*
* @protected
Expand Down Expand Up @@ -281,6 +278,7 @@ enc
actorId,
this.#encoding,
this.#params,
Array.from(this.#eventSubscriptions.keys()),
);
this.#transport = { websocket: ws };
ws.addEventListener("open", () => {
Expand Down Expand Up @@ -863,3 +861,15 @@ enc
*/
export type ActorConn<AD extends AnyActorDefinition> = ActorConnRaw &
ActorDefinitionActions<AD>;

/**
* Connection to a actor. Allows calling actor's remote procedure calls with inferred types. See {@link ActorConnRaw} for underlying methods.
* Needs to be established manually using #connect.
*
* @template AD The actor class that this connection is for.
* @see {@link ActorConnRaw}
* @see {@link ActorConn}
*/
export type ActorManualConn<AD extends AnyActorDefinition> = ActorConnRaw & {
connect: () => void;
} & ActorDefinitionActions<AD>;
37 changes: 35 additions & 2 deletions packages/rivetkit/src/client/actor-handle.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ import {
} from "@/schemas/client-protocol/versioned";
import { bufferToArrayBuffer } from "@/utils";
import type { ActorDefinitionActions } from "./actor-common";
import { type ActorConn, ActorConnRaw } from "./actor-conn";
import {
type ActorConn,
ActorConnRaw,
type ActorManualConn,
} from "./actor-conn";
import { queryActor } from "./actor-query";
import { type ClientRaw, CREATE_ACTOR_CONN_PROXY } from "./client";
import { ActorError } from "./errors";
Expand Down Expand Up @@ -160,6 +164,33 @@ export class ActorHandleRaw {
) as ActorConn<AnyActorDefinition>;
}

/**
* Creates a new connection to the actor, that should be manually connected.
* This is useful for creating connections that are not immediately connected,
* such as when you want to set up event listeners before connecting.
*
* @param AD - The actor definition for the connection.
* @returns {ActorConn<AD>} A connection to the actor.
*/
create(): ActorManualConn<AnyActorDefinition> {
logger().debug({
msg: "creating a connection from handle",
query: this.#actorQuery,
});

const conn = new ActorConnRaw(
this.#client,
this.#driver,
this.#params,
this.#encoding,
this.#actorQuery,
);

return this.#client[CREATE_ACTOR_CONN_PROXY](
conn,
) as ActorManualConn<AnyActorDefinition>;
}

/**
* Makes a raw HTTP request to the actor.
*
Expand Down Expand Up @@ -259,10 +290,12 @@ export class ActorHandleRaw {
*/
export type ActorHandle<AD extends AnyActorDefinition> = Omit<
ActorHandleRaw,
"connect"
"connect" | "create"
> & {
// Add typed version of ActorConn (instead of using AnyActorDefinition)
connect(): ActorConn<AD>;
// Resolve method returns the actor ID
resolve(): Promise<string>;
// Add typed version of create
create(): ActorManualConn<AD>;
} & ActorDefinitionActions<AD>;
24 changes: 24 additions & 0 deletions packages/rivetkit/src/client/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import type { ActorActionFunction } from "./actor-common";
import {
type ActorConn,
type ActorConnRaw,
type ActorManualConn,
CONNECT_SYMBOL,
} from "./actor-conn";
import { type ActorHandle, ActorHandleRaw } from "./actor-handle";
Expand Down Expand Up @@ -149,6 +150,7 @@ export interface Region {

export const ACTOR_CONNS_SYMBOL = Symbol("actorConns");
export const CREATE_ACTOR_CONN_PROXY = Symbol("createActorConnProxy");
export const CREATE_ACTOR_PROXY = Symbol("createActorProxy");
export const TRANSPORT_SYMBOL = Symbol("transport");

/**
Expand Down Expand Up @@ -359,12 +361,34 @@ export class ClientRaw {
// Save to connection list
this[ACTOR_CONNS_SYMBOL].add(conn);

logger().debug({
msg: "creating actor proxy for connection and connecting",
conn,
});

// Start connection
conn[CONNECT_SYMBOL]();

return createActorProxy(conn) as ActorConn<AD>;
}

[CREATE_ACTOR_PROXY]<AD extends AnyActorDefinition>(
conn: ActorConnRaw,
): ActorConn<AD> {
// Save to connection list
this[ACTOR_CONNS_SYMBOL].add(conn);

logger().debug({ msg: "creating actor proxy for connection", conn });

Object.assign(conn, {
connect: () => {
conn[CONNECT_SYMBOL]();
},
});

return createActorProxy(conn) as ActorManualConn<AD>;
}

/**
* Disconnects from all actors.
*
Expand Down
29 changes: 28 additions & 1 deletion packages/rivetkit/src/driver-test-suite/tests/actor-conn.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { describe, expect, test } from "vitest";
import { describe, expect, test, vi } from "vitest";
import type { DriverTestConfig } from "../mod";
import { FAKE_TIME, setupDriverTest, waitFor } from "../utils";

Expand Down Expand Up @@ -190,6 +190,33 @@ export function runActorConnTests(driverTestConfig: DriverTestConfig) {
// Clean up
await connection.dispose();
});

test("should handle events sent during onConnect", async (c) => {
const { client } = await setupDriverTest(c, driverTestConfig);

// Create actor with onConnect event
const connection = client.counter
.getOrCreate(["test-onconnect"])
.create();

// Set up event listener for onConnect
const onBroadcastFn = vi.fn();
connection.on("onconnect:broadcast", onBroadcastFn);

// Set up event listener for onConnect message
const onMsgFn = vi.fn();
connection.on("onconnect:msg", onMsgFn);

connection.connect();

// Verify the onConnect event was received
await vi.waitFor(() => {
expect(onBroadcastFn).toHaveBeenCalled();
expect(onMsgFn).toHaveBeenCalled();
});
// Clean up
await connection.dispose();
});
});

describe("Connection Parameters", () => {
Expand Down
Loading
Loading