You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
### Manually aborting reload
This PR updates the engine `reload()` and `unload()` methods to allow
users to abort an uncompleted `reload()` by either:
- call `unload()` any time before `reload()` completed
- call `reload()` again before the previous `reload()` completed
### Note on unload() and unexpected device lost error
Previously, we had an issue where a device lost error is reported when
we simply switch a model intentionally (i.e. calling `reload()`). This
is because `unload()` sets `deviceLostIsError` back to true immediately
after calling `this.pipeline.dispose()`, which destroys the WebGPU
device internally. However, WebGPU is asynchronous and may not finish
after `dispose()` returns. This PR also fixes this issue by making
`unload()` wait until the device is actually destroyed by introducing
`LLMChatPipeline.sync()`.
---------
Co-authored-by: Charlie Ruan <[email protected]>
@@ -248,7 +274,8 @@ export class MLCEngine implements MLCEngineInterface {
248
274
gpuDetectOutput.device.lost.then((info: any)=>{
249
275
if(this.deviceLostIsError){
250
276
log.error(
251
-
`Device was lost during reload. This can happen due to insufficient memory or other GPU constraints. Detailed error: ${info}. Please try to reload WebLLM with a less resource-intensive model.`,
277
+
`Device was lost. This can happen due to insufficient memory or other GPU constraints. `+
278
+
`Detailed error: ${info}. Please try to reload WebLLM with a less resource-intensive model.`,
252
279
);
253
280
this.unload();
254
281
deviceLostInReload=true;
@@ -267,6 +294,7 @@ export class MLCEngine implements MLCEngineInterface {
267
294
tvm.webgpu(),
268
295
"webllm/model",
269
296
cacheType,
297
+
this.reloadController?.signal,
270
298
);
271
299
this.pipeline=newLLMChatPipeline(
272
300
tvm,
@@ -646,12 +674,23 @@ export class MLCEngine implements MLCEngineInterface {
646
674
this.pipeline?.resetChat(keepStats);
647
675
}
648
676
677
+
/**
678
+
* Unloads the currently loaded model and destroy the webgpu device. Waits
679
+
* until the webgpu device finishes all submitted work and destroys itself.
680
+
* @note This is an asynchronous function.
681
+
*/
649
682
asyncunload(){
650
683
this.deviceLostIsError=false;// so that unload() does not trigger device.lost error
651
684
this.pipeline?.dispose();
685
+
// Wait until device is actually destroyed so we can safely set deviceLostIsError back to true
686
+
awaitthis.pipeline?.sync();
652
687
this.pipeline=undefined;
653
688
this.currentModelId=undefined;
654
689
this.deviceLostIsError=true;
690
+
if(this.reloadController){
691
+
this.reloadController.abort("Engine.unload() is called.");
0 commit comments