Skip to content

Improve model load: Parallelise download of model shards #280

@DavidGOrtega

Description

@DavidGOrtega

Currently the shards are downloaded sequentially. Should be promisified all the downloads

chat_module.js

await tvm.fetchNDArrayCache(modelUrl, tvm.webgpu(), "webllm/model");

runtime.js (TVM)

/**
   * Fetch list of NDArray into the NDArrayCache.
   *
   * @param ndarrayCacheUrl The cache url.
   * @param list The list of array data.
   * @param device The device to store the data to.
   * @param artifactCache The artifact cache
   */
  private async fetchNDArrayCacheInternal(
    ndarrayCacheUrl: string,
    list: Array<NDArrayShardEntry>,
    device: DLDevice,
    artifactCache: ArtifactCache
  ) {
    const perf = compact.getPerformance();
    const tstart = perf.now();

    let totalBytes = 0;
    for (let i = 0; i < list.length; ++i) {
      totalBytes += list[i].nbytes;
    }
    let fetchedBytes = 0;
    let timeElapsed = 0;

    const cacheOnly = await artifactCache.hasAllKeys(list.map(key => new URL(key.dataPath, ndarrayCacheUrl).href))

    const reportCallback = (iter: number) => {
      // report
      for (let j = 0; j < this.initProgressCallback.length; ++j) {
        let text = "Fetching param cache[" + iter + "/" + list.length + "]: ";
        text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB fetched. "
        text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "% completed, "
        text += timeElapsed + " secs elapsed.";
        text += " It can take a while when we first visit this page to populate the cache."
        text += " Later refreshes will become faster.";
        if (cacheOnly) {
          text = "Loading model from cache[" + iter + "/" + list.length + "]: ";
          text += Math.ceil(fetchedBytes / (1024 * 1024)).toString() + "MB loaded. "
          text += Math.floor(fetchedBytes * 100 / totalBytes).toString() + "% completed, "
          text += timeElapsed + " secs elapsed.";
        }
        this.initProgressCallback[j]({
          progress: fetchedBytes / totalBytes,
          timeElapsed: timeElapsed,
          cacheOnly: cacheOnly,
          text: text
        });
      }
    };

    for (let j = 0; j < this.initProgressCallback.length; ++j) {
      this.initProgressCallback[j]({
        progress: fetchedBytes / totalBytes,
        timeElapsed: 0,
        cacheOnly: cacheOnly,
        text: "Start to fetch params",
      });
    }

    for (let i = 0; i < list.length; ++i) {
      reportCallback(i);
      fetchedBytes += list[i].nbytes;
      const dataUrl = new URL(list[i].dataPath, ndarrayCacheUrl).href;
      let buffer;
      try {
        buffer = await (await artifactCache.fetchWithCache(dataUrl)).arrayBuffer();
      } catch (err) {
        this.env.logger("Error: Cannot fetch " + dataUrl + " err= " + err);
        throw err;
      }
      const shardRecords = list[i].records;
      for (let j = 0; j < shardRecords.length; ++j) {
        const rec = shardRecords[j];
        const cpu_arr = this.withNewScope(() => {
          return this.detachFromCurrentScope(
            this.empty(rec.shape, rec.dtype, this.cpu())
          )
        });
        const recSource = buffer.slice(rec.byteOffset, rec.byteOffset + rec.nbytes);
        // first sync copy to cpu.
        this.ctx.arrayDecodeStorage(cpu_arr, new Uint8Array(recSource), rec.format, rec.dtype);
        // then async stream into GPU if needed
        if (device.deviceType === DeviceStrToEnum.cpu) {
          this.ndarrayCacheUpdate(rec.name, cpu_arr, false);
          cpu_arr.dispose();
        } else {
          // allocate a gpu arr and async copy to it.
          const gpu_arr = this.withNewScope(() => {
            return this.detachFromCurrentScope(
              this.empty(rec.shape, rec.dtype, device)
            )
          });
          gpu_arr.copyFrom(cpu_arr);
          await device.sync();
          this.ndarrayCacheUpdate(rec.name, gpu_arr, false);
          cpu_arr.dispose();
          gpu_arr.dispose();
        }
      }
      timeElapsed = Math.ceil((perf.now() - tstart) / 1000);
    }
    reportCallback(list.length);
  }

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions