-
Notifications
You must be signed in to change notification settings - Fork 0
add metadata and resumable download support with tests #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
from url: URL, | ||
to destination: URL, | ||
metadataDirURL: URL, | ||
using authToken: String? = nil, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we allowing this to potentially be nil
? Is there a case where you can download from HF without providing an auth token?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it was made optional by huggingface. you don't need and auth token to download from public repos
Sources/Hub/Downloader.swift
Outdated
if newNumRetries <= 0 { | ||
throw error | ||
} | ||
try await Task.sleep(nanoseconds: 1_000_000_000) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just specify one second? It's less confusing and doesn't require someone to do some quick math to figure out the conversion.
Sources/Hub/Downloader.swift
Outdated
try await downloadWithStreaming( | ||
request: request, | ||
resumeSize: downloadedSize, | ||
numRetries: newNumRetries - 1, | ||
expectedSize: expectedSize | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means we would be retrying the request on every error, but not every error is retryable. For instance, if the HF API returns 400, 401 or 403, we shouldn't retry that request because it's never going to succeed. We should only retry the request if the response from HF is in the [500, 599] range, which is a server-side error, which can be transient (hence, retryable).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, generally recommend against implementing this with recursion since it's slightly less readable compared to a simple iterative solution (where we try/catch the error and retry the request until you reach the preset number of retries).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
checking for error ranges makes sense. I primarily did it this way to follow the design decisions made in the python library, which uses recursion for retries.
Sources/Hub/HubApi.swift
Outdated
let contents = try String(contentsOf: metadataPath, encoding: .utf8) | ||
let lines = contents.components(separatedBy: .newlines) | ||
|
||
guard lines.count == 4 else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we hard-coding 4 here? Would be useful to add a comment at the top of the function explaining why we expect the line count to be exactly 4 in the metadata file. Otherwise, if the line count can be something other than 4, maybe make it a parameter?
} | ||
|
||
func isValidSHA256(_ hash: String) -> Bool { | ||
let sha256Pattern = "^[0-9a-f]{64}$" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sha-256 hashes can also contain upper case letters. We should modify the regex to account for them so we don't have false negatives.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
afaik, sha256 hashes are usually lower case only. at least that's the format huggingface uses to serve the hashes for large files, so I reused their pattern from the python library.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, never mind, I forgot that this is a hexadecimal encoding and not a base64 encoding. With hex, 0xab37
is not different from 0xAB37
, so we're good.
let localCommitHash = localMetadata?.commitHash ?? "" | ||
let remoteCommitHash = remoteMetadata.commitHash ?? "" | ||
|
||
// Local file exists + metadata exists + commit_hash matches => return file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
func testDownloadFileMetadata() async throws { | ||
let hubApi = HubApi(downloadBase: downloadDestination) | ||
var lastProgress: Progress? = nil | ||
let downloadedTo = try await hubApi.snapshot(from: repo, matching: "*.json") { progress in | ||
print("Total Progress: \(progress.fractionCompleted)") | ||
print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") | ||
lastProgress = progress | ||
} | ||
XCTAssertEqual(lastProgress?.fractionCompleted, 1) | ||
XCTAssertEqual(lastProgress?.completedUnitCount, 6) | ||
XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) | ||
|
||
let downloadedFilenames = getRelativeFiles(url: downloadDestination) | ||
XCTAssertEqual( | ||
Set(downloadedFilenames), | ||
Set([ | ||
"config.json", "tokenizer.json", "tokenizer_config.json", | ||
"llama-2-7b-chat.mlpackage/Manifest.json", | ||
"llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/FeatureDescriptions.json", | ||
"llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json", | ||
]) | ||
) | ||
|
||
let metadataDestination = downloadedTo.appending(component: ".cache/huggingface") | ||
let downloadedMetadataFilenames = getRelativeFiles(url: metadataDestination) | ||
XCTAssertEqual( | ||
Set(downloadedMetadataFilenames), | ||
Set([ | ||
".cache/huggingface/config.metadata", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a very good test.
let remoteCommitHash = remoteMetadata.commitHash ?? "" | ||
|
||
// Local file exists + metadata exists + commit_hash matches => return file | ||
if isValidSHA256(remoteCommitHash) && downloaded && localMetadata != nil && localCommitHash == remoteCommitHash { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if isValidSHA256(remoteCommitHash) && downloaded && localMetadata != nil && localCommitHash == remoteCommitHash { | |
if downloaded && localMetadata != nil && localCommitHash == remoteCommitHash { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here isValidSHA256
is actually immaterial since we're also checking whether the local commit hash is the same as the remote one. There is no case where the hash is not a real SHA hash and the condition localCommitHash == remoteCommitHash
is true.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to go
Approved, closing in favor of upstream PR |
This PR aims to bring over some of huggingface_hub Python library's capabilities to swift-transformers: