diff --git a/CLAUDE.md b/CLAUDE.md index 3fcb45ed6..d5b474174 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -108,3 +108,7 @@ This ensures imports resolve correctly across different build environments and p - Run `yarn check-types` regularly during development to catch type errors early. Prefer `yarn check-types` instead of `yarn build`. - Use `tsx` CLI to execute TypeScript scripts directly (e.g., `tsx script.ts` instead of `node script.js`). - Do not auto-commit changes + +## Test Guidelines + +- Do not check if errors are an instanceOf WorkerError in tests. Many error types do not have the same prototype chain when sent over the network, but still have the same properties so you can safely cast with `as`. diff --git a/docs/clients/python.mdx b/docs/clients/python.mdx index bc9f23ea3..3cf4f301f 100644 --- a/docs/clients/python.mdx +++ b/docs/clients/python.mdx @@ -47,7 +47,7 @@ The RivetKit Python client provides a way to connect to and interact with worker ```python Async import asyncio - from worker_core_client import AsyncClient + from rivetkit_client import AsyncClient async def main(): # Replace with your endpoint URL after deployment @@ -77,7 +77,7 @@ The RivetKit Python client provides a way to connect to and interact with worker ``` ```python Sync - from worker_core_client import Client + from rivetkit_client import Client # Replace with your endpoint URL after deployment client = Client("http://localhost:6420") @@ -128,4 +128,4 @@ The RivetKit Python client provides a way to connect to and interact with worker - \ No newline at end of file + diff --git a/docs/clients/rust.mdx b/docs/clients/rust.mdx index 1cc12d68d..fb97bf5a6 100644 --- a/docs/clients/rust.mdx +++ b/docs/clients/rust.mdx @@ -41,7 +41,7 @@ The RivetKit Rust client provides a way to connect to and interact with workers Modify `src/main.rs` to connect to your worker: ```rust src/main.rs - use worker_core_client::{Client, GetOptions, TransportKind, EncodingKind}; + use rivetkit_client::{Client, GetOptions, TransportKind, EncodingKind}; use serde_json::json; use std::time::Duration; diff --git a/docs/concepts/interacting-with-workers.mdx b/docs/concepts/interacting-with-workers.mdx index 6834f350b..5b3fc4a7d 100644 --- a/docs/concepts/interacting-with-workers.mdx +++ b/docs/concepts/interacting-with-workers.mdx @@ -19,7 +19,7 @@ const client = createClient(/* CONNECTION ADDRESS */); ``` ```rust Rust -use worker_core_client::{Client, TransportKind, EncodingKind}; +use rivetkit_client::{Client, TransportKind, EncodingKind}; // Create a client with connection address and configuration let client = Client::new( @@ -30,7 +30,7 @@ let client = Client::new( ``` ```python Python (Callbacks) -from worker_core_client import AsyncClient as WorkerClient +from rivetkit_client import AsyncClient as WorkerClient # Create a client with the connection address client = WorkerClient("http://localhost:6420") @@ -60,7 +60,7 @@ await room.sendMessage("Alice", "Hello everyone!"); ``` ```rust Rust -use worker_core_client::GetOptions; +use rivetkit_client::GetOptions; use serde_json::json; // Connect to a chat room for the "general" channel @@ -116,8 +116,8 @@ await doc.initializeDocument("My New Document"); ``` ```rust Rust -use worker_core_client::{CreateOptions}; -use worker_core_client::client::CreateRequestMetadata; +use rivetkit_client::{CreateOptions}; +use rivetkit_client::client::CreateRequestMetadata; use serde_json::json; // Create a new document worker @@ -169,7 +169,7 @@ await doc.updateContent("Updated content"); ``` ```rust Rust -use worker_core_client::GetWithIdOptions; +use rivetkit_client::GetWithIdOptions; // Connect to a specific worker by its ID let my_worker_id = "55425f42-82f8-451f-82c1-6227c83c9372"; @@ -374,7 +374,7 @@ const chatRoom = await client.chatRoom.get({ channel: "super-secret" }, { ```rust Rust use serde_json::json; -use worker_core_client::GetOptions; +use rivetkit_client::GetOptions; let tags = vec![ ("channel".to_string(), "super-secret".to_string()), @@ -471,7 +471,7 @@ const client = createClient( ``` ```rust Rust -use worker_core_client::{Client, TransportKind, EncodingKind}; +use rivetkit_client::{Client, TransportKind, EncodingKind}; // Create client with specific options let client = Client::new( @@ -484,7 +484,7 @@ let client = Client::new( ``` ```python Python (Callbacks) -from worker_core_client import AsyncClient as WorkerClient +from rivetkit_client import AsyncClient as WorkerClient # Example with all client options client = WorkerClient( diff --git a/docs/openapi.json b/docs/openapi.json index dd55f79fb..3dabdd1ef 100644 --- a/docs/openapi.json +++ b/docs/openapi.json @@ -117,27 +117,6 @@ }, "/workers/connect/websocket": { "get": { - "parameters": [ - { - "schema": { - "type": "string", - "description": "The encoding format to use for the response (json, cbor)", - "example": "json" - }, - "required": true, - "name": "encoding", - "in": "query" - }, - { - "schema": { - "type": "string", - "description": "Worker query information" - }, - "required": true, - "name": "query", - "in": "query" - } - ], "responses": { "101": { "description": "WebSocket upgrade" diff --git a/docs/workers/quickstart.mdx b/docs/workers/quickstart.mdx index c10e17c91..e8837c446 100644 --- a/docs/workers/quickstart.mdx +++ b/docs/workers/quickstart.mdx @@ -7,10 +7,11 @@ description: Start building awesome documentation in under 5 minutes ```ts registry.ts import { setup } from "rivetkit"; import { worker } from "rivetkit/worker"; -import { workflow } from "rivetkit/workflow"; -import { realtime } from "rivetkit/realtime"; const counter = worker({ + onAuth: async () => { + // Allow public access + }, state: { count: 0, }, @@ -36,10 +37,16 @@ export type Registry = typeof registry; ``` ```ts server.ts +// With router import { registry } from "./registry"; -const registry = new Hono(); -app.route("/registry", registry.handler); +const app = new Hono(); +app.route("/registry", c => registry.handler(c.req.raw)); +serve(app); + +// Without router +import { serve } from "@rivetkit/node"; + serve(registry); ``` @@ -47,7 +54,7 @@ serve(registry); import { createClient } from "rivetkit/client"; import type { Registry } from "./registry"; -const client = createClient("http://localhost:8080/registry"); +const client = createClient("http://localhost:8080"); ``` diff --git a/packages/core/fixtures/driver-test-suite/action-inputs.ts b/packages/core/fixtures/driver-test-suite/action-inputs.ts index 4a1d23d68..c57d9f4b9 100644 --- a/packages/core/fixtures/driver-test-suite/action-inputs.ts +++ b/packages/core/fixtures/driver-test-suite/action-inputs.ts @@ -7,6 +7,7 @@ export interface State { // Test worker that can capture input during creation export const inputWorker = worker({ + onAuth: () => {}, createState: (c, { input }): State => { return { initialInput: input, diff --git a/packages/core/fixtures/driver-test-suite/action-timeout.ts b/packages/core/fixtures/driver-test-suite/action-timeout.ts index fb18967e0..9127cf20e 100644 --- a/packages/core/fixtures/driver-test-suite/action-timeout.ts +++ b/packages/core/fixtures/driver-test-suite/action-timeout.ts @@ -2,6 +2,7 @@ import { worker } from "rivetkit"; // Short timeout worker export const shortTimeoutWorker = worker({ + onAuth: () => {}, state: { value: 0 }, options: { action: { @@ -22,6 +23,7 @@ export const shortTimeoutWorker = worker({ // Long timeout worker export const longTimeoutWorker = worker({ + onAuth: () => {}, state: { value: 0 }, options: { action: { @@ -39,6 +41,7 @@ export const longTimeoutWorker = worker({ // Default timeout worker export const defaultTimeoutWorker = worker({ + onAuth: () => {}, state: { value: 0 }, actions: { normalAction: async (c) => { @@ -50,6 +53,7 @@ export const defaultTimeoutWorker = worker({ // Sync worker (timeout shouldn't apply) export const syncTimeoutWorker = worker({ + onAuth: () => {}, state: { value: 0 }, options: { action: { diff --git a/packages/core/fixtures/driver-test-suite/action-types.ts b/packages/core/fixtures/driver-test-suite/action-types.ts index 006b36fb6..cf280060c 100644 --- a/packages/core/fixtures/driver-test-suite/action-types.ts +++ b/packages/core/fixtures/driver-test-suite/action-types.ts @@ -2,6 +2,7 @@ import { worker, UserError } from "rivetkit"; // Worker with synchronous actions export const syncActionWorker = worker({ + onAuth: () => {}, state: { value: 0 }, actions: { // Simple synchronous action that returns a value directly @@ -25,6 +26,7 @@ export const syncActionWorker = worker({ // Worker with asynchronous actions export const asyncActionWorker = worker({ + onAuth: () => {}, state: { value: 0, data: null as any }, actions: { // Async action with a delay @@ -57,6 +59,7 @@ export const asyncActionWorker = worker({ // Worker with promise actions export const promiseWorker = worker({ + onAuth: () => {}, state: { results: [] as string[] }, actions: { // Action that returns a resolved promise diff --git a/packages/core/fixtures/driver-test-suite/auth.ts b/packages/core/fixtures/driver-test-suite/auth.ts new file mode 100644 index 000000000..f4dd6657e --- /dev/null +++ b/packages/core/fixtures/driver-test-suite/auth.ts @@ -0,0 +1,105 @@ +import { worker, UserError } from "rivetkit"; + +// Basic auth worker - requires API key +export const authWorker = worker({ + state: { requests: 0 }, + onAuth: (opts) => { + const { req, intents, params } = opts; + const apiKey = (params as any)?.apiKey; + if (!apiKey) { + throw new UserError("API key required", { code: "missing_auth" }); + } + + if (apiKey !== "valid-api-key") { + throw new UserError("Invalid API key", { code: "invalid_auth" }); + } + + return { userId: "user123", token: apiKey }; + }, + actions: { + getRequests: (c) => { + c.state.requests++; + return c.state.requests; + }, + getUserAuth: (c) => c.conn.auth, + }, +}); + +// Intent-specific auth worker - checks different permissions for different intents +export const intentAuthWorker = worker({ + state: { value: 0 }, + onAuth: (opts) => { + const { req, intents, params } = opts; + console.log('intents', intents, params); + const role = (params as any)?.role; + + if (intents.has("create") && role !== "admin") { + throw new UserError("Admin role required for create operations", { code: "insufficient_permissions" }); + } + + if (intents.has("action") && !["admin", "user"].includes(role || "")) { + throw new UserError("User or admin role required for actions", { code: "insufficient_permissions" }); + } + + return { role, timestamp: Date.now() }; + }, + actions: { + getValue: (c) => c.state.value, + setValue: (c, value: number) => { + c.state.value = value; + return value; + }, + getAuth: (c) => c.conn.auth, + }, +}); + +// Public worker - empty onAuth to allow public access +export const publicWorker = worker({ + state: { visitors: 0 }, + onAuth: () => { + return null; // Allow public access + }, + actions: { + visit: (c) => { + c.state.visitors++; + return c.state.visitors; + }, + }, +}); + +// No auth worker - should fail when accessed publicly (no onAuth defined) +export const noAuthWorker = worker({ + state: { value: 42 }, + actions: { + getValue: (c) => c.state.value, + }, +}); + +// Async auth worker - tests promise-based authentication +export const asyncAuthWorker = worker({ + state: { count: 0 }, + onAuth: async (opts) => { + const { req, intents, params } = opts; + // Simulate async auth check (e.g., database lookup) + await new Promise(resolve => setTimeout(resolve, 10)); + + const token = (params as any)?.token; + if (!token) { + throw new UserError("Token required", { code: "missing_token" }); + } + + // Simulate token validation + if (token === "invalid") { + throw new UserError("Token is invalid", { code: "invalid_token" }); + } + + return { userId: `user-${token}`, validated: true }; + }, + actions: { + increment: (c) => { + c.state.count++; + return c.state.count; + }, + getAuthData: (c) => c.conn.auth, + }, +}); diff --git a/packages/core/fixtures/driver-test-suite/conn-params.ts b/packages/core/fixtures/driver-test-suite/conn-params.ts index 9ca3a94c2..44bc4160d 100644 --- a/packages/core/fixtures/driver-test-suite/conn-params.ts +++ b/packages/core/fixtures/driver-test-suite/conn-params.ts @@ -1,6 +1,7 @@ import { worker } from "rivetkit"; export const counterWithParams = worker({ + onAuth: () => {}, state: { count: 0, initializers: [] as string[] }, createConnState: (c, { params }: { params: { name?: string } }) => { return { diff --git a/packages/core/fixtures/driver-test-suite/conn-state.ts b/packages/core/fixtures/driver-test-suite/conn-state.ts index 279f3b5df..4c5a55e3a 100644 --- a/packages/core/fixtures/driver-test-suite/conn-state.ts +++ b/packages/core/fixtures/driver-test-suite/conn-state.ts @@ -8,6 +8,7 @@ export type ConnState = { }; export const connStateWorker = worker({ + onAuth: () => {}, state: { sharedCounter: 0, disconnectionCount: 0, diff --git a/packages/core/fixtures/driver-test-suite/counter.ts b/packages/core/fixtures/driver-test-suite/counter.ts index 0c0254a01..bb1aea230 100644 --- a/packages/core/fixtures/driver-test-suite/counter.ts +++ b/packages/core/fixtures/driver-test-suite/counter.ts @@ -1,6 +1,7 @@ import { worker } from "rivetkit"; export const counter = worker({ + onAuth: () => {}, state: { count: 0 }, actions: { increment: (c, x: number) => { diff --git a/packages/core/fixtures/driver-test-suite/error-handling.ts b/packages/core/fixtures/driver-test-suite/error-handling.ts index e84d5b050..1cbdf947d 100644 --- a/packages/core/fixtures/driver-test-suite/error-handling.ts +++ b/packages/core/fixtures/driver-test-suite/error-handling.ts @@ -1,6 +1,7 @@ import { worker, UserError } from "rivetkit"; export const errorHandlingWorker = worker({ + onAuth: () => {}, state: { errorLog: [] as string[], }, diff --git a/packages/core/fixtures/driver-test-suite/lifecycle.ts b/packages/core/fixtures/driver-test-suite/lifecycle.ts index f316c1f78..2b81da1b6 100644 --- a/packages/core/fixtures/driver-test-suite/lifecycle.ts +++ b/packages/core/fixtures/driver-test-suite/lifecycle.ts @@ -1,6 +1,7 @@ import { worker } from "rivetkit"; export const counterWithLifecycle = worker({ + onAuth: () => {}, state: { count: 0, events: [] as string[], diff --git a/packages/core/fixtures/driver-test-suite/metadata.ts b/packages/core/fixtures/driver-test-suite/metadata.ts index bf64c5142..0f3879a0c 100644 --- a/packages/core/fixtures/driver-test-suite/metadata.ts +++ b/packages/core/fixtures/driver-test-suite/metadata.ts @@ -3,6 +3,7 @@ import { worker } from "rivetkit"; // Note: For testing only - metadata API will need to be mocked // in tests since this is implementation-specific export const metadataWorker = worker({ + onAuth: () => {}, state: { lastMetadata: null as any, workerName: "", diff --git a/packages/core/fixtures/driver-test-suite/registry.ts b/packages/core/fixtures/driver-test-suite/registry.ts index 1b1e5ca1b..5c30dba0c 100644 --- a/packages/core/fixtures/driver-test-suite/registry.ts +++ b/packages/core/fixtures/driver-test-suite/registry.ts @@ -27,6 +27,13 @@ import { uniqueVarWorker, driverCtxWorker, } from "./vars"; +import { + authWorker, + intentAuthWorker, + publicWorker, + noAuthWorker, + asyncAuthWorker, +} from "./auth"; // Consolidated setup with all workers export const registry = setup({ @@ -63,6 +70,12 @@ export const registry = setup({ dynamicVarWorker, uniqueVarWorker, driverCtxWorker, + // From auth.ts + authWorker, + intentAuthWorker, + publicWorker, + noAuthWorker, + asyncAuthWorker, }, }); diff --git a/packages/core/fixtures/driver-test-suite/scheduled.ts b/packages/core/fixtures/driver-test-suite/scheduled.ts index a5ce59a7e..9590b0f49 100644 --- a/packages/core/fixtures/driver-test-suite/scheduled.ts +++ b/packages/core/fixtures/driver-test-suite/scheduled.ts @@ -1,6 +1,7 @@ import { worker } from "rivetkit"; export const scheduled = worker({ + onAuth: () => {}, state: { lastRun: 0, scheduledCount: 0, diff --git a/packages/core/fixtures/driver-test-suite/vars.ts b/packages/core/fixtures/driver-test-suite/vars.ts index a42b7f042..753e615a1 100644 --- a/packages/core/fixtures/driver-test-suite/vars.ts +++ b/packages/core/fixtures/driver-test-suite/vars.ts @@ -2,6 +2,7 @@ import { worker } from "rivetkit"; // Worker with static vars export const staticVarWorker = worker({ + onAuth: () => {}, state: { value: 0 }, connState: { hello: "world" }, vars: { counter: 42, name: "test-worker" }, @@ -17,6 +18,7 @@ export const staticVarWorker = worker({ // Worker with nested vars export const nestedVarWorker = worker({ + onAuth: () => {}, state: { value: 0 }, connState: { hello: "world" }, vars: { @@ -43,6 +45,7 @@ export const nestedVarWorker = worker({ // Worker with dynamic vars export const dynamicVarWorker = worker({ + onAuth: () => {}, state: { value: 0 }, connState: { hello: "world" }, createVars: () => { @@ -60,6 +63,7 @@ export const dynamicVarWorker = worker({ // Worker with unique vars per instance export const uniqueVarWorker = worker({ + onAuth: () => {}, state: { value: 0 }, connState: { hello: "world" }, createVars: () => { @@ -76,6 +80,7 @@ export const uniqueVarWorker = worker({ // Worker that uses driver context export const driverCtxWorker = worker({ + onAuth: () => {}, state: { value: 0 }, connState: { hello: "world" }, createVars: (c, driverCtx: any) => { diff --git a/packages/core/package.json b/packages/core/package.json index 1d0ee6514..6466b540b 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -71,6 +71,16 @@ "default": "./dist/driver-helpers/mod.cjs" } }, + "./driver-helpers/websocket": { + "import": { + "types": "./dist/common/websocket.d.ts", + "default": "./dist/common/websocket.js" + }, + "require": { + "types": "./dist/common/websocket.d.cts", + "default": "./dist/common/websocket.cjs" + } + }, "./driver-test-suite": { "import": { "types": "./dist/driver-test-suite/mod.d.ts", @@ -148,7 +158,7 @@ "sideEffects": false, "scripts": { "dev": "yarn build --watch", - "build": "tsup src/mod.ts src/client/mod.ts src/common/log.ts src/worker/errors.ts src/topologies/coordinate/mod.ts src/topologies/partition/mod.ts src/utils.ts src/driver-helpers/mod.ts src/driver-test-suite/mod.ts src/worker/protocol/inspector/mod.ts src/test/mod.ts src/inspector/protocol/worker/mod.ts src/inspector/protocol/manager/mod.ts src/inspector/mod.ts", + "build": "tsup src/mod.ts src/client/mod.ts src/common/log.ts src/common/websocket.ts src/worker/errors.ts src/topologies/coordinate/mod.ts src/topologies/partition/mod.ts src/utils.ts src/driver-helpers/mod.ts src/driver-test-suite/mod.ts src/worker/protocol/inspector/mod.ts src/test/mod.ts src/inspector/protocol/worker/mod.ts src/inspector/protocol/manager/mod.ts src/inspector/mod.ts", "check-types": "tsc --noEmit", "boop": "tsc --outDir dist/test -d", "test": "vitest run", diff --git a/packages/core/src/client/client.ts b/packages/core/src/client/client.ts index 63458c8e7..af6514f68 100644 --- a/packages/core/src/client/client.ts +++ b/packages/core/src/client/client.ts @@ -1,12 +1,7 @@ import type { Transport } from "@/worker/protocol/message/mod"; import type { Encoding } from "@/worker/protocol/serde"; import type { WorkerQuery } from "@/manager/protocol/query"; -import { - WorkerConn, - WorkerConnRaw, - CONNECT_SYMBOL, - SendHttpMessageOpts, -} from "./worker-conn"; +import { WorkerConn, WorkerConnRaw, CONNECT_SYMBOL } from "./worker-conn"; import { WorkerHandle, WorkerHandleRaw } from "./worker-handle"; import { WorkerActionFunction } from "./worker-common"; import { logger } from "./log"; @@ -15,8 +10,7 @@ import type { AnyWorkerDefinition } from "@/worker/definition"; import type * as wsToServer from "@/worker/protocol/message/to-server"; import type { EventSource } from "eventsource"; import type { Context as HonoContext } from "hono"; -import { createHttpClientDriver } from "./http-client-driver"; -import { HonoRequest } from "hono"; +import type { WebSocket } from "ws"; /** Extract the worker registry from the registry definition. */ export type ExtractWorkersFromRegistry> = @@ -172,6 +166,7 @@ export interface ClientDriver { c: HonoContext | undefined, workerQuery: WorkerQuery, encodingKind: Encoding, + params: unknown, ): Promise; connectWebSocket( c: HonoContext | undefined, @@ -364,6 +359,7 @@ export class ClientRaw { undefined, createQuery, this.#encodingKind, + opts?.params, ); logger().debug("created worker with ID", { name, @@ -481,11 +477,9 @@ export function createClientWithDriver>( key?: string | string[], opts?: GetOptions, ): WorkerHandle[typeof prop]> => { - return target.getOrCreate[typeof prop]>( - prop, - key, - opts, - ); + return target.getOrCreate< + ExtractWorkersFromRegistry[typeof prop] + >(prop, key, opts); }, getForId: ( workerId: string, @@ -500,12 +494,12 @@ export function createClientWithDriver>( create: async ( key: string | string[], opts: CreateOptions = {}, - ): Promise[typeof prop]>> => { - return await target.create[typeof prop]>( - prop, - key, - opts, - ); + ): Promise< + WorkerHandle[typeof prop]> + > => { + return await target.create< + ExtractWorkersFromRegistry[typeof prop] + >(prop, key, opts); }, } as WorkerAccessor[typeof prop]>; } diff --git a/packages/core/src/client/http-client-driver.ts b/packages/core/src/client/http-client-driver.ts index 86ef73b92..18919d8e9 100644 --- a/packages/core/src/client/http-client-driver.ts +++ b/packages/core/src/client/http-client-driver.ts @@ -26,6 +26,7 @@ import type { ActionRequest } from "@/worker/protocol/http/action"; import type { ActionResponse } from "@/worker/protocol/message/to-client"; import { ClientDriver } from "./client"; import { HonoRequest, Context as HonoContext } from "hono"; +import type { WebSocket } from "ws"; /** * Client driver that communicates with the manager via HTTP. @@ -82,6 +83,7 @@ export function createHttpClientDriver(managerEndpoint: string): ClientDriver { _c: HonoContext | undefined, workerQuery: WorkerQuery, encodingKind: Encoding, + params: unknown, ): Promise => { logger().debug("resolving worker ID", { query: workerQuery }); @@ -95,6 +97,9 @@ export function createHttpClientDriver(managerEndpoint: string): ClientDriver { headers: { [HEADER_ENCODING]: encodingKind, [HEADER_WORKER_QUERY]: JSON.stringify(workerQuery), + ...(params !== undefined + ? { [HEADER_CONN_PARAMS]: JSON.stringify(params) } + : {}), }, body: {}, encoding: encodingKind, @@ -122,36 +127,37 @@ export function createHttpClientDriver(managerEndpoint: string): ClientDriver { ): Promise => { const { WebSocket } = await dynamicImports; - const workerQueryStr = encodeURIComponent(JSON.stringify(workerQuery)); const endpoint = managerEndpoint .replace(/^http:/, "ws:") .replace(/^https:/, "wss:"); - const url = `${endpoint}/workers/connect/websocket?encoding=${encodingKind}&query=${workerQueryStr}`; + const url = `${endpoint}/workers/connect/websocket`; + + // Pass sensitive data via protocol + const protocol = [ + `query.${encodeURIComponent(JSON.stringify(workerQuery))}`, + `encoding.${encodingKind}`, + ]; + if (params) + protocol.push( + `conn_params.${encodeURIComponent(JSON.stringify(params))}`, + ); + + // HACK: See packages/platforms/cloudflare-workers/src/websocket.ts + protocol.push("rivetkit"); logger().debug("connecting to websocket", { url }); - const ws = new WebSocket(url); + const ws = new WebSocket(url, protocol); if (encodingKind === "cbor") { ws.binaryType = "arraybuffer"; } else if (encodingKind === "json") { // HACK: Bun bug prevents changing binary type, so we ignore the error https://github.com/oven-sh/bun/issues/17005 try { - ws.binaryType = "blob"; + ws.binaryType = "blob" as any; } catch (error) {} } else { assertUnreachable(encodingKind); } - ws.addEventListener("open", () => { - // Send init message with the initialization data - // - // We can't pass this data in the query string since it might include sensitive data which would get logged - const messageSerialized = serializeWithEncoding(encodingKind, { - b: { i: { p: params } }, - }); - ws.send(messageSerialized); - logger().debug("sent websocket init message"); - }); - return ws; }, diff --git a/packages/core/src/client/worker-conn.ts b/packages/core/src/client/worker-conn.ts index 3c749af06..a8213c3b1 100644 --- a/packages/core/src/client/worker-conn.ts +++ b/packages/core/src/client/worker-conn.ts @@ -23,6 +23,7 @@ import { type WebSocketMessage as ConnMessage, messageLength, serializeWithEncoding, + WebSocketMessage, } from "./utils"; import { HEADER_WORKER_ID, @@ -34,6 +35,7 @@ import { } from "@/worker/router-endpoints"; import type { EventSource } from "eventsource"; import { WorkerDefinitionActions } from "./worker-common"; +import type { WebSocket, CloseEvent, ErrorEvent } from "ws"; interface ActionInFlight { name: string; @@ -265,13 +267,13 @@ enc logger().debug("websocket open"); }; ws.onmessage = async (ev) => { - this.#handleOnMessage(ev); + this.#handleOnMessage(ev.data); }; ws.onclose = (ev) => { this.#handleOnClose(ev); }; ws.onerror = (ev) => { - this.#handleOnError(ev); + this.#handleOnError(); }; } @@ -288,7 +290,7 @@ enc // #handleOnOpen is called on "i" event }; eventSource.onmessage = (ev) => { - this.#handleOnMessage(ev); + this.#handleOnMessage(ev.data); }; eventSource.onerror = (ev) => { if (eventSource.readyState === eventSource.CLOSED) { @@ -296,7 +298,7 @@ enc this.#handleOnClose(ev); } else { // Log error since event source is still open - this.#handleOnError(ev); + this.#handleOnError(); } }; } @@ -330,14 +332,16 @@ enc } /** Called by the onmessage event from drivers. */ - async #handleOnMessage(event: MessageEvent) { + async #handleOnMessage(data: any) { logger().trace("received message", { - dataType: typeof event.data, - isBlob: event.data instanceof Blob, - isArrayBuffer: event.data instanceof ArrayBuffer, + dataType: typeof data, + isBlob: data instanceof Blob, + isArrayBuffer: data instanceof ArrayBuffer, }); - const response = (await this.#parse(event.data)) as wsToClient.ToClient; + const response = (await this.#parse( + data as ConnMessage, + )) as wsToClient.ToClient; logger().trace("parsed message", { response: JSON.stringify(response).substring(0, 100) + "...", }); @@ -459,11 +463,11 @@ enc } /** Called by the onerror event from drivers. */ - #handleOnError(event: Event) { + #handleOnError() { if (this.#disposed) return; // More detailed information will be logged in onclose - logger().warn("socket error", { event }); + logger().warn("socket error"); } #takeActionInFlight(id: number): ActionInFlight { diff --git a/packages/core/src/client/worker-handle.ts b/packages/core/src/client/worker-handle.ts index cf381b0fe..4aa43499c 100644 --- a/packages/core/src/client/worker-handle.ts +++ b/packages/core/src/client/worker-handle.ts @@ -109,6 +109,7 @@ export class WorkerHandleRaw { undefined, this.#workerQuery, this.#encodingKind, + this.#params, ); this.#workerQuery = { getForId: { workerId } }; return workerId; diff --git a/packages/core/src/common/router.ts b/packages/core/src/common/router.ts index a113ada5f..467ce0d46 100644 --- a/packages/core/src/common/router.ts +++ b/packages/core/src/common/router.ts @@ -1,11 +1,11 @@ import type { Context as HonoContext, Next } from "hono"; import { getLogger, Logger } from "./log"; -import { deconstructError } from "./utils"; +import { deconstructError, stringifyError } from "./utils"; import { getRequestEncoding, getRequestExposeInternalError, } from "@/worker/router-endpoints"; -import { serialize } from "@/worker/protocol/serde"; +import { Encoding, serialize } from "@/worker/protocol/serde"; import { ResponseError } from "@/worker/protocol/http/error"; export function logger() { @@ -21,7 +21,7 @@ export function loggerMiddleware(logger: Logger) { await next(); const duration = Date.now() - startTime; - logger.info("http request", { + logger.debug("http request", { method, path, status: c.res.status, @@ -48,7 +48,7 @@ export function handleRouteError( ) { const exposeInternalError = opts.enableExposeInternalError && - getRequestExposeInternalError(c.req, false); + getRequestExposeInternalError(c.req); const { statusCode, code, message, metadata } = deconstructError( error, @@ -60,7 +60,16 @@ export function handleRouteError( exposeInternalError, ); - const encoding = getRequestEncoding(c.req, false); + let encoding: Encoding; + try { + encoding = getRequestEncoding(c.req); + } catch (err) { + logger().debug("failed to extract encoding", { + error: stringifyError(err), + }); + encoding = "json"; + } + const output = serialize( { c: code, diff --git a/packages/core/src/common/utils.ts b/packages/core/src/common/utils.ts index 692902fa4..3270094cf 100644 --- a/packages/core/src/common/utils.ts +++ b/packages/core/src/common/utils.ts @@ -204,7 +204,7 @@ export function deconstructError( export function stringifyError(error: unknown): string { if (error instanceof Error) { - if (process.env._WORKER_CORE_ERROR_STACK === "1") { + if (process.env._RIVETKIT_ERROR_STACK === "1") { return `${error.name}: ${error.message}${error.stack ? `\n${error.stack}` : ""}`; } else { return `${error.name}: ${error.message}`; diff --git a/packages/core/src/common/websocket.ts b/packages/core/src/common/websocket.ts index 0b36cab4a..10ba790ea 100644 --- a/packages/core/src/common/websocket.ts +++ b/packages/core/src/common/websocket.ts @@ -1,4 +1,5 @@ import { logger } from "@/client/log"; +import type { WebSocket } from "ws"; // Global singleton promise that will be reused for subsequent calls let webSocketPromise: Promise | null = null; @@ -13,9 +14,9 @@ export async function importWebSocket(): Promise { webSocketPromise = (async () => { let _WebSocket: typeof WebSocket; - if (typeof WebSocket !== "undefined") { + if (typeof global.WebSocket !== "undefined") { // Browser environment - _WebSocket = WebSocket; + _WebSocket = global.WebSocket as unknown as typeof WebSocket; logger().debug("using native websocket"); } else { // Node.js environment diff --git a/packages/core/src/driver-helpers/mod.ts b/packages/core/src/driver-helpers/mod.ts index f49f60893..f5bf79f29 100644 --- a/packages/core/src/driver-helpers/mod.ts +++ b/packages/core/src/driver-helpers/mod.ts @@ -17,3 +17,13 @@ export { GetOrCreateWithKeyInput, WorkerOutput, } from "@/manager/driver"; +export { + HEADER_WORKER_QUERY, + HEADER_ENCODING, + HEADER_EXPOSE_INTERNAL_ERROR, + HEADER_CONN_PARAMS, + HEADER_AUTH_DATA, + HEADER_WORKER_ID, + HEADER_CONN_ID, + HEADER_CONN_TOKEN, +} from "@/worker/router-endpoints"; diff --git a/packages/core/src/driver-test-suite/mod.ts b/packages/core/src/driver-test-suite/mod.ts index 455a7b220..48b5b6727 100644 --- a/packages/core/src/driver-test-suite/mod.ts +++ b/packages/core/src/driver-test-suite/mod.ts @@ -21,6 +21,7 @@ import { runWorkerVarsTests } from "./tests/worker-vars"; import { runWorkerConnStateTests } from "./tests/worker-conn-state"; import { runWorkerMetadataTests } from "./tests/worker-metadata"; import { runWorkerErrorHandlingTests } from "./tests/worker-error-handling"; +import { runWorkerAuthTests } from "./tests/worker-auth"; export interface DriverTestConfig { /** Deploys an registry and returns the connection endpoint. */ @@ -90,6 +91,8 @@ export function runDriverTests( runWorkerMetadataTests(driverTestConfig); runWorkerErrorHandlingTests(driverTestConfig); + + runWorkerAuthTests(driverTestConfig); }); } } diff --git a/packages/core/src/driver-test-suite/test-inline-client-driver.ts b/packages/core/src/driver-test-suite/test-inline-client-driver.ts index 09b701159..60cbfaaff 100644 --- a/packages/core/src/driver-test-suite/test-inline-client-driver.ts +++ b/packages/core/src/driver-test-suite/test-inline-client-driver.ts @@ -13,6 +13,8 @@ import { import { assertUnreachable } from "@/worker/utils"; import * as cbor from "cbor-x"; import { WorkerError as ClientWorkerError } from "@/client/errors"; +import type { WebSocket } from "ws"; +import { importWebSocket } from "@/common/websocket"; /** * Creates a client driver used for testing the inline client driver. This will send a request to the HTTP server which will then internally call the internal client and return the response. @@ -43,13 +45,14 @@ export function createTestInlineClientDriver( c: HonoContext | undefined, workerQuery: WorkerQuery, encodingKind: Encoding, + params: unknown, ): Promise => { return makeInlineRequest( endpoint, encodingKind, transport, "resolveWorkerId", - [undefined, workerQuery, encodingKind], + [undefined, workerQuery, encodingKind, params], ); }, @@ -59,6 +62,8 @@ export function createTestInlineClientDriver( encodingKind: Encoding, params: unknown, ): Promise => { + const WebSocket = await importWebSocket(); + logger().info("creating websocket connection via test inline driver", { workerQuery, encodingKind, @@ -80,7 +85,10 @@ export function createTestInlineClientDriver( logger().debug("connecting to websocket", { url: finalWsUrl }); // Create and return the WebSocket - return new WebSocket(finalWsUrl); + return new WebSocket(finalWsUrl, [ + // HACK: See packages/platforms/cloudflare-workers/src/websocket.ts + "rivetkit", + ]); }, connectSse: async ( diff --git a/packages/core/src/driver-test-suite/tests/worker-auth.ts b/packages/core/src/driver-test-suite/tests/worker-auth.ts new file mode 100644 index 000000000..c7fcb61ce --- /dev/null +++ b/packages/core/src/driver-test-suite/tests/worker-auth.ts @@ -0,0 +1,318 @@ +import { describe, test, expect } from "vitest"; +import type { DriverTestConfig } from "../mod"; +import { setupDriverTest } from "../utils"; +import { WorkerError } from "@/client/errors"; + +export function runWorkerAuthTests(driverTestConfig: DriverTestConfig) { + describe("Worker Authentication Tests", () => { + describe("Basic Authentication", () => { + test("should allow access with valid auth", async (c) => { + const { client } = await setupDriverTest(c, driverTestConfig); + + // Create client with valid auth params + const instance = client.authWorker.getOrCreate(undefined, { + params: { apiKey: "valid-api-key" }, + }); + + // This should succeed with valid API key + const authData = await instance.getUserAuth(); + if (driverTestConfig.clientType === "inline") { + // Inline clients don't have auth data + expect(authData).toBeUndefined(); + } else { + // HTTP clients should have auth data + expect(authData).toEqual({ userId: "user123", token: "valid-api-key" }); + } + + // Should be able to call actions + const requests = await instance.getRequests(); + expect(requests).toBe(1); + }); + + test("should deny access with invalid auth", async (c) => { + const { client } = await setupDriverTest(c, driverTestConfig); + + // This should fail without proper authorization + const instance = client.authWorker.getOrCreate(); + + if (driverTestConfig.clientType === "inline") { + // Inline clients bypass authentication + const requests = await instance.getRequests(); + expect(typeof requests).toBe("number"); + } else { + // HTTP clients should enforce authentication + try { + await instance.getRequests(); + expect.fail("Expected authentication error"); + } catch (error) { + expect((error as WorkerError).code).toBe("missing_auth"); + } + } + }); + + test("should expose auth data on connection", async (c) => { + const { client } = await setupDriverTest(c, driverTestConfig); + + const instance = client.authWorker.getOrCreate(undefined, { + params: { apiKey: "valid-api-key" }, + }); + + // Auth data should be available via c.conn.auth + const authData = await instance.getUserAuth(); + if (driverTestConfig.clientType === "inline") { + // Inline clients don't have auth data + expect(authData).toBeUndefined(); + } else { + // HTTP clients should have auth data + expect(authData).toBeDefined(); + expect((authData as any).userId).toBe("user123"); + expect((authData as any).token).toBe("valid-api-key"); + } + }); + }); + + describe("Intent-Based Authentication", () => { + test("should allow get operations for any role", async (c) => { + const { client } = await setupDriverTest(c, driverTestConfig); + + const createdInstance = await client.intentAuthWorker.create(["foo"], { + params: { role: "admin" }, + }); + const workerId = await createdInstance.resolve(); + + if (driverTestConfig.clientType === "inline") { + // Inline clients bypass authentication + const instance = client.intentAuthWorker.getForId(workerId); + const value = await instance.getValue(); + expect(value).toBe(0); + } else { + // HTTP clients - actions require user or admin role + const instance = client.intentAuthWorker.getForId(workerId, { + params: { role: "user" }, // Actions require user or admin role + }); + const value = await instance.getValue(); + expect(value).toBe(0); + } + }); + + test("should require admin role for create operations", async (c) => { + const { client } = await setupDriverTest(c, driverTestConfig); + + if (driverTestConfig.clientType === "inline") { + // Inline clients bypass authentication - should succeed + const instance = client.intentAuthWorker.getOrCreate(undefined, { + params: { role: "user" }, + }); + const value = await instance.getValue(); + expect(value).toBe(0); + } else { + // HTTP clients should enforce authentication + try { + const instance = client.intentAuthWorker.getOrCreate(undefined, { + params: { role: "user" }, + }); + await instance.getValue(); + expect.fail("Expected permission error for create operation"); + } catch (error) { + expect((error as WorkerError).code).toBe("insufficient_permissions"); + expect((error as WorkerError).message).toContain( + "Admin role required", + ); + } + } + }); + + test("should allow actions for user and admin roles", async (c) => { + const { client } = await setupDriverTest(c, driverTestConfig); + + const createdInstance = await client.intentAuthWorker.create(["foo"], { + params: { role: "admin" }, + }); + const workerId = await createdInstance.resolve(); + + // This should fail - actions require user or admin role + const instance = client.intentAuthWorker.getForId(workerId, { + params: { role: "guest" }, + }); + + if (driverTestConfig.clientType === "inline") { + // Inline clients bypass authentication - should succeed + const result = await instance.setValue(42); + expect(result).toBe(42); + } else { + // HTTP clients should enforce authentication + try { + await instance.setValue(42); + expect.fail("Expected permission error for action"); + } catch (error) { + expect((error as WorkerError).code).toBe("insufficient_permissions"); + expect((error as WorkerError).message).toContain( + "User or admin role required", + ); + } + } + }); + }); + + describe("Public Access", () => { + test("should allow access with empty onAuth", async (c) => { + const { client } = await setupDriverTest(c, driverTestConfig); + + // Public worker should allow access without authentication + const instance = client.publicWorker.getOrCreate(); + + const visitors = await instance.visit(); + expect(visitors).toBe(1); + + // Should be able to call multiple times + const visitors2 = await instance.visit(); + expect(visitors2).toBe(2); + }); + + test("should deny access without onAuth defined", async (c) => { + const { client } = await setupDriverTest(c, driverTestConfig); + + // Worker without onAuth should be blocked + const instance = client.noAuthWorker.getOrCreate(); + + if (driverTestConfig.clientType === "inline") { + // Inline clients bypass authentication - should succeed + const value = await instance.getValue(); + expect(value).toBe(42); + } else { + // HTTP clients should enforce authentication + try { + await instance.getValue(); + expect.fail("Expected access to be denied for worker without onAuth"); + } catch (error) { + expect((error as WorkerError).code).toBe("forbidden"); + } + } + }); + }); + + describe("Async Authentication", () => { + test("should handle promise-based auth", async (c) => { + const { client } = await setupDriverTest(c, driverTestConfig); + + const instance = client.asyncAuthWorker.getOrCreate(undefined, { + params: { token: "valid" }, + }); + + // Should succeed with valid token + const result = await instance.increment(); + expect(result).toBe(1); + + // Auth data should be available + const authData = await instance.getAuthData(); + if (driverTestConfig.clientType === "inline") { + // Inline clients don't have auth data + expect(authData).toBeUndefined(); + } else { + // HTTP clients should have auth data + expect(authData).toBeDefined(); + expect((authData as any).userId).toBe("user-valid"); + expect((authData as any).validated).toBe(true); + } + }); + + test("should handle async auth failures", async (c) => { + const { client } = await setupDriverTest(c, driverTestConfig); + + const instance = client.asyncAuthWorker.getOrCreate(); + + if (driverTestConfig.clientType === "inline") { + // Inline clients bypass authentication - should succeed + const result = await instance.increment(); + expect(result).toBe(1); + } else { + // HTTP clients should enforce authentication + try { + await instance.increment(); + expect.fail("Expected async auth failure"); + } catch (error) { + expect((error as WorkerError).code).toBe("missing_token"); + } + } + }); + }); + + describe("Authentication Across Transports", () => { + if (driverTestConfig.transport === "websocket") { + test("should authenticate WebSocket connections", async (c) => { + const { client } = await setupDriverTest(c, driverTestConfig); + + // Test WebSocket connection auth + const instance = client.authWorker.getOrCreate(undefined, { + params: { apiKey: "valid-api-key" }, + }); + + // Should be able to establish connection and call actions + const authData = await instance.getUserAuth(); + expect(authData).toBeDefined(); + expect((authData as any).userId).toBe("user123"); + }); + } + + test("should authenticate HTTP actions", async (c) => { + const { client } = await setupDriverTest(c, driverTestConfig); + + // Test HTTP action auth + const instance = client.authWorker.getOrCreate(undefined, { + params: { apiKey: "valid-api-key" }, + }); + + // Actions should require authentication + const requests = await instance.getRequests(); + expect(typeof requests).toBe("number"); + }); + }); + + describe("Error Handling", () => { + test("should handle auth errors gracefully", async (c) => { + const { client } = await setupDriverTest(c, driverTestConfig); + + const instance = client.authWorker.getOrCreate(); + + if (driverTestConfig.clientType === "inline") { + // Inline clients bypass authentication - should succeed + const requests = await instance.getRequests(); + expect(typeof requests).toBe("number"); + } else { + // HTTP clients should enforce authentication + try { + await instance.getRequests(); + expect.fail("Expected authentication error"); + } catch (error) { + // Error should be properly structured + const workerError = error as WorkerError; + expect(workerError.code).toBeDefined(); + expect(workerError.message).toBeDefined(); + } + } + }); + + test("should preserve error details for debugging", async (c) => { + const { client } = await setupDriverTest(c, driverTestConfig); + + const instance = client.asyncAuthWorker.getOrCreate(); + + if (driverTestConfig.clientType === "inline") { + // Inline clients bypass authentication - should succeed + const result = await instance.increment(); + expect(result).toBe(1); + } else { + // HTTP clients should enforce authentication + try { + await instance.increment(); + expect.fail("Expected token error"); + } catch (error) { + const workerError = error as WorkerError; + expect(workerError.code).toBe("missing_token"); + expect(workerError.message).toBe("Token required"); + } + } + }); + }); + }); +} diff --git a/packages/core/src/globals.d.ts b/packages/core/src/globals.d.ts new file mode 100644 index 000000000..7e7ef7fcf --- /dev/null +++ b/packages/core/src/globals.d.ts @@ -0,0 +1 @@ +declare const navigator: any; diff --git a/packages/core/src/inline-client-driver/fake-event-source.ts b/packages/core/src/inline-client-driver/fake-event-source.ts index 7ca6dcaa3..148a35a4e 100644 --- a/packages/core/src/inline-client-driver/fake-event-source.ts +++ b/packages/core/src/inline-client-driver/fake-event-source.ts @@ -1,6 +1,6 @@ import { logger } from "./log"; import type { SSEStreamingApi } from "hono/streaming"; -import type { EventSource } from "eventsource"; +import type { EventListener, EventSource } from "eventsource"; /** * FakeEventSource provides a minimal implementation of an SSE stream diff --git a/packages/core/src/inline-client-driver/fake-websocket.ts b/packages/core/src/inline-client-driver/fake-websocket.ts index bd1184e39..61a6bcb96 100644 --- a/packages/core/src/inline-client-driver/fake-websocket.ts +++ b/packages/core/src/inline-client-driver/fake-websocket.ts @@ -4,24 +4,28 @@ import type { ConnectWebSocketOutput } from "@/worker/router-endpoints"; import type * as messageToServer from "@/worker/protocol/message/to-server"; import { parseMessage } from "@/worker/protocol/message/mod"; import type { InputData } from "@/worker/protocol/serde"; +import type { + Event, + CloseEvent, + MessageEvent, +} from "ws"; /** * FakeWebSocket implements a WebSocket-like interface * that connects to a ConnectWebSocketOutput handler */ -export class FakeWebSocket implements WebSocket { +export class FakeWebSocket { // WebSocket interface properties - binaryType: BinaryType = "arraybuffer"; - bufferedAmount: number = 0; - extensions: string = ""; - protocol: string = ""; - url: string = ""; + bufferedAmount = 0; + extensions = ""; + protocol = ""; + url = ""; // Event handlers - onclose: ((ev: CloseEvent) => any) | null = null; - onerror: ((ev: Event) => any) | null = null; - onmessage: ((ev: MessageEvent) => any) | null = null; - onopen: ((ev: Event) => any) | null = null; + onclose: ((ev: any) => void) | null = null; + onerror: ((ev: any) => void) | null = null; + onmessage: ((ev: any) => void) | null = null; + onopen: ((ev: any) => void) | null = null; // WebSocket readyState values readonly CONNECTING = 0 as const; @@ -117,13 +121,11 @@ export class FakeWebSocket implements WebSocket { logger().debug("fake websocket sending message", { messageType: message.b && - ("i" in message.b - ? "init" - : "ar" in message.b - ? "action" - : "sr" in message.b - ? "subscription" - : "unknown"), + ("ar" in message.b + ? "action" + : "sr" in message.b + ? "subscription" + : "unknown"), }); this.#handler.onMessage(message).catch((err) => { @@ -220,10 +222,7 @@ export class FakeWebSocket implements WebSocket { /** * Implementation of EventTarget methods (minimal implementation) */ - addEventListener( - type: string, - listener: EventListenerOrEventListenerObject, - ): void { + addEventListener(type: string, listener: any): void { // Map to the onXXX properties switch (type) { case "open": @@ -271,7 +270,7 @@ export class FakeWebSocket implements WebSocket { } } - dispatchEvent(event: Event): boolean { + dispatchEvent(event: any): boolean { // Dispatch to the corresponding handler switch (event.type) { case "open": diff --git a/packages/core/src/inline-client-driver/mod.ts b/packages/core/src/inline-client-driver/mod.ts index 34bf18259..7129a23ae 100644 --- a/packages/core/src/inline-client-driver/mod.ts +++ b/packages/core/src/inline-client-driver/mod.ts @@ -6,7 +6,6 @@ import type * as wsToServer from "@/worker/protocol/message/to-server"; import { type Encoding, serialize } from "@/worker/protocol/serde"; import { ConnectWebSocketOutput, - handleWebSocketConnect, HEADER_CONN_PARAMS, HEADER_ENCODING, HEADER_CONN_ID, @@ -32,6 +31,7 @@ import onChange from "on-change"; import { httpUserAgent } from "@/utils"; import { WorkerError as ClientWorkerError } from "@/client/errors"; import { deconstructError } from "@/common/utils"; +import type { WebSocket } from "ws"; /** * Client driver that calls the manager driver inline. @@ -58,13 +58,9 @@ export function createInlineClientDriver( ...args: Args ): Promise => { try { - // Get the worker ID and meta - const { workerId, meta } = await queryWorker( - c, - workerQuery, - managerDriver, - ); - logger().debug("found worker for action", { workerId, meta }); + // Get the worker ID + const { workerId } = await queryWorker(c, workerQuery, managerDriver); + logger().debug("found worker for action", { workerId }); invariant(workerId, "Missing worker ID"); // Invoke the action @@ -76,6 +72,8 @@ export function createInlineClientDriver( actionName, actionArgs: args, workerId, + // No auth data since this is from internal + authData: undefined, }); try { @@ -116,7 +114,6 @@ export function createInlineClientDriver( customFetch: routingHandler.custom.sendRequest.bind( undefined, workerId, - meta, ), }); @@ -142,7 +139,7 @@ export function createInlineClientDriver( workerQuery: WorkerQuery, _encodingKind: Encoding, ): Promise => { - // Get the worker ID and meta + // Get the worker ID const { workerId } = await queryWorker(c, workerQuery, managerDriver); logger().debug("resolved worker", { workerId }); invariant(workerId, "missing worker ID"); @@ -156,13 +153,9 @@ export function createInlineClientDriver( encodingKind: Encoding, params?: unknown, ): Promise => { - // Get the worker ID and meta - const { workerId, meta } = await queryWorker( - c, - workerQuery, - managerDriver, - ); - logger().debug("found worker for action", { workerId, meta }); + // Get the worker ID + const { workerId } = await queryWorker(c, workerQuery, managerDriver); + logger().debug("found worker for action", { workerId }); invariant(workerId, "Missing worker ID"); // Invoke the action @@ -182,15 +175,17 @@ export function createInlineClientDriver( const output = await routingHandler.inline.handlers.onConnectWebSocket({ req: c?.req, encoding: encodingKind, - params, workerId, + params, + // No auth data since this is from internal + authData: undefined, }); logger().debug("got ConnectWebSocketOutput, creating FakeWebSocket"); // TODO: There might be a bug where mutating data from the response of an action over a websocket will mutate the original data. See note about `structuredClone` in `action` // Create and initialize the FakeWebSocket, waiting for it to be ready - const webSocket = new FakeWebSocket(output); + const webSocket = new FakeWebSocket(output) as any as WebSocket; logger().debug("FakeWebSocket created and initialized"); return webSocket; @@ -198,19 +193,10 @@ export function createInlineClientDriver( // Open WebSocket const ws = await routingHandler.custom.openWebSocket( workerId, - meta, encodingKind, + params, ); - // Send init message with the initialization data - // - // We can't pass this data in the query string since it might include sensitive data which would get logged - const messageSerialized = serializeWithEncoding(encodingKind, { - b: { i: { p: params } }, - }); - ws.send(messageSerialized); - logger().debug("sent websocket init message"); - return ws; } else { assertUnreachable(routingHandler); @@ -223,13 +209,9 @@ export function createInlineClientDriver( encodingKind: Encoding, params: unknown, ): Promise => { - // Get the worker ID and meta - const { workerId, meta } = await queryWorker( - c, - workerQuery, - managerDriver, - ); - logger().debug("found worker for sse connection", { workerId, meta }); + // Get the worker ID + const { workerId } = await queryWorker(c, workerQuery, managerDriver); + logger().debug("found worker for sse connection", { workerId }); invariant(workerId, "Missing worker ID"); logger().debug("opening sse connection", { @@ -254,6 +236,8 @@ export function createInlineClientDriver( encoding: encodingKind, params, workerId, + // No auth data since this is from internal + authData: undefined, }); logger().debug("got ConnectSseOutput, creating FakeEventSource"); @@ -329,13 +313,6 @@ export function createInlineClientDriver( }, }); } else if ("custom" in routingHandler) { - // For custom routing handler, get the worker metadata first - const { meta } = await queryWorker( - c, - { getForId: { workerId } }, - managerDriver, - ); - // Send an HTTP request to the connections endpoint return sendHttpRequest({ url: "http://worker/connections/message", @@ -352,7 +329,6 @@ export function createInlineClientDriver( customFetch: routingHandler.custom.sendRequest.bind( undefined, workerId, - meta, ), }); } else { @@ -371,9 +347,9 @@ export async function queryWorker( c: HonoContext | undefined, query: WorkerQuery, driver: ManagerDriver, -): Promise<{ workerId: string; meta?: unknown }> { +): Promise<{ workerId: string }> { logger().debug("querying worker", { query }); - let workerOutput: { workerId: string; meta?: unknown }; + let workerOutput: { workerId: string }; if ("getForId" in query) { const output = await driver.getForId({ c, @@ -403,7 +379,6 @@ export async function queryWorker( }); workerOutput = { workerId: getOrCreateOutput.workerId, - meta: getOrCreateOutput.meta, }; } else if ("create" in query) { const createOutput = await driver.createWorker({ @@ -415,7 +390,6 @@ export async function queryWorker( }); workerOutput = { workerId: createOutput.workerId, - meta: createOutput.meta, }; } else { throw new errors.InvalidRequest("Invalid query format"); @@ -423,9 +397,8 @@ export async function queryWorker( logger().debug("worker query result", { workerId: workerOutput.workerId, - meta: workerOutput.meta, }); - return { workerId: workerOutput.workerId, meta: workerOutput.meta }; + return { workerId: workerOutput.workerId }; } /** diff --git a/packages/core/src/manager/auth.ts b/packages/core/src/manager/auth.ts new file mode 100644 index 000000000..80d3fbed1 --- /dev/null +++ b/packages/core/src/manager/auth.ts @@ -0,0 +1,121 @@ +import * as errors from "@/worker/errors"; +import type { Context as HonoContext } from "hono"; +import type { WorkerQuery } from "./protocol/query"; +import type { AuthIntent } from "@/worker/config"; +import type { AnyWorkerDefinition } from "@/worker/definition"; +import type { RegistryConfig } from "@/registry/config"; +import { ManagerDriver } from "./driver"; +import { stringifyError } from "@/utils"; +import { logger } from "./log"; + +/** + * Get authentication intents from a worker query + */ +export function getIntentsFromQuery(query: WorkerQuery): Set { + const intents = new Set(); + + if ("getForId" in query) { + intents.add("get"); + } else if ("getForKey" in query) { + intents.add("get"); + } else if ("getOrCreateForKey" in query) { + intents.add("get"); + intents.add("create"); + } else if ("create" in query) { + intents.add("create"); + } + + return intents; +} + +/** + * Get worker name from a worker query + */ +export async function getWorkerNameFromQuery( + c: HonoContext, + driver: ManagerDriver, + query: WorkerQuery, +): Promise { + if ("getForId" in query) { + // TODO: This will have a duplicate call to getForId between this and queryWorker + const output = await driver.getForId({ + c, + workerId: query.getForId.workerId, + }); + if (!output) throw new errors.WorkerNotFound(query.getForId.workerId); + return output.name; + } else if ("getForKey" in query) { + return query.getForKey.name; + } else if ("getOrCreateForKey" in query) { + return query.getOrCreateForKey.name; + } else if ("create" in query) { + return query.create.name; + } else { + throw new errors.InvalidRequest("Invalid query format"); + } +} + +/** + * Authenticate a request using the worker's onAuth function + */ +export async function authenticateRequest( + c: HonoContext, + workerDefinition: AnyWorkerDefinition, + intents: Set, + params: unknown, +): Promise { + if (!workerDefinition.config.onAuth) { + throw new errors.Forbidden( + "Worker requires authentication but no onAuth handler is defined", + ); + } + + try { + const dataOrPromise = workerDefinition.config.onAuth({ + req: c.req.raw, + intents, + params, + }); + if (dataOrPromise instanceof Promise) { + return await dataOrPromise; + } else { + return dataOrPromise; + } + } catch (error) { + logger().info("authentication error", { error: stringifyError(error) }); + if (errors.WorkerError.isWorkerError(error)) { + throw error; + } + throw new errors.Forbidden("Authentication failed"); + } +} + +/** + * Simplified authentication for endpoints that combines all auth steps + */ +export async function authenticateEndpoint( + c: HonoContext, + driver: ManagerDriver, + registryConfig: RegistryConfig, + query: WorkerQuery, + additionalIntents: AuthIntent[], + params: unknown, +): Promise { + // Get base intents from query + const intents = getIntentsFromQuery(query); + + // Add endpoint-specific intents + for (const intent of additionalIntents) { + intents.add(intent); + } + + // Get worker definition + const workerName = await getWorkerNameFromQuery(c, driver, query); + const workerDefinition = registryConfig.workers[workerName]; + if (!workerDefinition) { + throw new errors.WorkerNotFound(workerName); + } + + // Authenticate + return await authenticateRequest(c, workerDefinition, intents, params); +} diff --git a/packages/core/src/manager/driver.ts b/packages/core/src/manager/driver.ts index cbc63fae6..7e329f08d 100644 --- a/packages/core/src/manager/driver.ts +++ b/packages/core/src/manager/driver.ts @@ -42,5 +42,4 @@ export interface WorkerOutput { workerId: string; name: string; key: WorkerKey; - meta?: unknown; } diff --git a/packages/core/src/manager/protocol/query.ts b/packages/core/src/manager/protocol/query.ts index 675b20dd1..d15df3226 100644 --- a/packages/core/src/manager/protocol/query.ts +++ b/packages/core/src/manager/protocol/query.ts @@ -55,6 +55,7 @@ export const ConnectRequestSchema = z.object({ export const ConnectWebSocketRequestSchema = z.object({ query: WorkerQuerySchema.describe("query"), encoding: EncodingSchema.describe("encoding"), + connParams: z.unknown().optional().describe("conn_params"), }); export const ConnMessageRequestSchema = z.object({ @@ -66,6 +67,7 @@ export const ConnMessageRequestSchema = z.object({ export const ResolveRequestSchema = z.object({ query: WorkerQuerySchema.describe(HEADER_WORKER_QUERY), + connParams: z.string().optional().describe(HEADER_CONN_PARAMS), }); export type WorkerQuery = z.infer; diff --git a/packages/core/src/manager/router.ts b/packages/core/src/manager/router.ts index 57fac9c79..45b6f2180 100644 --- a/packages/core/src/manager/router.ts +++ b/packages/core/src/manager/router.ts @@ -20,8 +20,9 @@ import { HEADER_CONN_TOKEN, HEADER_ENCODING, HEADER_WORKER_QUERY, - ALL_HEADERS, + ALL_PUBLIC_HEADERS, getRequestQuery, + HEADER_AUTH_DATA, } from "@/worker/router-endpoints"; import { assertUnreachable } from "@/worker/utils"; import type { RegistryConfig } from "@/registry/config"; @@ -30,7 +31,11 @@ import { handleRouteNotFound, loggerMiddleware, } from "@/common/router"; -import { DeconstructedError, deconstructError } from "@/common/utils"; +import { + DeconstructedError, + deconstructError, + stringifyError, +} from "@/common/utils"; import type { DriverConfig } from "@/driver-helpers/config"; import { type ManagerInspectorConnHandler, @@ -55,8 +60,10 @@ import { import type { WorkerQuery } from "./protocol/query"; import { VERSION } from "@/utils"; import { ConnRoutingHandler } from "@/worker/conn-routing-handler"; -import { ClientDriver, createClientWithDriver } from "@/client/client"; -import { Transport, TransportSchema } from "@/worker/protocol/message/mod"; +import { ClientDriver } from "@/client/client"; +import { Transport } from "@/worker/protocol/message/mod"; +import { authenticateEndpoint } from "./auth"; +import type { WebSocket, MessageEvent, CloseEvent } from "ws"; type ManagerRouterHandler = { onConnectInspector?: ManagerInspectorConnHandler; @@ -141,7 +148,10 @@ export function createManagerRouter( return cors({ ...corsConfig, - allowHeaders: [...(registryConfig.cors?.allowHeaders ?? []), ...ALL_HEADERS], + allowHeaders: [ + ...(registryConfig.cors?.allowHeaders ?? []), + ...ALL_PUBLIC_HEADERS, + ], })(c, next); }); } @@ -194,7 +204,9 @@ export function createManagerRouter( responses: buildOpenApiResponses(ResolveResponseSchema), }); - router.openapi(resolveRoute, (c) => handleResolveRequest(c, driver)); + router.openapi(resolveRoute, (c) => + handleResolveRequest(c, registryConfig, driver), + ); } // GET /workers/connect/websocket @@ -202,12 +214,6 @@ export function createManagerRouter( const wsRoute = createRoute({ method: "get", path: "/workers/connect/websocket", - request: { - query: z.object({ - encoding: OPENAPI_ENCODING, - query: OPENAPI_WORKER_QUERY, - }), - }, responses: { 101: { description: "WebSocket upgrade", @@ -443,7 +449,7 @@ export function createManagerRouter( if (serverWs.readyState === 1) { // OPEN - serverWs.send(clientEvt.data); + serverWs.send(clientEvt.data as any); } }; @@ -580,9 +586,9 @@ export async function queryWorker( c: HonoContext, query: WorkerQuery, driver: ManagerDriver, -): Promise<{ workerId: string; meta?: unknown }> { +): Promise<{ workerId: string }> { logger().debug("querying worker", { query }); - let workerOutput: { workerId: string; meta?: unknown }; + let workerOutput: { workerId: string }; if ("getForId" in query) { const output = await driver.getForId({ c, @@ -612,7 +618,6 @@ export async function queryWorker( }); workerOutput = { workerId: getOrCreateOutput.workerId, - meta: getOrCreateOutput.meta, }; } else if ("create" in query) { const createOutput = await driver.createWorker({ @@ -624,7 +629,6 @@ export async function queryWorker( }); workerOutput = { workerId: createOutput.workerId, - meta: createOutput.meta, }; } else { throw new errors.InvalidRequest("Invalid query format"); @@ -632,9 +636,8 @@ export async function queryWorker( logger().debug("worker query result", { workerId: workerOutput.workerId, - meta: workerOutput.meta, }); - return { workerId: workerOutput.workerId, meta: workerOutput.meta }; + return { workerId: workerOutput.workerId }; } /** @@ -649,11 +652,11 @@ async function handleSseConnectRequest( ): Promise { let encoding: Encoding | undefined; try { - encoding = getRequestEncoding(c.req, false); + encoding = getRequestEncoding(c.req); logger().debug("sse connection request received", { encoding }); const params = ConnectRequestSchema.safeParse({ - query: getRequestQuery(c, false), + query: getRequestQuery(c), encoding: c.req.header(HEADER_ENCODING), connParams: c.req.header(HEADER_CONN_PARAMS), }); @@ -667,10 +670,25 @@ async function handleSseConnectRequest( const query = params.data.query; - // Get the worker ID and meta - const { workerId, meta } = await queryWorker(c, query, driver); + // Parse connection parameters for authentication + const connParams = params.data.connParams + ? JSON.parse(params.data.connParams) + : undefined; + + // Authenticate the request + const authData = await authenticateEndpoint( + c, + driver, + registryConfig, + query, + ["connect"], + connParams, + ); + + // Get the worker ID + const { workerId } = await queryWorker(c, query, driver); invariant(workerId, "Missing worker ID"); - logger().debug("sse connection to worker", { workerId, meta }); + logger().debug("sse connection to worker", { workerId }); // Handle based on mode if ("inline" in handler.routingHandler) { @@ -682,6 +700,7 @@ async function handleSseConnectRequest( driverConfig, handler.routingHandler.inline.handlers.onConnectSse, workerId, + authData, ); } else if ("custom" in handler.routingHandler) { logger().debug("using custom proxy mode for sse connection"); @@ -693,11 +712,13 @@ async function handleSseConnectRequest( if (params.data.connParams) { proxyRequest.headers.set(HEADER_CONN_PARAMS, params.data.connParams); } + if (authData) { + proxyRequest.headers.set(HEADER_AUTH_DATA, JSON.stringify(authData)); + } return await handler.routingHandler.custom.proxyRequest( c, proxyRequest, workerId, - meta, ); } else { assertUnreachable(handler.routingHandler); @@ -776,12 +797,63 @@ async function handleWebSocketConnectRequest( try { logger().debug("websocket connection request received"); + // Parse configuration from Sec-WebSocket-Protocol header + // + // We use this instead of query parameters since this is more secure than + // query parameters. Query parameters often get logged. + // + // Browsers don't support using headers, so this is the only way to + // pass data securely. + const protocols = c.req.header("sec-websocket-protocol"); + let queryRaw: string | undefined; + let encodingRaw: string | undefined; + let connParamsRaw: string | undefined; + + if (protocols) { + // Parse protocols for conn_params.{token} pattern + const protocolList = protocols.split(",").map((p) => p.trim()); + for (const protocol of protocolList) { + if (protocol.startsWith("query.")) { + queryRaw = decodeURIComponent(protocol.substring("query.".length)); + } else if (protocol.startsWith("encoding.")) { + encodingRaw = protocol.substring("encoding.".length); + } else if (protocol.startsWith("conn_params.")) { + connParamsRaw = decodeURIComponent( + protocol.substring("conn_params.".length), + ); + } + } + } + + // Parse query + let queryUnvalidated: unknown; + try { + queryUnvalidated = JSON.parse(queryRaw!); + } catch (error) { + logger().error("invalid query json", { error }); + throw new errors.InvalidQueryJSON(error); + } + + // Parse conn params + let connParamsUnvalidated: unknown = null; + try { + if (connParamsRaw) { + connParamsUnvalidated = JSON.parse(connParamsRaw!); + } + } catch (error) { + logger().error("invalid conn params", { error }); + throw new errors.InvalidParams( + `Invalid params JSON: ${stringifyError(error)}`, + ); + } + // We can't use the standard headers with WebSockets // // All other information will be sent over the socket itself, since that data needs to be E2EE const params = ConnectWebSocketRequestSchema.safeParse({ - query: getRequestQuery(c, true), - encoding: c.req.query("encoding"), + query: queryUnvalidated, + encoding: encodingRaw, + connParams: connParamsUnvalidated, }); if (!params.success) { logger().error("invalid connection parameters", { @@ -789,10 +861,23 @@ async function handleWebSocketConnectRequest( }); throw new errors.InvalidRequest(params.error); } + encoding = params.data.encoding; + + // Authenticate endpoint + const authData = await authenticateEndpoint( + c, + driver, + registryConfig, + params.data.query, + ["connect"], + connParamsRaw, + ); - // Get the worker ID and meta - const { workerId, meta } = await queryWorker(c, params.data.query, driver); - logger().debug("found worker for websocket connection", { workerId, meta }); + // Get the worker ID + const { workerId } = await queryWorker(c, params.data.query, driver); + logger().debug("found worker for websocket connection", { + workerId, + }); invariant(workerId, "missing worker id"); if ("inline" in handler.routingHandler) { @@ -808,24 +893,29 @@ async function handleWebSocketConnectRequest( return handleWebSocketConnect( c, registryConfig, - driverConfig, onConnectWebSocket, workerId, - )(); + params.data.encoding, + params.data.connParams, + authData, + ); })(c, noopNext()); } else if ("custom" in handler.routingHandler) { logger().debug("using custom proxy mode for websocket connection"); // Proxy the WebSocket connection to the worker + // // The proxyWebSocket handler will: // 1. Validate the WebSocket upgrade request // 2. Forward the request to the worker with the appropriate path // 3. Handle the WebSocket pair and proxy messages between client and worker return await handler.routingHandler.custom.proxyWebSocket( c, - `/connect/websocket?encoding=${params.data.encoding}`, + "/connect/websocket", workerId, - meta, + params.data.encoding, + params.data.connParams, + authData, upgradeWebSocket, ); } else { @@ -878,6 +968,9 @@ async function handleWebSocketConnectRequest( /** * Handle a connection message request to a worker + * + * There is no authentication handler on this request since the connection + * token is used to authenticate the message. */ async function handleMessageRequest( c: HonoContext, @@ -900,6 +993,22 @@ async function handleMessageRequest( } const { workerId, connId, encoding, connToken } = params.data; + // TODO: This endpoint can be used to exhause resources (DoS attack) on an worker if you know the worker ID: + // 1. Get the worker ID (usually this is reasonably secure, but we don't assume worker ID is sensitive) + // 2. Spam messages to the worker (the conn token can be invalid) + // 3. The worker will be exhausted processing messages — even if the token is invalid + // + // The solution is we need to move the authorization of the connection token to this request handler + // AND include the worker ID in the connection token so we can verify that it has permission to send + // a message to that worker. This would require changing the token to a JWT so we can include a secure + // payload, but this requires managing a private key & managing key rotations. + // + // All other solutions (e.g. include the worker name as a header or include the worker name in the worker ID) + // have exploits that allow the caller to send messages to arbitrary workers. + // + // Currently, we assume this is not a critical problem because requests will likely get rate + // limited before enough messages are passed to the worker to exhaust resources. + // Handle based on mode if ("inline" in handler.routingHandler) { logger().debug("using inline proxy mode for connection message"); @@ -960,7 +1069,7 @@ async function handleActionRequest( logger().debug("action call received", { actionName }); const params = ConnectRequestSchema.safeParse({ - query: getRequestQuery(c, false), + query: getRequestQuery(c), encoding: c.req.header(HEADER_ENCODING), connParams: c.req.header(HEADER_CONN_PARAMS), }); @@ -972,9 +1081,24 @@ async function handleActionRequest( throw new errors.InvalidRequest(params.error); } - // Get the worker ID and meta - const { workerId, meta } = await queryWorker(c, params.data.query, driver); - logger().debug("found worker for action", { workerId, meta }); + // Parse connection parameters for authentication + const connParams = params.data.connParams + ? JSON.parse(params.data.connParams) + : undefined; + + // Authenticate the request + const authData = await authenticateEndpoint( + c, + driver, + registryConfig, + params.data.query, + ["action"], + connParams, + ); + + // Get the worker ID + const { workerId } = await queryWorker(c, params.data.query, driver); + logger().debug("found worker for action", { workerId }); invariant(workerId, "Missing worker ID"); // Handle based on mode @@ -988,6 +1112,7 @@ async function handleActionRequest( handler.routingHandler.inline.handlers.onAction, actionName, workerId, + authData, ); } else if ("custom" in handler.routingHandler) { logger().debug("using custom proxy mode for action call"); @@ -1002,14 +1127,17 @@ async function handleActionRequest( body: c.req.raw.body, }); proxyRequest.headers.set(HEADER_ENCODING, params.data.encoding); - if (params.data.connParams) + if (params.data.connParams) { proxyRequest.headers.set(HEADER_CONN_PARAMS, params.data.connParams); + } + if (authData) { + proxyRequest.headers.set(HEADER_AUTH_DATA, JSON.stringify(authData)); + } return await handler.routingHandler.custom.proxyRequest( c, proxyRequest, workerId, - meta, ); } else { assertUnreachable(handler.routingHandler); @@ -1031,13 +1159,15 @@ async function handleActionRequest( */ async function handleResolveRequest( c: HonoContext, + registryConfig: RegistryConfig, driver: ManagerDriver, ): Promise { - const encoding = getRequestEncoding(c.req, false); + const encoding = getRequestEncoding(c.req); logger().debug("resolve request encoding", { encoding }); const params = ResolveRequestSchema.safeParse({ - query: getRequestQuery(c, false), + query: getRequestQuery(c), + connParams: c.req.header(HEADER_CONN_PARAMS), }); if (!params.success) { logger().error("invalid connection parameters", { @@ -1046,9 +1176,19 @@ async function handleResolveRequest( throw new errors.InvalidRequest(params.error); } - // Get the worker ID and meta - const { workerId, meta } = await queryWorker(c, params.data.query, driver); - logger().debug("resolved worker", { workerId, meta }); + // Parse connection parameters for authentication + const connParams = params.data.connParams + ? JSON.parse(params.data.connParams) + : undefined; + + const query = params.data.query; + + // Authenticate the request + await authenticateEndpoint(c, driver, registryConfig, query, [], connParams); + + // Get the worker ID + const { workerId } = await queryWorker(c, query, driver); + logger().debug("resolved worker", { workerId }); invariant(workerId, "Missing worker ID"); // Format response according to protocol diff --git a/packages/core/src/topologies/common/generic-conn-driver.ts b/packages/core/src/topologies/common/generic-conn-driver.ts index 6b9df94c5..a9e107835 100644 --- a/packages/core/src/topologies/common/generic-conn-driver.ts +++ b/packages/core/src/topologies/common/generic-conn-driver.ts @@ -7,6 +7,7 @@ import * as messageToClient from "@/worker/protocol/message/to-client"; import { encodeDataToString } from "@/worker/protocol/serde"; import { WSContext } from "hono/ws"; import { SSEStreamingApi } from "hono/streaming"; +import type { WebSocket } from "ws"; // This state is different than `PersistedConn` state since the connection-specific state is persisted & must be serializable. This is also part of the connection driver, not part of the core worker. // diff --git a/packages/core/src/topologies/coordinate/conn/mod.ts b/packages/core/src/topologies/coordinate/conn/mod.ts index 4890dd5a5..7c6d9bac1 100644 --- a/packages/core/src/topologies/coordinate/conn/mod.ts +++ b/packages/core/src/topologies/coordinate/conn/mod.ts @@ -9,6 +9,7 @@ import { generateConnId, generateConnToken } from "@/worker/connection"; import type { WorkerDriver } from "@/worker/driver"; import { DriverConfig } from "@/driver-helpers/config"; import { RegistryConfig } from "@/registry/config"; +import { unknown } from "zod"; export interface RelayConnDriver { sendMessage(message: messageToClient.ToClient): void; @@ -27,6 +28,7 @@ export class RelayConn { #driver: RelayConnDriver; #workerId: string; #parameters: unknown; + #authData: unknown; #workerPeer?: WorkerPeer; @@ -56,6 +58,7 @@ export class RelayConn { driver: RelayConnDriver, workerId: string, parameters: unknown, + authData: unknown, ) { this.#registryConfig = registryConfig; this.#driverConfig = driverConfig; @@ -65,6 +68,7 @@ export class RelayConn { this.#globalState = globalState; this.#workerId = workerId; this.#parameters = parameters; + this.#authData = authData; } async start() { @@ -108,6 +112,7 @@ export class RelayConn { ci: connId, ct: connToken, p: this.#parameters, + ad: this.#authData, }, }, }, diff --git a/packages/core/src/topologies/coordinate/node/mod.ts b/packages/core/src/topologies/coordinate/node/mod.ts index ed30e4077..1a2c417e1 100644 --- a/packages/core/src/topologies/coordinate/node/mod.ts +++ b/packages/core/src/topologies/coordinate/node/mod.ts @@ -105,6 +105,7 @@ export class Node { ci: connId, ct: connToken, p: connParams, + ad: authData, }: ToLeaderConnectionOpen, ) { if (!nodeId) { @@ -133,6 +134,7 @@ export class Node { connState, CONN_DRIVER_COORDINATE_RELAY, { nodeId } satisfies CoordinateRelayState, + authData, ); // Connection init will be sent by `Worker` diff --git a/packages/core/src/topologies/coordinate/node/protocol.ts b/packages/core/src/topologies/coordinate/node/protocol.ts index cc9f22320..f33c15566 100644 --- a/packages/core/src/topologies/coordinate/node/protocol.ts +++ b/packages/core/src/topologies/coordinate/node/protocol.ts @@ -18,6 +18,8 @@ export const ToLeaderConnectionOpenSchema = z.object({ ct: z.string(), // Parameters p: z.unknown(), + // Auth data + ad: z.unknown(), }); export type ToLeaderConnectionOpen = z.infer; diff --git a/packages/core/src/topologies/coordinate/router/sse.ts b/packages/core/src/topologies/coordinate/router/sse.ts index 19e6df9a0..739c32a08 100644 --- a/packages/core/src/topologies/coordinate/router/sse.ts +++ b/packages/core/src/topologies/coordinate/router/sse.ts @@ -15,7 +15,7 @@ export async function serveSse( CoordinateDriver: CoordinateDriver, globalState: GlobalState, workerId: string, - { encoding, params }: ConnectSseOpts, + { encoding, params, authData }: ConnectSseOpts, ): Promise { let conn: RelayConn | undefined; return { @@ -39,6 +39,7 @@ export async function serveSse( }, workerId, params, + authData, ); await conn.start(); }, diff --git a/packages/core/src/topologies/coordinate/router/websocket.ts b/packages/core/src/topologies/coordinate/router/websocket.ts index 59e4d92c6..a8950a133 100644 --- a/packages/core/src/topologies/coordinate/router/websocket.ts +++ b/packages/core/src/topologies/coordinate/router/websocket.ts @@ -10,7 +10,10 @@ import { publishMessageToLeader } from "../node/message"; import type { WorkerDriver } from "@/worker/driver"; import type { DriverConfig } from "@/driver-helpers/config"; import type { RegistryConfig } from "@/registry/config"; -import { ConnectWebSocketOpts, ConnectWebSocketOutput } from "@/worker/router-endpoints"; +import { + ConnectWebSocketOpts, + ConnectWebSocketOutput, +} from "@/worker/router-endpoints"; export async function serveWebSocket( registryConfig: RegistryConfig, @@ -19,7 +22,7 @@ export async function serveWebSocket( CoordinateDriver: CoordinateDriver, globalState: GlobalState, workerId: string, - { req, encoding, params }: ConnectWebSocketOpts, + { req, encoding, params, authData }: ConnectWebSocketOpts, ): Promise { let conn: RelayConn | undefined; return { @@ -41,6 +44,7 @@ export async function serveWebSocket( }, workerId, params, + authData, ); await conn.start(); }, diff --git a/packages/core/src/topologies/partition/topology.ts b/packages/core/src/topologies/partition/topology.ts index a40d158c2..1a6f8cd36 100644 --- a/packages/core/src/topologies/partition/topology.ts +++ b/packages/core/src/topologies/partition/topology.ts @@ -46,17 +46,16 @@ import { ConnRoutingHandlerCustom, } from "@/worker/conn-routing-handler"; import invariant from "invariant"; +import type { WebSocket } from "ws"; export type SendRequestHandler = ( workerRequest: Request, workerId: string, - meta?: unknown, ) => Promise; export type OpenWebSocketHandler = ( path: string, workerId: string, - meta?: unknown, ) => Promise; export class PartitionTopologyManager { @@ -177,6 +176,7 @@ export class PartitionTopologyWorker { { encoding: opts.encoding, } satisfies GenericWebSocketDriverState, + opts.authData, ); }, onMessage: async (message) => { @@ -228,6 +228,7 @@ export class PartitionTopologyWorker { connState, CONN_DRIVER_GENERIC_SSE, { encoding: opts.encoding } satisfies GenericSseDriverState, + opts.authData, ); }, onClose: async () => { @@ -261,6 +262,7 @@ export class PartitionTopologyWorker { connState, CONN_DRIVER_GENERIC_HTTP, {} satisfies GenericHttpDriverState, + opts.authData, ); // Call action diff --git a/packages/core/src/topologies/partition/worker-router.ts b/packages/core/src/topologies/partition/worker-router.ts index 22585803f..0443f669b 100644 --- a/packages/core/src/topologies/partition/worker-router.ts +++ b/packages/core/src/topologies/partition/worker-router.ts @@ -27,8 +27,13 @@ import { handleConnectionMessage, HEADER_CONN_TOKEN, HEADER_CONN_ID, - ALL_HEADERS, + ALL_PUBLIC_HEADERS, + HEADER_CONN_PARAMS, + HEADER_AUTH_DATA, + HEADER_ENCODING, } from "@/worker/router-endpoints"; +import invariant from "invariant"; +import { EncodingSchema } from "@/worker/protocol/serde"; export type { ConnectWebSocketOpts, @@ -79,7 +84,10 @@ export function createWorkerRouter( return cors({ ...corsConfig, - allowHeaders: [...(registryConfig.cors?.allowHeaders ?? []), ...ALL_HEADERS], + allowHeaders: [ + ...(registryConfig.cors?.allowHeaders ?? []), + ...ALL_PUBLIC_HEADERS, + ], })(c, next); }); } @@ -97,18 +105,30 @@ export function createWorkerRouter( // Use the handlers from connectionHandlers const handlers = handler.connectionHandlers; - if (upgradeWebSocket && handlers.onConnectWebSocket) { + if (upgradeWebSocket) { router.get( "/connect/websocket", upgradeWebSocket(async (c) => { const workerId = await handler.getWorkerId(); + const encodingRaw = c.req.header(HEADER_ENCODING); + const connParamsRaw = c.req.header(HEADER_CONN_PARAMS); + const authDataRaw = c.req.header(HEADER_AUTH_DATA); + + const encoding = EncodingSchema.parse(encodingRaw); + const connParams = connParamsRaw + ? JSON.parse(connParamsRaw) + : undefined; + const authData = authDataRaw ? JSON.parse(authDataRaw) : undefined; + return handleWebSocketConnect( c as HonoContext, registryConfig, - driverConfig, handlers.onConnectWebSocket!, workerId, - )(); + encoding, + connParams, + authData, + ); }), ); } else { @@ -125,12 +145,20 @@ export function createWorkerRouter( throw new Error("onConnectSse handler is required"); } const workerId = await handler.getWorkerId(); + + const authDataRaw = c.req.header(HEADER_AUTH_DATA); + let authData: unknown = undefined; + if (authDataRaw) { + authData = JSON.parse(authDataRaw); + } + return handleSseConnect( c, registryConfig, driverConfig, handlers.onConnectSse, workerId, + authData, ); }); @@ -140,6 +168,13 @@ export function createWorkerRouter( } const actionName = c.req.param("action"); const workerId = await handler.getWorkerId(); + + const authDataRaw = c.req.header(HEADER_AUTH_DATA); + let authData: unknown = undefined; + if (authDataRaw) { + authData = JSON.parse(authDataRaw); + } + return handleAction( c, registryConfig, @@ -147,6 +182,7 @@ export function createWorkerRouter( handlers.onAction, actionName, workerId, + authData, ); }); diff --git a/packages/core/src/topologies/standalone/topology.ts b/packages/core/src/topologies/standalone/topology.ts index a74ff0547..af4b4c7c7 100644 --- a/packages/core/src/topologies/standalone/topology.ts +++ b/packages/core/src/topologies/standalone/topology.ts @@ -157,6 +157,7 @@ export class StandaloneTopology { connState, CONN_DRIVER_GENERIC_WEBSOCKET, { encoding: opts.encoding } satisfies GenericWebSocketDriverState, + opts.authData, ); }, onMessage: async (message) => { @@ -199,6 +200,7 @@ export class StandaloneTopology { connState, CONN_DRIVER_GENERIC_SSE, { encoding: opts.encoding } satisfies GenericSseDriverState, + opts.authData, ); }, onClose: async () => { @@ -224,6 +226,7 @@ export class StandaloneTopology { connState, CONN_DRIVER_GENERIC_HTTP, {} satisfies GenericHttpDriverState, + opts.authData, ); // Call action diff --git a/packages/core/src/worker/config.ts b/packages/core/src/worker/config.ts index 6fbbc9455..cb3684867 100644 --- a/packages/core/src/worker/config.ts +++ b/packages/core/src/worker/config.ts @@ -10,6 +10,7 @@ import { z } from "zod"; // (b) it makes the type definitions incredibly difficult to read as opposed to vanilla TypeScript. export const WorkerConfigSchema = z .object({ + onAuth: z.function().optional(), onCreate: z.function().optional(), onStart: z.function().optional(), onStateChange: z.function().optional(), @@ -95,6 +96,8 @@ export interface OnConnectOptions { // Creates state config // // This must have only one or the other or else S will not be able to be inferred +// +// Data returned from this handler will be available on `c.state`. type CreateState = | { state: S } | { @@ -108,6 +111,8 @@ type CreateState = // Creates connection state config // // This must have only one or the other or else S will not be able to be inferred +// +// Data returned from this handler will be available on `c.conn.state`. type CreateConnState = | { connState: CS } | { @@ -151,7 +156,48 @@ export interface Actions { // CreateState & // CreateConnState; +/** + * @experimental + */ +export type AuthIntent = "get" | "create" | "connect" | "action" | "message"; + +interface OnAuthOptions { + req: Request; + /** + * @experimental + */ + intents: Set; + params: CP; +} + interface BaseWorkerConfig> { + /** + * Called on the HTTP server before clients can interact with the worker. + * + * Only called for public endpoints. Calls to workers from within the backend + * do not trigger this handler. + * + * Data returned from this handler will be available on `c.conn.auth`. + * + * This function is required for any public HTTP endpoint access. Use this hook + * to validate client credentials and return authentication data that will be + * available on connections. This runs on the HTTP server (not the worker) + * in order to reduce load on the worker & prevent denial of server attacks + * against individual workers. + * + * If you need access to worker state for authentication, use onBeforeConnect + * with an empty onAuth function instead. + * + * You can also provide your own authentication middleware on your router if you + * choose, then use onAuth to pass the authentication data (e.g. user ID) to the + * worker itself. + * + * @param opts Authentication options including request and intent + * @returns Authentication data to attach to connections (must be serializable) + * @throws Throw an error to deny access to the worker + */ + onAuth?: (opts: OnAuthOptions) => unknown | Promise; + /** * Called when the worker is first initialized. * @@ -186,8 +232,18 @@ interface BaseWorkerConfig> { /** * Called before a client connects to the worker. * + * Unlike onAuth, this handler is still called for both internal and + * public clients. + * * Use this hook to determine if a connection should be accepted - * and to initialize connection-specific state. + * and to initialize connection-specific state. Unlike onAuth, this runs + * on the worker and has access to worker state, but uses slightly + * more resources on the worker rather than authenticating with onAuth. + * + * For authentication without worker state access, prefer onAuth. + * + * For authentication with worker state, use onBeforeConnect with an empty + * onAuth handler. * * @param opts Connection parameters including client-provided data * @returns The initial connection state or a Promise that resolves to it @@ -254,6 +310,7 @@ interface BaseWorkerConfig> { export type WorkerConfig = Omit< z.infer, | "actions" + | "onAuth" | "onCreate" | "onStart" | "onStateChange" @@ -283,6 +340,7 @@ export type WorkerConfigInput< > = Omit< z.input, | "actions" + | "onAuth" | "onCreate" | "onStart" | "onStateChange" diff --git a/packages/core/src/worker/conn-routing-handler.ts b/packages/core/src/worker/conn-routing-handler.ts index 5edbdc5e9..c64277ab8 100644 --- a/packages/core/src/worker/conn-routing-handler.ts +++ b/packages/core/src/worker/conn-routing-handler.ts @@ -1,7 +1,8 @@ -import { UpgradeWebSocket } from "@/utils"; -import { Encoding } from "./protocol/serde"; +import type { UpgradeWebSocket } from "@/utils"; +import type { Encoding } from "./protocol/serde"; import type { ConnectionHandlers as ConnHandlers } from "./router-endpoints"; -import type { Context as HonoContext, HonoRequest } from "hono"; +import type { Context as HonoContext } from "hono"; +import type { WebSocket } from "ws"; /** * Deterines how requests to workers should be routed. @@ -31,27 +32,27 @@ export type BuildProxyEndpoint = (c: HonoContext, workerId: string) => string; export type SendRequestHandler = ( workerId: string, - meta: unknown | undefined, workerRequest: Request, ) => Promise; export type OpenWebSocketHandler = ( workerId: string, - meta: unknown | undefined, encodingKind: Encoding, + params: unknown ) => Promise; export type ProxyRequestHandler = ( c: HonoContext, workerRequest: Request, workerId: string, - meta?: unknown, ) => Promise; export type ProxyWebSocketHandler = ( c: HonoContext, path: string, workerId: string, - meta?: unknown, - upgradeWebSocket?: UpgradeWebSocket, + encoding: Encoding, + connParams: unknown, + authData: unknown, + upgradeWebSocket: UpgradeWebSocket, ) => Promise; diff --git a/packages/core/src/worker/connection.ts b/packages/core/src/worker/connection.ts index b4f2418e3..8639c6d4e 100644 --- a/packages/core/src/worker/connection.ts +++ b/packages/core/src/worker/connection.ts @@ -52,6 +52,10 @@ export class Conn { return this.__persist.p; } + public get auth(): unknown { + return this.__persist.a; + } + public get _stateEnabled() { return this.#stateEnabled; } diff --git a/packages/core/src/worker/definition.ts b/packages/core/src/worker/definition.ts index 7c45c28a4..6c5f4947a 100644 --- a/packages/core/src/worker/definition.ts +++ b/packages/core/src/worker/definition.ts @@ -31,6 +31,10 @@ export class WorkerDefinition> { this.#config = config; } + get config(): WorkerConfig { + return this.#config; + } + instantiate(): WorkerInstance { return new WorkerInstance(this.#config); } diff --git a/packages/core/src/worker/errors.ts b/packages/core/src/worker/errors.ts index 5d85ad262..fbbe87077 100644 --- a/packages/core/src/worker/errors.ts +++ b/packages/core/src/worker/errors.ts @@ -24,14 +24,6 @@ export class WorkerError extends Error { public static isWorkerError( error: unknown, ): error is WorkerError | DeconstructedError { - console.trace( - "checking error", - error, - typeof error, - (error as WorkerError | DeconstructedError).__type, - typeof error === "object" && - (error as WorkerError | DeconstructedError).__type === "WorkerError", - ); return ( typeof error === "object" && (error as WorkerError | DeconstructedError).__type === "WorkerError" @@ -296,3 +288,10 @@ export class InvalidParams extends WorkerError { super("invalid_params", message, { public: true }); } } + +export class Forbidden extends WorkerError { + constructor(message?: string) { + super("forbidden", message ?? "Access denied", { public: true }); + this.statusCode = 403; + } +} diff --git a/packages/core/src/worker/instance.ts b/packages/core/src/worker/instance.ts index b1e01109e..b55f38fd3 100644 --- a/packages/core/src/worker/instance.ts +++ b/packages/core/src/worker/instance.ts @@ -666,6 +666,7 @@ export class WorkerInstance { state: CS, driverId: string, driverState: unknown, + authData: unknown, ): Promise> { if (this.#connections.has(connectionId)) { throw new Error(`Connection already exists: ${connectionId}`); @@ -680,6 +681,7 @@ export class WorkerInstance { ds: driverState, p: params, s: state, + a: authData, su: [], }; const conn = new Conn( diff --git a/packages/core/src/worker/persisted.ts b/packages/core/src/worker/persisted.ts index 2043ada55..cd33563f5 100644 --- a/packages/core/src/worker/persisted.ts +++ b/packages/core/src/worker/persisted.ts @@ -22,6 +22,8 @@ export interface PersistedConn { p: CP; // State s: CS; + // Auth data + a?: unknown; // Subscriptions su: PersistedSubscription[]; } diff --git a/packages/core/src/worker/protocol/message/mod.ts b/packages/core/src/worker/protocol/message/mod.ts index fc014f8aa..fa8e330d3 100644 --- a/packages/core/src/worker/protocol/message/mod.ts +++ b/packages/core/src/worker/protocol/message/mod.ts @@ -92,9 +92,7 @@ export async function processMessage( let actionName: string | undefined; try { - if ("i" in message.b) { - invariant(false, "should not be notified of init event"); - } else if ("ar" in message.b) { + if ("ar" in message.b) { // Action request if (handler.onExecuteAction === undefined) { diff --git a/packages/core/src/worker/protocol/message/to-server.ts b/packages/core/src/worker/protocol/message/to-server.ts index 4197a3000..257f2792b 100644 --- a/packages/core/src/worker/protocol/message/to-server.ts +++ b/packages/core/src/worker/protocol/message/to-server.ts @@ -1,10 +1,5 @@ import { z } from "zod"; -const InitSchema = z.object({ - // Conn Params - p: z.unknown({}).optional(), -}); - const ActionRequestSchema = z.object({ // ID i: z.number().int(), @@ -24,7 +19,6 @@ const SubscriptionRequestSchema = z.object({ export const ToServerSchema = z.object({ // Body b: z.union([ - z.object({ i: InitSchema }), z.object({ ar: ActionRequestSchema }), z.object({ sr: SubscriptionRequestSchema }), ]), diff --git a/packages/core/src/worker/router-endpoints.ts b/packages/core/src/worker/router-endpoints.ts index 15fb105f4..c0e38a2d2 100644 --- a/packages/core/src/worker/router-endpoints.ts +++ b/packages/core/src/worker/router-endpoints.ts @@ -1,6 +1,6 @@ -import { type HonoRequest, type Context as HonoContext } from "hono"; +import type { HonoRequest, Context as HonoContext } from "hono"; import { type SSEStreamingApi, streamSSE } from "hono/streaming"; -import { type WSContext } from "hono/ws"; +import type { WSContext } from "hono/ws"; import * as errors from "./errors"; import { logger } from "./log"; import { @@ -8,23 +8,22 @@ import { EncodingSchema, serialize, deserialize, - CachedSerializer, } from "@/worker/protocol/serde"; import { parseMessage } from "@/worker/protocol/message/mod"; import * as protoHttpAction from "@/worker/protocol/http/action"; import type * as messageToServer from "@/worker/protocol/message/to-server"; -import type { InputData, OutputData } from "@/worker/protocol/serde"; +import type { InputData } from "@/worker/protocol/serde"; import { assertUnreachable } from "./utils"; import { deconstructError, stringifyError } from "@/common/utils"; import type { RegistryConfig } from "@/registry/config"; import type { DriverConfig } from "@/driver-helpers/config"; -import invariant from "invariant"; export interface ConnectWebSocketOpts { req?: HonoRequest; encoding: Encoding; - params: unknown; workerId: string; + params: unknown; + authData: unknown; } export interface ConnectWebSocketOutput { @@ -38,6 +37,7 @@ export interface ConnectSseOpts { encoding: Encoding; params: unknown; workerId: string; + authData: unknown; } export interface ConnectSseOutput { @@ -51,6 +51,7 @@ export interface ActionOpts { actionName: string; actionArgs: unknown[]; workerId: string; + authData: unknown; } export interface ActionOutput { @@ -83,166 +84,129 @@ export interface ConnectionHandlers { export function handleWebSocketConnect( context: HonoContext, registryConfig: RegistryConfig, - driverConfig: DriverConfig, handler: (opts: ConnectWebSocketOpts) => Promise, workerId: string, + encoding: Encoding, + params: unknown, + authData: unknown, ) { - return async () => { - const encoding = getRequestEncoding(context.req, true); - const exposeInternalError = getRequestExposeInternalError( - context.req, - false, - ); - - let sharedWs: WSContext | undefined = undefined; + const exposeInternalError = getRequestExposeInternalError(context.req); - // Setup promise for the init message since all other behavior depends on this - const { - promise: onInitPromise, - resolve: onInitResolve, - reject: onInitReject, - } = Promise.withResolvers(); + // Setup promise for the init message since all other behavior depends on this + const { + promise: wsHandlerPromise, + resolve: wsHandlerResolve, + reject: wsHandlerReject, + } = Promise.withResolvers(); - let didTimeOut = false; - let didInit = false; + return { + onOpen: async (_evt: any, ws: WSContext) => { + logger().debug("websocket open"); - // Add timeout waiting for init - const initTimeout = setTimeout(() => { - logger().warn("timed out waiting for init"); - - sharedWs?.close(1001, "timed out waiting for init message"); - didTimeOut = true; - onInitReject("init timed out"); - }, registryConfig.webSocketInitTimeout); - - return { - onOpen: async (_evt: any, ws: WSContext) => { - sharedWs = ws; - - logger().debug("websocket open"); - - // Close WS immediately if init timed out. This indicates a long delay at the protocol level in sending the init message. - if (didTimeOut) ws.close(1001, "timed out waiting for init message"); - }, - onMessage: async (evt: { data: any }, ws: WSContext) => { - try { - const value = evt.data.valueOf() as InputData; - const message = await parseMessage(value, { - encoding: encoding, - maxIncomingMessageSize: registryConfig.maxIncomingMessageSize, - }); - - if ("i" in message.b) { - // Handle init message - // - // Parameters must go over the init message instead of a query parameter so it receives full E2EE - - logger().debug("received init ws message"); - - invariant( - !didInit, - "should not have already received init message", - ); - didInit = true; - clearTimeout(initTimeout); - - try { - // Create connection handler - const wsHandler = await handler({ - req: context.req, - encoding, - params: message.b.i.p, - workerId, - }); - - // Notify socket open - // TODO: Add timeout to this - await wsHandler.onOpen(ws); - - // Allow all other events to proceed - onInitResolve(wsHandler); - } catch (error) { - deconstructError( - error, - logger(), - { wsEvent: "open" }, - exposeInternalError, - ); - onInitReject(error); - ws.close(1011, "internal error"); - } - } else { - // Handle all other messages - - logger().debug("received regular ws message"); - - const wsHandler = await onInitPromise; - await wsHandler.onMessage(message); - } - } catch (error) { - const { code } = deconstructError( - error, - logger(), - { - wsEvent: "message", - }, - exposeInternalError, - ); - ws.close(1011, code); - } + try { + // Create connection handler + const wsHandler = await handler({ + req: context.req, + encoding, + params, + workerId, + authData, + }); + + // Notify socket open + // TODO: Add timeout to this + await wsHandler.onOpen(ws); + + // Unblock other uses of WS handler + wsHandlerResolve(wsHandler); + } catch (error) { + wsHandlerReject(error); + + const { code } = deconstructError( + error, + logger(), + { + wsEvent: "message", + }, + exposeInternalError, + ); + ws.close(1011, code); + } + }, + onMessage: async (evt: { data: any }, ws: WSContext) => { + try { + const wsHandler = await wsHandlerPromise; + + const value = evt.data.valueOf() as InputData; + const message = await parseMessage(value, { + encoding: encoding, + maxIncomingMessageSize: registryConfig.maxIncomingMessageSize, + }); + + await wsHandler.onMessage(message); + } catch (error) { + const { code } = deconstructError( + error, + logger(), + { + wsEvent: "message", + }, + exposeInternalError, + ); + ws.close(1011, code); + } + }, + onClose: async ( + event: { + wasClean: boolean; + code: number; + reason: string; }, - onClose: async ( - event: { - wasClean: boolean; - code: number; - reason: string; - }, - ws: WSContext, - ) => { - if (event.wasClean) { - logger().info("websocket closed", { - code: event.code, - reason: event.reason, - wasClean: event.wasClean, - }); - } else { - logger().warn("websocket closed", { - code: event.code, - reason: event.reason, - wasClean: event.wasClean, - }); - } + ws: WSContext, + ) => { + if (event.wasClean) { + logger().info("websocket closed", { + code: event.code, + reason: event.reason, + wasClean: event.wasClean, + }); + } else { + logger().warn("websocket closed", { + code: event.code, + reason: event.reason, + wasClean: event.wasClean, + }); + } - // HACK: Close socket in order to fix bug with Cloudflare leaving WS in closing state - // https://github.com/cloudflare/workerd/issues/2569 - ws.close(1000, "hack_force_close"); + // HACK: Close socket in order to fix bug with Cloudflare leaving WS in closing state + // https://github.com/cloudflare/workerd/issues/2569 + ws.close(1000, "hack_force_close"); - try { - const wsHandler = await onInitPromise; - await wsHandler.onClose(); - } catch (error) { - deconstructError( - error, - logger(), - { wsEvent: "close" }, - exposeInternalError, - ); - } - }, - onError: async (_error: unknown) => { - try { - // Workers don't need to know about this, since it's abstracted away - logger().warn("websocket error"); - } catch (error) { - deconstructError( - error, - logger(), - { wsEvent: "error" }, - exposeInternalError, - ); - } - }, - }; + try { + const wsHandler = await wsHandlerPromise; + await wsHandler.onClose(); + } catch (error) { + deconstructError( + error, + logger(), + { wsEvent: "close" }, + exposeInternalError, + ); + } + }, + onError: async (_error: unknown) => { + try { + // Workers don't need to know about this, since it's abstracted away + logger().warn("websocket error"); + } catch (error) { + deconstructError( + error, + logger(), + { wsEvent: "error" }, + exposeInternalError, + ); + } + }, }; } @@ -255,8 +219,9 @@ export async function handleSseConnect( driverConfig: DriverConfig, handler: (opts: ConnectSseOpts) => Promise, workerId: string, + authData: unknown, ) { - const encoding = getRequestEncoding(c.req, false); + const encoding = getRequestEncoding(c.req); const parameters = getRequestConnParams(c.req, registryConfig, driverConfig); const sseHandler = await handler({ @@ -264,6 +229,7 @@ export async function handleSseConnect( encoding, params: parameters, workerId, + authData, }); return streamSSE(c, async (stream) => { @@ -312,8 +278,9 @@ export async function handleAction( handler: (opts: ActionOpts) => Promise, actionName: string, workerId: string, + authData: unknown, ) { - const encoding = getRequestEncoding(c.req, false); + const encoding = getRequestEncoding(c.req); const parameters = getRequestConnParams(c.req, registryConfig, driverConfig); logger().debug("handling action", { actionName, encoding }); @@ -365,6 +332,7 @@ export async function handleAction( actionName: actionName, actionArgs: actionArgs, workerId, + authData, }); // Encode the response @@ -396,7 +364,7 @@ export async function handleConnectionMessage( connToken: string, workerId: string, ) { - const encoding = getRequestEncoding(c.req, false); + const encoding = getRequestEncoding(c.req); // Validate incoming request let message: messageToServer.ToServer; @@ -435,15 +403,10 @@ export async function handleConnectionMessage( } // Helper to get the connection encoding from a request -export function getRequestEncoding( - req: HonoRequest, - useQuery: boolean, -): Encoding { - const encodingParam = useQuery - ? req.query("encoding") - : req.header(HEADER_ENCODING); +export function getRequestEncoding(req: HonoRequest): Encoding { + const encodingParam = req.header(HEADER_ENCODING); if (!encodingParam) { - return "json"; + throw new errors.InvalidEncoding("undefined"); } const result = EncodingSchema.safeParse(encodingParam); @@ -454,13 +417,8 @@ export function getRequestEncoding( return result.data; } -export function getRequestExposeInternalError( - req: HonoRequest, - useQuery: boolean, -): boolean { - const param = useQuery - ? req.query("expose-internal-error") - : req.header(HEADER_EXPOSE_INTERNAL_ERROR); +export function getRequestExposeInternalError(req: HonoRequest): boolean { + const param = req.header(HEADER_EXPOSE_INTERNAL_ERROR); if (!param) { return false; } @@ -468,11 +426,9 @@ export function getRequestExposeInternalError( return param === "true"; } -export function getRequestQuery(c: HonoContext, useQuery: boolean): unknown { +export function getRequestQuery(c: HonoContext): unknown { // Get query parameters for worker lookup - const queryParam = useQuery - ? c.req.query("query") - : c.req.header(HEADER_WORKER_QUERY); + const queryParam = c.req.header(HEADER_WORKER_QUERY); if (!queryParam) { logger().error("missing query parameter"); throw new errors.InvalidRequest("missing query"); @@ -491,18 +447,28 @@ export function getRequestQuery(c: HonoContext, useQuery: boolean): unknown { export const HEADER_WORKER_QUERY = "X-AC-Query"; export const HEADER_ENCODING = "X-AC-Encoding"; + +// Internal header export const HEADER_EXPOSE_INTERNAL_ERROR = "X-AC-Expose-Internal-Error"; // IMPORTANT: Params must be in headers or in an E2EE part of the request (i.e. NOT the URL or query string) in order to ensure that tokens can be securely passed in params. export const HEADER_CONN_PARAMS = "X-AC-Conn-Params"; +// Internal header +export const HEADER_AUTH_DATA = "X-AC-Auth-Data"; + export const HEADER_WORKER_ID = "X-AC-Worker"; export const HEADER_CONN_ID = "X-AC-Conn"; export const HEADER_CONN_TOKEN = "X-AC-Conn-Token"; -export const ALL_HEADERS = [ +/** + * Headers that publics can send from public clients. + * + * Used for CORS. + **/ +export const ALL_PUBLIC_HEADERS = [ HEADER_WORKER_QUERY, HEADER_ENCODING, HEADER_CONN_PARAMS, diff --git a/packages/platforms/cloudflare-workers/src/handler.ts b/packages/platforms/cloudflare-workers/src/handler.ts index 522383d6c..c6a6d5bc7 100644 --- a/packages/platforms/cloudflare-workers/src/handler.ts +++ b/packages/platforms/cloudflare-workers/src/handler.ts @@ -5,6 +5,12 @@ import { } from "./worker-handler-do"; import { ConfigSchema, type InputConfig } from "./config"; import { assertUnreachable } from "rivetkit/utils"; +import { + HEADER_AUTH_DATA, + HEADER_CONN_PARAMS, + HEADER_ENCODING, + HEADER_EXPOSE_INTERNAL_ERROR, +} from "rivetkit/driver-helpers"; import type { Hono } from "hono"; import { PartitionTopologyManager } from "rivetkit/topologies/partition"; import { logger } from "./log"; @@ -28,6 +34,15 @@ export interface Bindings { */ export const CF_AMBIENT_ENV = new AsyncLocalStorage(); +const STANDARD_WEBSOCKET_HEADERS = [ + "connection", + "upgrade", + "sec-websocket-key", + "sec-websocket-version", + "sec-websocket-protocol", + "sec-websocket-extensions", +]; + export function getCloudflareAmbientEnv(): Bindings { const env = CF_AMBIENT_ENV.getStore(); invariant(env, "missing CF_AMBIENT_ENV"); @@ -83,11 +98,7 @@ export function createRouter( registry.config, driverConfig, { - sendRequest: async ( - workerId, - meta, - workerRequest, - ): Promise => { + sendRequest: async (workerId, workerRequest): Promise => { const env = getCloudflareAmbientEnv(); logger().debug("sending request to durable object", { @@ -104,8 +115,8 @@ export function createRouter( openWebSocket: async ( workerId, - meta, encodingKind: Encoding, + params: unknown, ): Promise => { const env = getCloudflareAmbientEnv(); @@ -115,13 +126,20 @@ export function createRouter( const id = env.WORKER_DO.idFromString(workerId); const stub = env.WORKER_DO.get(id); - // TODO: this doesn't call on open - const url = `http://worker/connect/websocket?encoding=${encodingKind}&expose-internal-error=true`; - const response = await stub.fetch(url, { - headers: { - Upgrade: "websocket", - Connection: "Upgrade", - }, + const headers: Record = { + Upgrade: "websocket", + Connection: "Upgrade", + [HEADER_EXPOSE_INTERNAL_ERROR]: "true", + [HEADER_ENCODING]: encodingKind, + }; + if (params) { + headers[HEADER_CONN_PARAMS] = JSON.stringify(params); + } + // HACK: See packages/platforms/cloudflare-workers/src/websocket.ts + headers["sec-websocket-protocol"] = "rivetkit"; + + const response = await stub.fetch("http://worker/connect/websocket", { + headers, }); const webSocket = response.webSocket; @@ -137,6 +155,12 @@ export function createRouter( webSocket.accept(); + // HACK: Cloudflare does not call onopen automatically, so we need + // to call this on the next tick + setTimeout(() => { + webSocket.onopen?.(new Event("open")); + }, 100); + return webSocket; }, @@ -152,7 +176,14 @@ export function createRouter( return await stub.fetch(workerRequest); }, - proxyWebSocket: async (c, path, workerId) => { + proxyWebSocket: async ( + c, + path, + workerId, + encoding, + params, + authData, + ) => { logger().debug("forwarding websocket to durable object", { workerId, path, @@ -166,10 +197,37 @@ export function createRouter( }); } - // Update path on URL + // TODO: strip headers const newUrl = new URL(`http://worker${path}`); const workerRequest = new Request(newUrl, c.req.raw); + // Always build fresh request to prevent forwarding unwanted headers + // HACK: Since we can't build a new request, we need to remove + // non-standard headers manually + const headerKeys: string[] = []; + workerRequest.headers.forEach((v, k) => headerKeys.push(k)); + for (const k of headerKeys) { + if (!STANDARD_WEBSOCKET_HEADERS.includes(k)) { + workerRequest.headers.delete(k); + } + } + + // Add RivetKit headers + workerRequest.headers.set(HEADER_EXPOSE_INTERNAL_ERROR, "true"); + workerRequest.headers.set(HEADER_ENCODING, encoding); + if (params) { + workerRequest.headers.set( + HEADER_CONN_PARAMS, + JSON.stringify(params), + ); + } + if (authData) { + workerRequest.headers.set( + HEADER_AUTH_DATA, + JSON.stringify(authData), + ); + } + const id = c.env.WORKER_DO.idFromString(workerId); const stub = c.env.WORKER_DO.get(id); diff --git a/packages/platforms/cloudflare-workers/src/manager-driver.ts b/packages/platforms/cloudflare-workers/src/manager-driver.ts index d010dc6fc..3e5f54325 100644 --- a/packages/platforms/cloudflare-workers/src/manager-driver.ts +++ b/packages/platforms/cloudflare-workers/src/manager-driver.ts @@ -49,14 +49,10 @@ export class CloudflareWorkersManagerDriver implements ManagerDriver { return undefined; } - // Generate durable ID from workerId for meta - const durableId = env.WORKER_DO.idFromString(workerId); - return { workerId, name: workerData.name, key: workerData.key, - meta: durableId, }; } @@ -74,8 +70,7 @@ export class CloudflareWorkersManagerDriver implements ManagerDriver { // Generate deterministic ID from the name and key // This is aligned with how createWorker generates IDs const nameKeyString = serializeNameAndKey(name, key); - const durableId = env.WORKER_DO.idFromName(nameKeyString); - const workerId = durableId.toString(); + const workerId = env.WORKER_DO.idFromName(nameKeyString).toString(); // Check if the worker metadata exists const workerData = await env.WORKER_KV.get(KEYS.WORKER.metadata(workerId), { @@ -128,11 +123,11 @@ export class CloudflareWorkersManagerDriver implements ManagerDriver { // Create a deterministic ID from the worker name and key // This ensures that workers with the same name and key will have the same ID const nameKeyString = serializeNameAndKey(name, key); - const durableId = env.WORKER_DO.idFromName(nameKeyString); - const workerId = durableId.toString(); + const doId = env.WORKER_DO.idFromName(nameKeyString); + const workerId = doId.toString(); // Init worker - const worker = env.WORKER_DO.get(durableId); + const worker = env.WORKER_DO.get(doId); await worker.initialize({ name, key, @@ -153,7 +148,6 @@ export class CloudflareWorkersManagerDriver implements ManagerDriver { workerId, name, key, - meta: durableId, }; } @@ -172,14 +166,10 @@ export class CloudflareWorkersManagerDriver implements ManagerDriver { return undefined; } - // Generate durable ID for meta - const durableId = env.WORKER_DO.idFromString(workerId); - return { workerId, name: workerData.name, key: workerData.key, - meta: durableId, }; } } diff --git a/packages/platforms/cloudflare-workers/src/websocket.ts b/packages/platforms/cloudflare-workers/src/websocket.ts index db3313d3b..37229452e 100644 --- a/packages/platforms/cloudflare-workers/src/websocket.ts +++ b/packages/platforms/cloudflare-workers/src/websocket.ts @@ -59,6 +59,12 @@ export const upgradeWebSocket: UpgradeWebSocket< return new Response(null, { status: 101, + headers: { + // HACK: Required in order for Cloudflare to not error with "Network connection lost" + // + // This bug undocumented. Cannot easily reproduce outside of RivetKit. + "Sec-WebSocket-Protocol": "rivetkit", + }, webSocket: client, }); }); diff --git a/packages/platforms/rivet/src/globals.d.ts b/packages/platforms/rivet/src/globals.d.ts new file mode 100644 index 000000000..bac4a2804 --- /dev/null +++ b/packages/platforms/rivet/src/globals.d.ts @@ -0,0 +1 @@ +declare const Deno: any; diff --git a/packages/platforms/rivet/src/manager-driver.ts b/packages/platforms/rivet/src/manager-driver.ts index 7805465fc..3b5f36995 100644 --- a/packages/platforms/rivet/src/manager-driver.ts +++ b/packages/platforms/rivet/src/manager-driver.ts @@ -9,8 +9,22 @@ import type { CreateInput, } from "rivetkit/driver-helpers"; import { logger } from "./log"; -import { type RivetClientConfig, rivetRequest } from "./rivet-client"; -import { serializeKeyForTag, deserializeKeyFromTag } from "./util"; +import { + RivetActor, + type RivetClientConfig, + rivetRequest, +} from "./rivet-client"; +import { + serializeKeyForTag, + deserializeKeyFromTag, + convertKeyToRivetTags, +} from "./util"; +import { + getWorkerMeta, + getWorkerMetaWithKey, + populateCache, +} from "./worker-meta"; +import invariant from "invariant"; export interface WorkerState { key: string[]; @@ -32,39 +46,16 @@ export class RivetManagerDriver implements ManagerDriver { workerId, }: GetForIdInput): Promise { try { - // Get actor - const res = await rivetRequest( - this.#clientConfig, - "GET", - `/actors/${encodeURIComponent(workerId)}`, - ); - - // Check if worker exists and not destroyed - if (res.actor.destroyedAt) { - return undefined; - } - - // Ensure worker has required tags - if (!("name" in res.actor.tags)) { - throw new Error(`Worker ${res.actor.id} missing 'name' in tags.`); - } - if (res.actor.tags.role !== "worker") { - throw new Error(`Worker ${res.actor.id} does not have a worker role.`); - } - if (res.actor.tags.framework !== "rivetkit") { - throw new Error(`Worker ${res.actor.id} is not an RivetKit worker.`); - } + const meta = await getWorkerMeta(this.#clientConfig, workerId); + if (!meta) return undefined; return { - workerId: res.actor.id, - name: res.actor.tags.name, - key: this.#extractKeyFromRivetTags(res.actor.tags), - meta: { - endpoint: buildWorkerEndpoint(res.actor), - } satisfies GetWorkerMeta, + workerId, + name: meta.name, + key: meta.key, }; } catch (error) { - // Handle not found or other errors + // TODO: Handle not found or other errors gracefully return undefined; } } @@ -73,48 +64,13 @@ export class RivetManagerDriver implements ManagerDriver { name, key, }: GetWithKeyInput): Promise { - // Convert key array to Rivet's tag format - const rivetTags = this.#convertKeyToRivetTags(name, key); - - // Query actors with matching tags - const { actors } = await rivetRequest( - this.#clientConfig, - "GET", - `/actors?tags_json=${encodeURIComponent(JSON.stringify(rivetTags))}`, - ); - - // Filter workers to ensure they're valid - const validActors = actors.filter((a: RivetActor) => { - // Verify all ports have hostname and port - for (const portName in a.network.ports) { - const port = a.network.ports[portName]; - if (!port.hostname || !port.port) return false; - } - return true; - }); - - if (validActors.length === 0) { - return undefined; - } - - // For consistent results, sort by ID if multiple actors match - const actor = - validActors.length > 1 - ? validActors.sort((a, b) => a.id.localeCompare(b.id))[0] - : validActors[0]; - - // Ensure actor has required tags - if (!("name" in actor.tags)) { - throw new Error(`Worker ${actor.id} missing 'name' in tags.`); - } + const meta = await getWorkerMetaWithKey(this.#clientConfig, name, key); + if (!meta) return undefined; return { - workerId: actor.id, - name: actor.tags.name, - key: this.#extractKeyFromRivetTags(actor.tags), - meta: { - endpoint: buildWorkerEndpoint(actor), - } satisfies GetWorkerMeta, + workerId: meta.workerId, + name: meta.name, + key: meta.key, }; } @@ -151,7 +107,7 @@ export class RivetManagerDriver implements ManagerDriver { } const createRequest = { - tags: this.#convertKeyToRivetTags(name, key), + tags: convertKeyToRivetTags(name, key), build_tags: { role: "worker", framework: "rivetkit", @@ -186,10 +142,12 @@ export class RivetManagerDriver implements ManagerDriver { { actor: RivetActor } >(this.#clientConfig, "POST", "/actors", createRequest); + const meta = populateCache(actor); + invariant(meta, "actor just created, should not be destroyed"); + // Initialize the worker try { - const endpoint = buildWorkerEndpoint(actor); - const url = `${endpoint}/initialize`; + const url = `${meta.endpoint}/initialize`; logger().debug("initializing worker", { url, input: JSON.stringify(input), @@ -225,84 +183,8 @@ export class RivetManagerDriver implements ManagerDriver { return { workerId: actor.id, - name, - key: this.#extractKeyFromRivetTags(actor.tags), - meta: { - endpoint: buildWorkerEndpoint(actor), - } satisfies GetWorkerMeta, + name: meta.name, + key: meta.key, }; } - - // Helper method to convert a key array to Rivet's tag-based format - #convertKeyToRivetTags(name: string, key: string[]): Record { - return { - name, - key: serializeKeyForTag(key), - role: "worker", - framework: "rivetkit", - }; - } - - // Helper method to extract key array from Rivet's tag-based format - #extractKeyFromRivetTags(tags: Record): string[] { - return deserializeKeyFromTag(tags.key); - } - - async #getBuildWithTags( - buildTags: Record, - ): Promise { - // Query builds with matching tags - const { builds } = await rivetRequest( - this.#clientConfig, - "GET", - `/builds?tags_json=${encodeURIComponent(JSON.stringify(buildTags))}`, - ); - - if (builds.length === 0) { - return undefined; - } - - // For consistent results, sort by ID if multiple builds match - return builds.length > 1 - ? builds.sort((a, b) => a.id.localeCompare(b.id))[0] - : builds[0]; - } } - -function buildWorkerEndpoint(worker: RivetActor): string { - // Fetch port - const httpPort = worker.network.ports.http; - if (!httpPort) throw new Error("missing http port"); - let hostname = httpPort.hostname; - if (!hostname) throw new Error("missing hostname"); - const port = httpPort.port; - if (!port) throw new Error("missing port"); - - let isTls = false; - switch (httpPort.protocol) { - case "https": - isTls = true; - break; - case "http": - case "tcp": - isTls = false; - break; - case "tcp_tls": - case "udp": - throw new Error(`Invalid protocol ${httpPort.protocol}`); - default: - assertUnreachable(httpPort.protocol as never); - } - - const path = httpPort.path ?? ""; - - // HACK: Fix hostname inside of Docker Compose - if (hostname === "127.0.0.1") hostname = "rivet-guard"; - - return `${isTls ? "https" : "http"}://${hostname}:${port}${path}`; -} - -// biome-ignore lint/suspicious/noExplicitAny: will add api types later -type RivetActor = any; -// biome-ignore lint/suspicious/noExplicitAny: will add api types later -type RivetBuild = any; diff --git a/packages/platforms/rivet/src/manager.ts b/packages/platforms/rivet/src/manager.ts index 4f0ca3d99..99b654259 100644 --- a/packages/platforms/rivet/src/manager.ts +++ b/packages/platforms/rivet/src/manager.ts @@ -10,6 +10,14 @@ import invariant from "invariant"; import { ConfigSchema, InputConfig } from "./config"; import type { Registry } from "rivetkit"; import { createWebSocketProxy } from "./ws-proxy"; +import { flushCache, getWorkerMeta } from "./worker-meta"; +import { + HEADER_AUTH_DATA, + HEADER_CONN_PARAMS, + HEADER_ENCODING, + HEADER_EXPOSE_INTERNAL_ERROR, +} from "rivetkit/driver-helpers"; +import { importWebSocket } from "rivetkit/driver-helpers/websocket"; export async function startManager( registry: Registry, @@ -96,12 +104,12 @@ export async function startManager( registry.config, driverConfig, { - sendRequest: async (workerId, meta, workerRequest) => { - invariant(meta, "meta not provided"); - const workerMeta = meta as GetWorkerMeta; + sendRequest: async (workerId, workerRequest) => { + const meta = await getWorkerMeta(clientConfig, workerId); + invariant(meta, "worker should exist"); const parsedRequestUrl = new URL(workerRequest.url); - const workerUrl = `${workerMeta.endpoint}${parsedRequestUrl.pathname}${parsedRequestUrl.search}`; + const workerUrl = `${meta.endpoint}${parsedRequestUrl.pathname}${parsedRequestUrl.search}`; logger().debug("proxying request to rivet worker", { method: workerRequest.method, @@ -111,25 +119,35 @@ export async function startManager( const proxyRequest = new Request(workerUrl, workerRequest); return await fetch(proxyRequest); }, - openWebSocket: async (workerId, meta, encodingKind) => { - invariant(meta, "meta not provided"); - const workerMeta = meta as GetWorkerMeta; + openWebSocket: async (workerId, encodingKind, params: unknown) => { + const WebSocket = await importWebSocket(); + + const meta = await getWorkerMeta(clientConfig, workerId); + invariant(meta, "worker should exist"); - // Create WebSocket URL with encoding parameter - const wsEndpoint = workerMeta.endpoint.replace(/^http/, "ws"); - const url = `${wsEndpoint}/connect/websocket?encoding=${encodingKind}&expose-internal-error=true`; + const wsEndpoint = meta.endpoint.replace(/^http/, "ws"); + const url = `${wsEndpoint}/connect/websocket`; + + const headers: Record = { + Upgrade: "websocket", + Connection: "Upgrade", + [HEADER_EXPOSE_INTERNAL_ERROR]: "true", + [HEADER_ENCODING]: encodingKind, + }; + if (params) { + headers[HEADER_CONN_PARAMS] = JSON.stringify(params); + } logger().debug("opening websocket to worker", { workerId, url, }); - // Open WebSocket connection - return new WebSocket(url); + return new WebSocket(url, { headers }); }, - proxyRequest: async (c, workerRequest, _workerId, metaRaw) => { - invariant(metaRaw, "meta not provided"); - const meta = metaRaw as GetWorkerMeta; + proxyRequest: async (c, workerRequest, workerId) => { + const meta = await getWorkerMeta(clientConfig, workerId); + invariant(meta, "worker should exist"); const parsedRequestUrl = new URL(workerRequest.url); const workerUrl = `${meta.endpoint}${parsedRequestUrl.pathname}${parsedRequestUrl.search}`; @@ -142,9 +160,17 @@ export async function startManager( const proxyRequest = new Request(workerUrl, workerRequest); return await proxy(proxyRequest); }, - proxyWebSocket: async (c, path, _workerId, metaRaw, upgradeWebSocket) => { - invariant(metaRaw, "meta not provided"); - const meta = metaRaw as GetWorkerMeta; + proxyWebSocket: async ( + c, + path, + workerId, + encoding, + connParmas, + authData, + upgradeWebSocket, + ) => { + const meta = await getWorkerMeta(clientConfig, workerId); + invariant(meta, "worker should exist"); const workerUrl = `${meta.endpoint}${path}`; @@ -152,18 +178,33 @@ export async function startManager( url: workerUrl, }); - const handlers = createWebSocketProxy(workerUrl); + // Build headers + const headers: Record = { + [HEADER_EXPOSE_INTERNAL_ERROR]: "true", + [HEADER_ENCODING]: encoding, + }; + if (connParmas) { + headers[HEADER_CONN_PARAMS] = JSON.stringify(connParmas); + } + if (authData) { + headers[HEADER_AUTH_DATA] = JSON.stringify(authData); + } + + const handlers = await createWebSocketProxy(workerUrl, headers); // upgradeWebSocket is middleware, so we need to pass fake handlers invariant(upgradeWebSocket, "missing upgradeWebSocket"); - return upgradeWebSocket((c) => createWebSocketProxy(workerUrl))( - c, - async () => {}, - ); + return upgradeWebSocket((c) => handlers)(c, async () => {}); }, }, ); + // HACK: Expose endpoint for tests to flush cache + managerTopology.router.post("/.test/rivet/flush-cache", (c) => { + flushCache(); + return c.text("ok"); + }); + // Start server with ambient env wrapper logger().info("server running", { port }); const server = honoServe({ @@ -174,37 +215,3 @@ export async function startManager( if (!injectWebSocket) throw new Error("injectWebSocket not defined"); injectWebSocket(server); } - -// import { Hono } from "hono"; -// import { serve } from "@hono/node-server"; -// import { upgradeWebSocket } from "hono/cloudflare-workers"; -// import { logger as honoLogger } from "hono/logger"; -// -// export async function startManager( -// registry: Registry, -// inputConfig?: InputConfig, -// ): Promise { -// const port = parseInt(process.env.PORT_HTTP!); -// -// const router = new Hono(); -// router.use(honoLogger()); -// -// const { injectWebSocket, upgradeWebSocket } = createNodeWebSocket({ -// app: router, -// }); -// -// router.get("/", (c) => { -// return c.text("Hello Hono!"); -// }); -// -// console.log(`Server is running on port ${port}`); -// -// const server = serve({ -// fetch: router.fetch, -// hostname: "0.0.0.0", -// port, -// }); -// injectWebSocket(server); -// -// console.log(`WS injected`); -// } diff --git a/packages/platforms/rivet/src/rivet-client.ts b/packages/platforms/rivet/src/rivet-client.ts index 28f42035b..4e39cbe0a 100644 --- a/packages/platforms/rivet/src/rivet-client.ts +++ b/packages/platforms/rivet/src/rivet-client.ts @@ -7,6 +7,11 @@ export interface RivetClientConfig { environment?: string; } +// biome-ignore lint/suspicious/noExplicitAny: will add api types later +export type RivetActor = any; +// biome-ignore lint/suspicious/noExplicitAny: will add api types later +export type RivetBuild = any; + export async function rivetRequest( config: RivetClientConfig, method: string, @@ -32,11 +37,11 @@ export async function rivetRequest( }); if (!response.ok) { - const errorData = await response.json().catch(() => ({})); + const errorData: any = await response.json().catch(() => ({})); throw new Error( `Rivet API error (${response.status}, ${method} ${url}): ${errorData.message || response.statusText}`, ); } - return response.json(); + return (await response.json()) as ResponseBody; } diff --git a/packages/platforms/rivet/src/util.ts b/packages/platforms/rivet/src/util.ts index 83a975c8a..841db8a93 100644 --- a/packages/platforms/rivet/src/util.ts +++ b/packages/platforms/rivet/src/util.ts @@ -42,9 +42,9 @@ export function serializeKeyForTag(key: string[]): string { * @param keyString The serialized key string from a tag * @returns Array of key strings */ -export function deserializeKeyFromTag(keyString: string): string[] { +export function deserializeKeyFromTag(keyString?: string): string[] { // Check for special empty key marker - if (keyString === EMPTY_KEY) { + if (!keyString || keyString === EMPTY_KEY) { return []; } @@ -80,3 +80,16 @@ export function deserializeKeyFromTag(keyString: string): string[] { return parts; } + +// Helper method to convert a key array to Rivet's tag-based format +export function convertKeyToRivetTags( + name: string, + key: string[], +): Record { + return { + name, + key: serializeKeyForTag(key), + role: "worker", + framework: "rivetkit", + }; +} diff --git a/packages/platforms/rivet/src/worker-meta.ts b/packages/platforms/rivet/src/worker-meta.ts new file mode 100644 index 000000000..ebe8e1f50 --- /dev/null +++ b/packages/platforms/rivet/src/worker-meta.ts @@ -0,0 +1,240 @@ +import { assertUnreachable } from "rivetkit/utils"; +import { RivetActor, RivetClientConfig, rivetRequest } from "./rivet-client"; +import { deserializeKeyFromTag, convertKeyToRivetTags } from "./util"; +import invariant from "invariant"; + +interface WorkerMeta { + name: string; + key: string[]; + endpoint: string; +} + +interface WorkerMetaWithId extends WorkerMeta { + workerId: string; +} + +// TODO: Implement LRU cache +// Cache for worker ID -> worker meta +const WORKER_META_CACHE = new Map>(); + +// TODO: Implement LRU cache +// Cache for worker name+key -> worker ID +const WORKER_KEY_CACHE = new Map>(); + +/** + * Creates a cache key for worker name and key combination. + */ +function createKeysCacheKey(name: string, key: string[]): string { + return `${name}:${JSON.stringify(key)}`; +} + +/** + * Returns worker metadata with an in-memory cache. + */ +export async function getWorkerMeta( + clientConfig: RivetClientConfig, + workerId: string, +): Promise { + // TODO: This does not refresh cache when workers are destroyed. This + // will be replaced with hot pulls from the Rivet API once (a) worker + // IDs include the datacenter in order to build endpoints without + // hitting the API and (b) we update the API to hit the regional + // endpoints. + + const workerMetaPromise = WORKER_META_CACHE.get(workerId); + if (workerMetaPromise) { + return await workerMetaPromise; + } else { + // Fetch meta + const promise = (async () => { + const { actor } = await rivetRequest( + clientConfig, + "GET", + `/actors/${encodeURIComponent(workerId)}`, + ); + + return convertActorToMeta(actor); + })(); + WORKER_META_CACHE.set(workerId, promise); + + // Remove from cache on failure so it can be retried + promise.catch(() => { + WORKER_META_CACHE.delete(workerId); + }); + + return await promise; + } +} + +/** + * Returns worker metadata for a worker with the given name and key. + */ +export async function getWorkerMetaWithKey( + clientConfig: RivetClientConfig, + name: string, + key: string[], +): Promise { + const cacheKey = createKeysCacheKey(name, key); + + // Check if we have the worker ID cached + const cachedWorkerIdPromise = WORKER_KEY_CACHE.get(cacheKey); + if (cachedWorkerIdPromise) { + const workerId = await cachedWorkerIdPromise; + if (workerId) { + // Try to get the worker metadata from the ID cache + const meta = await getWorkerMeta(clientConfig, workerId); + if (meta) { + return { + ...meta, + workerId, + }; + } + // If metadata is not available, remove from key cache and continue with fresh lookup + WORKER_KEY_CACHE.delete(cacheKey); + } + } + + // Cache miss or invalid cached data, perform fresh lookup + const promise = (async () => { + // Convert key array to Rivet's tag format + const rivetTags = convertKeyToRivetTags(name, key); + + // Query actors with matching tags + const { actors } = await rivetRequest( + clientConfig, + "GET", + `/actors?tags_json=${encodeURIComponent(JSON.stringify(rivetTags))}`, + ); + + // Filter workers to ensure they're valid + const validActors = actors.filter((a: RivetActor) => { + // Verify all ports have hostname and port + for (const portName in a.network.ports) { + const port = a.network.ports[portName]; + if (!port.hostname || !port.port) return false; + } + return true; + }); + + if (validActors.length === 0) { + // Remove from cache if not found since we might create an actor + // with this key + WORKER_KEY_CACHE.delete(cacheKey); + + return undefined; + } + + // For consistent results, sort by ID if multiple actors match + const actor = + validActors.length > 1 + ? validActors.sort((a, b) => a.id.localeCompare(b.id))[0] + : validActors[0]; + + // Populate both caches + const meta = populateCache(actor); + invariant(meta, "actor should not be destroyed"); + + return actor.id; + })(); + + WORKER_KEY_CACHE.set(cacheKey, promise); + + // Remove from cache on failure so it can be retried + promise.catch(() => { + WORKER_KEY_CACHE.delete(cacheKey); + }); + + const workerId = await promise; + if (!workerId) { + return undefined; + } + + const meta = await getWorkerMeta(clientConfig, workerId); + invariant(meta, "worker metadata should be available after populating cache"); + + return { + ...meta, + workerId, + }; +} + +/** + * Preemptively adds an entry to the cache. + */ +export function populateCache(actor: RivetActor): WorkerMeta | undefined { + const meta = convertActorToMeta(actor); + if (meta) { + // Populate the worker ID -> metadata cache + WORKER_META_CACHE.set(actor.id, Promise.resolve(meta)); + + // Populate the name+key -> worker ID cache + const cacheKey = createKeysCacheKey(meta.name, meta.key); + WORKER_KEY_CACHE.set(cacheKey, Promise.resolve(actor.id)); + } + return meta; +} + +/** + * Converts actor data from the Rivet API to worker metadata. + */ +function convertActorToMeta(actor: RivetActor): WorkerMeta | undefined { + // Check if worker exists and not destroyed + if (actor.destroyedAt) { + return undefined; + } + + // Ensure worker has required tags + if (!("name" in actor.tags)) { + throw new Error(`Worker ${actor.id} missing 'name' in tags.`); + } + if (actor.tags.role !== "worker") { + throw new Error(`Worker ${actor.id} does not have a worker role.`); + } + if (actor.tags.framework !== "rivetkit") { + throw new Error(`Worker ${actor.id} is not an RivetKit worker.`); + } + + return { + name: actor.tags.name, + key: deserializeKeyFromTag(actor.tags.key), + endpoint: buildWorkerEndpoint(actor), + }; +} + +function buildWorkerEndpoint(actor: RivetActor): string { + // Fetch port + const httpPort = actor.network.ports.http; + if (!httpPort) throw new Error("missing http port"); + let hostname = httpPort.hostname; + if (!hostname) throw new Error("missing hostname"); + const port = httpPort.port; + if (!port) throw new Error("missing port"); + + let isTls = false; + switch (httpPort.protocol) { + case "https": + isTls = true; + break; + case "http": + case "tcp": + isTls = false; + break; + case "tcp_tls": + case "udp": + throw new Error(`Invalid protocol ${httpPort.protocol}`); + default: + assertUnreachable(httpPort.protocol as never); + } + + const path = httpPort.path ?? ""; + + // HACK: Fix hostname inside of Docker Compose + if (hostname === "127.0.0.1") hostname = "rivet-guard"; + + return `${isTls ? "https" : "http"}://${hostname}:${port}${path}`; +} + +export function flushCache() { + WORKER_META_CACHE.clear(); + WORKER_KEY_CACHE.clear(); +} diff --git a/packages/platforms/rivet/src/ws-proxy.ts b/packages/platforms/rivet/src/ws-proxy.ts index 0895c08d4..7bba7eeed 100644 --- a/packages/platforms/rivet/src/ws-proxy.ts +++ b/packages/platforms/rivet/src/ws-proxy.ts @@ -1,7 +1,8 @@ import { WSContext } from "hono/ws"; -import { Context } from "hono"; import { logger } from "./log"; import invariant from "invariant"; +import type { WebSocket, CloseEvent } from "ws"; +import { importWebSocket } from "rivetkit/driver-helpers/websocket"; /** * Creates a WebSocket proxy to forward connections to a target endpoint @@ -10,14 +11,19 @@ import invariant from "invariant"; * @param targetUrl Target WebSocket URL to proxy to * @returns Response with upgraded WebSocket */ -export function createWebSocketProxy(targetUrl: string) { +export async function createWebSocketProxy( + targetUrl: string, + headers: Record, +) { + const WebSocket = await importWebSocket(); + let targetWs: WebSocket | undefined = undefined; const messageQueue: any[] = []; return { onOpen: (_evt: any, wsContext: WSContext) => { // Create target WebSocket connection - targetWs = new WebSocket(targetUrl); + targetWs = new WebSocket(targetUrl, { headers }); // Set up target websocket handlers targetWs.onopen = () => { @@ -34,7 +40,7 @@ export function createWebSocketProxy(targetUrl: string) { }; targetWs.onmessage = (event) => { - wsContext.send(event.data); + wsContext.send(event.data as any); }; targetWs.onclose = (event) => { diff --git a/packages/platforms/rivet/tests/driver-tests.test.ts b/packages/platforms/rivet/tests/driver-tests.test.ts index 6706ed5b2..ed0dfa74a 100644 --- a/packages/platforms/rivet/tests/driver-tests.test.ts +++ b/packages/platforms/rivet/tests/driver-tests.test.ts @@ -1,6 +1,7 @@ import { runDriverTests } from "rivetkit/driver-test-suite"; import { deployToRivet, rivetClientConfig } from "./rivet-deploy"; import { RivetClientConfig, rivetRequest } from "../src/rivet-client"; +import invariant from "invariant"; let deployProjectOnce: Promise | undefined = undefined; @@ -18,6 +19,12 @@ runDriverTests({ // Cleanup workers from previous tests await deleteAllWorkers(rivetClientConfig); + // Flush cache since we manually updated the workers + const res = await fetch(`${endpoint}/.test/rivet/flush-cache`, { + method: "POST", + }); + invariant(res.ok, `request failed: ${res.status}`); + return { endpoint, async cleanup() { diff --git a/packages/platforms/rivet/tsconfig.json b/packages/platforms/rivet/tsconfig.json index 0c116dfe0..666d38420 100644 --- a/packages/platforms/rivet/tsconfig.json +++ b/packages/platforms/rivet/tsconfig.json @@ -1,7 +1,7 @@ { "extends": "../../../tsconfig.base.json", "compilerOptions": { - "types": ["deno", "node"], + "types": ["node"], "paths": { "@/*": ["./src/*"] } diff --git a/tsconfig.base.json b/tsconfig.base.json index 180de2f47..5719fb942 100644 --- a/tsconfig.base.json +++ b/tsconfig.base.json @@ -8,7 +8,7 @@ "allowSyntheticDefaultImports": true, "stripInternal": true, "moduleResolution": "bundler", - "lib": ["ESNext", "DOM"], + "lib": ["ESNext"], "types": ["node"] } } diff --git a/turbo.json b/turbo.json index 7d236df3c..719e5beda 100644 --- a/turbo.json +++ b/turbo.json @@ -1,4 +1,5 @@ -{ "$schema": "https://turbo.build/schema.json", +{ + "$schema": "https://turbo.build/schema.json", "tasks": { "//#fmt": { "cache": false @@ -9,7 +10,13 @@ "outputs": ["dist/**"] }, "check-types": { - "inputs": ["src/**", "tests/**", "tsconfig.json", "tsup.config.ts", "package.json"], + "inputs": [ + "src/**", + "tests/**", + "tsconfig.json", + "tsup.config.ts", + "package.json" + ], "dependsOn": ["^build"] }, "dev": { @@ -17,7 +24,8 @@ "dependsOn": ["build", "^check-types", "check-types"] }, "test": { - "dependsOn": ["^build", "check-types"] + "dependsOn": ["^build", "check-types"], + "env": ["_RIVETKIT_ERROR_STACK"] } } } diff --git a/yarn.lock b/yarn.lock index 1cb9c575b..097a5afd7 100644 --- a/yarn.lock +++ b/yarn.lock @@ -1737,7 +1737,7 @@ __metadata: languageName: unknown linkType: soft -"@rivetkit/nodejs@workspace:packages/platforms/nodejs": +"@rivetkit/nodejs@workspace:*, @rivetkit/nodejs@workspace:packages/platforms/nodejs": version: 0.0.0-use.local resolution: "@rivetkit/nodejs@workspace:packages/platforms/nodejs" dependencies: @@ -2684,6 +2684,7 @@ __metadata: version: 0.0.0-use.local resolution: "chat-room-python@workspace:examples/chat-room-python" dependencies: + "@rivetkit/nodejs": "workspace:*" "@types/node": "npm:^22.13.9" rivetkit: "workspace:*" tsx: "npm:^3.12.7" @@ -2695,6 +2696,7 @@ __metadata: version: 0.0.0-use.local resolution: "chat-room@workspace:examples/chat-room" dependencies: + "@rivetkit/nodejs": "workspace:*" "@types/node": "npm:^22.13.9" "@types/prompts": "npm:^2" prompts: "npm:^2.4.2" @@ -2920,6 +2922,7 @@ __metadata: version: 0.0.0-use.local resolution: "counter@workspace:examples/counter" dependencies: + "@rivetkit/nodejs": "workspace:*" "@types/node": "npm:^22.13.9" rivetkit: "workspace:*" tsx: "npm:^3.12.7" @@ -4620,6 +4623,7 @@ __metadata: "@hono/node-server": "npm:^1.14.1" "@linear/sdk": "npm:^7.0.0" "@octokit/rest": "npm:^19.0.13" + "@rivetkit/nodejs": "workspace:*" "@types/dotenv": "npm:^8.2.3" "@types/express": "npm:^5" "@types/node": "npm:^22.13.9" @@ -5581,6 +5585,7 @@ __metadata: resolution: "resend-streaks@workspace:examples/resend-streaks" dependencies: "@date-fns/tz": "npm:^1.2.0" + "@rivetkit/nodejs": "workspace:*" "@types/node": "npm:^22.13.9" date-fns: "npm:^4.1.0" resend: "npm:^2.0.0"