-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Closed
Description
Currently the shards are downloaded sequentially. Should be promisified all the downloads
chat_module.js
await tvm.fetchNDArrayCache(modelUrl, tvm.webgpu(), "webllm/model");runtime.js (TVM)
/**
* Fetch list of NDArray into the NDArrayCache.
*
* @param ndarrayCacheUrl The cache url.
* @param list The list of array data.
* @param device The device to store the data to.
* @param artifactCache The artifact cache
*/
private async fetchNDArrayCacheInternal(
ndarrayCacheUrl: string,
list: Array<NDArrayShardEntry>,
device: DLDevice,
artifactCache: ArtifactCache
) {
const perf = compact.getPerformance();
const tstart = perf.now();
let totalBytes = 0;
for (let i = 0; i < list.length; ++i) {
totalBytes += list[i].nbytes;
}
let fetchedBytes = 0;
let timeElapsed = 0;
const cacheOnly = await artifactCache.hasAllKeys(list.map(key => new URL(key.dataPath, ndarrayCacheUrl).href))
const reportCallback = (iter: number) => {
// report
for (let j = 0; j < this.initProgressCallback.length; ++j) {
let text = "Fetching param cache[" + iter + "/" + list.length + "]: ";
text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB fetched. "
text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "% completed, "
text += timeElapsed + " secs elapsed.";
text += " It can take a while when we first visit this page to populate the cache."
text += " Later refreshes will become faster.";
if (cacheOnly) {
text = "Loading model from cache[" + iter + "/" + list.length + "]: ";
text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB loaded. "
text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "% completed, "
text += timeElapsed + " secs elapsed.";
}
this.initProgressCallback[j]({
progress: fetchedBytes / totalBytes,
timeElapsed: timeElapsed,
cacheOnly: cacheOnly,
text: text
});
}
};
for (let j = 0; j < this.initProgressCallback.length; ++j) {
this.initProgressCallback[j]({
progress: fetchedBytes / totalBytes,
timeElapsed: 0,
cacheOnly: cacheOnly,
text: "Start to fetch params",
});
}
for (let i = 0; i < list.length; ++i) {
reportCallback(i);
fetchedBytes += list[i].nbytes;
const dataUrl = new URL(list[i].dataPath, ndarrayCacheUrl).href;
let buffer;
try {
buffer = await (await artifactCache.fetchWithCache(dataUrl)).arrayBuffer();
} catch (err) {
this.env.logger("Error: Cannot fetch " + dataUrl + " err= " + err);
throw err;
}
const shardRecords = list[i].records;
for (let j = 0; j < shardRecords.length; ++j) {
const rec = shardRecords[j];
const cpu_arr = this.withNewScope(() => {
return this.detachFromCurrentScope(
this.empty(rec.shape, rec.dtype, this.cpu())
)
});
const recSource = buffer.slice(rec.byteOffset, rec.byteOffset + rec.nbytes);
// first sync copy to cpu.
this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource), rec.format, rec.dtype);
// then async stream into GPU if needed
if (device.deviceType === DeviceStrToEnum.cpu) {
this.ndarrayCacheUpdate(rec.name, cpu_arr, false);
cpu_arr.dispose();
} else {
// allocate a gpu arr and async copy to it.
const gpu_arr = this.withNewScope(() => {
return this.detachFromCurrentScope(
this.empty(rec.shape, rec.dtype, device)
)
});
gpu_arr.copyFrom(cpu_arr);
await device.sync();
this.ndarrayCacheUpdate(rec.name, gpu_arr, false);
cpu_arr.dispose();
gpu_arr.dispose();
}
}
timeElapsed = Math.ceil((perf.now() - tstart) / 1000);
}
reportCallback(list.length);
}emilrowlandCharlieFRuan
Metadata
Metadata
Assignees
Labels
No labels