Skip to content

Conversation

ardaatahan
Copy link

@ardaatahan ardaatahan commented Dec 29, 2024

This PR aims to bring over some of huggingface_hub Python library's capabilities to swift-transformers:

  • Disables automatic redirection during head requests to fix getFileMetadata, which currently doesn't return expected commit hash and etag for lfs files.
  • Adds resumable downloads by downloading HuggingFace files in async bytes, stores them in 10MB buffers and writes buffers to memory.
  • Adds retry logic to downloads to mitigate connection errors during downloading.
  • Adds file metadata support to prevent unnecessary redownloads and ensure file integrity.
  • Adds additional tests to verify metadata and download logic.

from url: URL,
to destination: URL,
metadataDirURL: URL,
using authToken: String? = nil,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we allowing this to potentially be nil? Is there a case where you can download from HF without providing an auth token?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was made optional by huggingface. you don't need and auth token to download from public repos

if newNumRetries <= 0 {
throw error
}
try await Task.sleep(nanoseconds: 1_000_000_000)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just specify one second? It's less confusing and doesn't require someone to do some quick math to figure out the conversion.

Comment on lines 172 to 177
try await downloadWithStreaming(
request: request,
resumeSize: downloadedSize,
numRetries: newNumRetries - 1,
expectedSize: expectedSize
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means we would be retrying the request on every error, but not every error is retryable. For instance, if the HF API returns 400, 401 or 403, we shouldn't retry that request because it's never going to succeed. We should only retry the request if the response from HF is in the [500, 599] range, which is a server-side error, which can be transient (hence, retryable).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, generally recommend against implementing this with recursion since it's slightly less readable compared to a simple iterative solution (where we try/catch the error and retry the request until you reach the preset number of retries).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking for error ranges makes sense. I primarily did it this way to follow the design decisions made in the python library, which uses recursion for retries.

let contents = try String(contentsOf: metadataPath, encoding: .utf8)
let lines = contents.components(separatedBy: .newlines)

guard lines.count == 4 else {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we hard-coding 4 here? Would be useful to add a comment at the top of the function explaining why we expect the line count to be exactly 4 in the metadata file. Otherwise, if the line count can be something other than 4, maybe make it a parameter?

}

func isValidSHA256(_ hash: String) -> Bool {
let sha256Pattern = "^[0-9a-f]{64}$"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sha-256 hashes can also contain upper case letters. We should modify the regex to account for them so we don't have false negatives.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

afaik, sha256 hashes are usually lower case only. at least that's the format huggingface uses to serve the hashes for large files, so I reused their pattern from the python library.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, never mind, I forgot that this is a hexadecimal encoding and not a base64 encoding. With hex, 0xab37 is not different from 0xAB37, so we're good.

let localCommitHash = localMetadata?.commitHash ?? ""
let remoteCommitHash = remoteMetadata.commitHash ?? ""

// Local file exists + metadata exists + commit_hash matches => return file

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Comment on lines +275 to +303
func testDownloadFileMetadata() async throws {
let hubApi = HubApi(downloadBase: downloadDestination)
var lastProgress: Progress? = nil
let downloadedTo = try await hubApi.snapshot(from: repo, matching: "*.json") { progress in
print("Total Progress: \(progress.fractionCompleted)")
print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)")
lastProgress = progress
}
XCTAssertEqual(lastProgress?.fractionCompleted, 1)
XCTAssertEqual(lastProgress?.completedUnitCount, 6)
XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)"))

let downloadedFilenames = getRelativeFiles(url: downloadDestination)
XCTAssertEqual(
Set(downloadedFilenames),
Set([
"config.json", "tokenizer.json", "tokenizer_config.json",
"llama-2-7b-chat.mlpackage/Manifest.json",
"llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/FeatureDescriptions.json",
"llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json",
])
)

let metadataDestination = downloadedTo.appending(component: ".cache/huggingface")
let downloadedMetadataFilenames = getRelativeFiles(url: metadataDestination)
XCTAssertEqual(
Set(downloadedMetadataFilenames),
Set([
".cache/huggingface/config.metadata",

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very good test.

let remoteCommitHash = remoteMetadata.commitHash ?? ""

// Local file exists + metadata exists + commit_hash matches => return file
if isValidSHA256(remoteCommitHash) && downloaded && localMetadata != nil && localCommitHash == remoteCommitHash {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isValidSHA256(remoteCommitHash) && downloaded && localMetadata != nil && localCommitHash == remoteCommitHash {
if downloaded && localMetadata != nil && localCommitHash == remoteCommitHash {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here isValidSHA256 is actually immaterial since we're also checking whether the local commit hash is the same as the remote one. There is no case where the hash is not a real SHA hash and the condition localCommitHash == remoteCommitHash is true.

Copy link

@ZachNagengast ZachNagengast left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to go

@ZachNagengast
Copy link

Approved, closing in favor of upstream PR

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants