Skip to content

Commit ddac6d1

Browse files
[Engine] Allow manually aborting reload, fix unexpected deviceLostError (#525)
### Manually aborting reload This PR updates the engine `reload()` and `unload()` methods to allow users to abort an uncompleted `reload()` by either: - call `unload()` any time before `reload()` completed - call `reload()` again before the previous `reload()` completed ### Note on unload() and unexpected device lost error Previously, we had an issue where a device lost error is reported when we simply switch a model intentionally (i.e. calling `reload()`). This is because `unload()` sets `deviceLostIsError` back to true immediately after calling `this.pipeline.dispose()`, which destroys the WebGPU device internally. However, WebGPU is asynchronous and may not finish after `dispose()` returns. This PR also fixes this issue by making `unload()` wait until the device is actually destroyed by introducing `LLMChatPipeline.sync()`. --------- Co-authored-by: Charlie Ruan <[email protected]>
1 parent 7690707 commit ddac6d1

File tree

6 files changed

+139
-4
lines changed

6 files changed

+139
-4
lines changed

examples/abort-reload/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# WebLLM Get Started App
2+
3+
This folder provides a demo for cancelling model fetching after calling `engine.reload()`.
4+
5+
```bash
6+
npm install
7+
npm start
8+
```
9+
10+
Note if you would like to hack WebLLM core package.
11+
You can change web-llm dependencies as `"file:../.."`, and follow the build from source
12+
instruction in the project to build webllm locally. This option is only recommended
13+
if you would like to hack WebLLM core package.

examples/abort-reload/package.json

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
{
2+
"name": "get-started",
3+
"version": "0.1.0",
4+
"private": true,
5+
"scripts": {
6+
"start": "parcel src/get_started.html --port 8887",
7+
"build": "parcel build src/get_started.html --dist-dir lib"
8+
},
9+
"devDependencies": {
10+
"buffer": "^5.7.1",
11+
"parcel": "^2.8.3",
12+
"process": "^0.11.10",
13+
"tslib": "^2.3.1",
14+
"typescript": "^4.9.5",
15+
"url": "^0.11.3"
16+
},
17+
"dependencies": {
18+
"@mlc-ai/web-llm": "file:../../lib"
19+
}
20+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
<!doctype html>
2+
<html>
3+
<script>
4+
webLLMGlobal = {};
5+
</script>
6+
<body>
7+
<h2>WebLLM Test Page</h2>
8+
Open console to see output
9+
<br />
10+
<br />
11+
<label id="init-label"> </label>
12+
13+
<h3>Prompt</h3>
14+
<label id="prompt-label"> </label>
15+
16+
<h3>Response</h3>
17+
<label id="generate-label"> </label>
18+
<br />
19+
<label id="stats-label"> </label>
20+
21+
<script type="module" src="./get_started.js"></script>
22+
</body>
23+
</html>
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import * as webllm from "@mlc-ai/web-llm";
2+
import { error } from "loglevel";
3+
4+
let engine;
5+
6+
function setLabel(id, text) {
7+
const label = document.getElementById(id);
8+
if (label == null) {
9+
throw Error("Cannot find label " + id);
10+
}
11+
label.innerText = text;
12+
}
13+
14+
async function main() {
15+
const initProgressCallback = (report) => {
16+
console.log(report.text);
17+
setLabel("init-label", report.text);
18+
};
19+
// Option 1: If we do not specify appConfig, we use `prebuiltAppConfig` defined in `config.ts`
20+
const selectedModel = "Llama-3.1-8B-Instruct-q4f32_1-MLC";
21+
engine = new webllm.MLCEngine({
22+
initProgressCallback,
23+
});
24+
engine.reload(selectedModel);
25+
}
26+
main();
27+
setTimeout(() => {
28+
console.log("calling unload");
29+
engine.unload().catch((err) => {
30+
console.log(err);
31+
});
32+
}, 5000);

src/engine.ts

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ export class MLCEngine implements MLCEngineInterface {
103103
private initProgressCallback?: InitProgressCallback;
104104
private interruptSignal = false;
105105
private deviceLostIsError = true; // whether device.lost is due to actual error or model reload
106+
private reloadController: AbortController | undefined;
106107
private config?: ChatConfig;
107108
private appConfig: AppConfig;
108109

@@ -143,7 +144,25 @@ export class MLCEngine implements MLCEngineInterface {
143144
*/
144145
async reload(modelId: string, chatOpts?: ChatOptions): Promise<void> {
145146
await this.unload();
147+
this.reloadController = new AbortController();
146148

149+
try {
150+
await this.reloadInternal(modelId, chatOpts);
151+
} catch (error) {
152+
if (error instanceof DOMException && error.name === "AbortError") {
153+
log.warn("Reload() is aborted.", error.message);
154+
return;
155+
}
156+
throw error;
157+
} finally {
158+
this.reloadController = undefined;
159+
}
160+
}
161+
162+
private async reloadInternal(
163+
modelId: string,
164+
chatOpts?: ChatOptions,
165+
): Promise<void> {
147166
this.logitProcessor = this.logitProcessorRegistry?.get(modelId);
148167
const tstart = performance.now();
149168

@@ -175,7 +194,11 @@ export class MLCEngine implements MLCEngineInterface {
175194
// load config
176195
const configUrl = new URL("mlc-chat-config.json", modelUrl).href;
177196
this.config = {
178-
...(await configCache.fetchWithCache(configUrl, "json")),
197+
...(await configCache.fetchWithCache(
198+
configUrl,
199+
"json",
200+
this.reloadController?.signal,
201+
)),
179202
...modelRecord.overrides,
180203
...chatOpts,
181204
} as ChatConfig;
@@ -201,8 +224,11 @@ export class MLCEngine implements MLCEngineInterface {
201224
// rely on the normal caching strategy
202225
return (await fetch(new URL(wasmUrl, baseUrl).href)).arrayBuffer();
203226
} else {
204-
// use cache
205-
return await wasmCache.fetchWithCache(wasmUrl, "arraybuffer");
227+
return await wasmCache.fetchWithCache(
228+
wasmUrl,
229+
"arraybuffer",
230+
this.reloadController?.signal,
231+
);
206232
}
207233
};
208234
const wasmSource = await fetchWasmSource();
@@ -248,7 +274,8 @@ export class MLCEngine implements MLCEngineInterface {
248274
gpuDetectOutput.device.lost.then((info: any) => {
249275
if (this.deviceLostIsError) {
250276
log.error(
251-
`Device was lost during reload. This can happen due to insufficient memory or other GPU constraints. Detailed error: ${info}. Please try to reload WebLLM with a less resource-intensive model.`,
277+
`Device was lost. This can happen due to insufficient memory or other GPU constraints. ` +
278+
`Detailed error: ${info}. Please try to reload WebLLM with a less resource-intensive model.`,
252279
);
253280
this.unload();
254281
deviceLostInReload = true;
@@ -267,6 +294,7 @@ export class MLCEngine implements MLCEngineInterface {
267294
tvm.webgpu(),
268295
"webllm/model",
269296
cacheType,
297+
this.reloadController?.signal,
270298
);
271299
this.pipeline = new LLMChatPipeline(
272300
tvm,
@@ -646,12 +674,23 @@ export class MLCEngine implements MLCEngineInterface {
646674
this.pipeline?.resetChat(keepStats);
647675
}
648676

677+
/**
678+
* Unloads the currently loaded model and destroy the webgpu device. Waits
679+
* until the webgpu device finishes all submitted work and destroys itself.
680+
* @note This is an asynchronous function.
681+
*/
649682
async unload() {
650683
this.deviceLostIsError = false; // so that unload() does not trigger device.lost error
651684
this.pipeline?.dispose();
685+
// Wait until device is actually destroyed so we can safely set deviceLostIsError back to true
686+
await this.pipeline?.sync();
652687
this.pipeline = undefined;
653688
this.currentModelId = undefined;
654689
this.deviceLostIsError = true;
690+
if (this.reloadController) {
691+
this.reloadController.abort("Engine.unload() is called.");
692+
this.reloadController = undefined;
693+
}
655694
}
656695

657696
async getMaxStorageBufferBindingSize(): Promise<number> {

src/llm_chat.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,6 +1040,14 @@ export class LLMChatPipeline {
10401040
} as ChatCompletionTokenLogprob;
10411041
}
10421042

1043+
/**
1044+
* Synchronize the device.
1045+
*/
1046+
async sync(): Promise<void> {
1047+
// Is it equivalent to this.tvm.sync()?
1048+
await this.device.sync();
1049+
}
1050+
10431051
async evaluate() {
10441052
// run a canonical evaluation of the flow
10451053
this.resetKVCache();

0 commit comments

Comments
 (0)