Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions src/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ import { Conversation, compareConversationObject, getConversation } from "./conv
* @param modelId The model to load, needs to either be in `webllm.prebuiltAppConfig`, or in
* `engineConfig.appConfig`.
* @param engineConfig Optionally configures the engine, see `webllm.EngineConfig`.
* @returns An intialized `WebLLM.Engine` with `modelId` loaded.
* @returns An initialized `WebLLM.Engine` with `modelId` loaded.
* @throws Throws error when device lost (mostly due to OOM); users should re-call `CreateEngine()`,
* potentially with a smaller model or smaller context window size.
*/
export async function CreateEngine(
modelId: string,
Expand Down Expand Up @@ -70,7 +72,7 @@ export class Engine implements EngineInterface {
private pipeline?: LLMChatPipeline;
private initProgressCallback?: InitProgressCallback;
private interruptSignal = false;
private deviceLostIsError = false; // whether device.lost is due to actual error or model reload
private deviceLostIsError = true; // whether device.lost is due to actual error or model reload
private config?: ChatConfig;

constructor() {
Expand All @@ -89,8 +91,16 @@ export class Engine implements EngineInterface {
this.logitProcessorRegistry = logitProcessorRegistry;
}

/**
* Reload model `modelId`.
* @param modelId The model to load, needs to either be in `webllm.prebuiltAppConfig`, or in
* `engineConfig.appConfig`.
* @param chatOpts To optionally override the `mlc-chat-config.json` of `modelId`.
* @param appConfig Configure the app with the list of models and whether to use IndexedDB cache.
* @throws Throws error when device lost (mostly due to OOM); users should re-call reload(),
* potentially with a smaller model or smaller context window size.
*/
async reload(modelId: string, chatOpts?: ChatOptions, appConfig?: AppConfig): Promise<void> {
this.deviceLostIsError = false; // so that unload() does not trigger device.lost warning
this.unload();

this.logitProcessor = this.logitProcessorRegistry?.get(modelId);
Expand Down Expand Up @@ -195,15 +205,21 @@ export class Engine implements EngineInterface {
}
}

tvm.initWebGPU(gpuDetectOutput.device);
// Most device lost happens in `reload()` since we allocate memory ahead of time. So we can
// use this flag at the end of `reload()` to make the error handling synchronous.
// This `.then()` exists throughout the lifetime of the device. Though we have not
// experienced device error outside of `reload()`, it is still possible this `.then()` is
// triggered outside of `reload()`. TODO: does this cause unexpected behavior?
let deviceLostInReload = false;
gpuDetectOutput.device.lost.then((info: any) => {
// `fetchNDArrayCache` may exceed available memory; use `lost.then` to prevent crashing
if (this.deviceLostIsError) {
console.error("Device was lost, please try to initialize again. ", info);
this.unload();
deviceLostInReload = true;
}
});
this.deviceLostIsError = true;
tvm.initWebGPU(gpuDetectOutput.device);

const tokenizer = await this.asyncLoadTokenizer(modelUrl, this.config, appConfig);
const cacheType = appConfig.useIndexedDBCache ? "indexeddb" : "cache";
await tvm.fetchNDArrayCache(modelUrl, tvm.webgpu(), "webllm/model", cacheType);
Expand All @@ -220,6 +236,13 @@ export class Engine implements EngineInterface {
})
}
this.currentModelId = modelId;

if (deviceLostInReload) {
throw Error(
"WebGPU device lost during `reload()`.\n This is probably due to OOM, try reload with a " +
"model that has less parameters or a smaller context length."
);
}
}

async generate(
Expand Down Expand Up @@ -479,9 +502,11 @@ export class Engine implements EngineInterface {
}

async unload() {
this.deviceLostIsError = false; // so that unload() does not trigger device.lost error
this.pipeline?.dispose();
this.pipeline = undefined;
this.currentModelId = undefined;
this.deviceLostIsError = true;
}

async getMaxStorageBufferBindingSize(): Promise<number> {
Expand Down