-
Notifications
You must be signed in to change notification settings - Fork 0
add metadata and resumable download support with tests #2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
02d2571
f590932
22b6892
bedfc7a
af26e60
b4e1c49
26707b8
5839d33
9d39cf1
97b6163
fe2f32b
30adb75
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -6,6 +6,8 @@ | |||||
// | ||||||
|
||||||
import Foundation | ||||||
import CryptoKit | ||||||
import os | ||||||
|
||||||
public struct HubApi { | ||||||
var downloadBase: URL | ||||||
|
@@ -29,6 +31,8 @@ public struct HubApi { | |||||
} | ||||||
|
||||||
public static let shared = HubApi() | ||||||
|
||||||
private static let logger = Logger() | ||||||
} | ||||||
|
||||||
private extension HubApi { | ||||||
|
@@ -91,18 +95,24 @@ public extension HubApi { | |||||
return (data, response) | ||||||
} | ||||||
|
||||||
/// Throws error if page does not exist or is not accessible. | ||||||
/// Allows relative redirects but ignores absolute ones for LFS files. | ||||||
func httpHead(for url: URL) async throws -> (Data, HTTPURLResponse) { | ||||||
var request = URLRequest(url: url) | ||||||
request.httpMethod = "HEAD" | ||||||
if let hfToken = hfToken { | ||||||
request.setValue("Bearer \(hfToken)", forHTTPHeaderField: "Authorization") | ||||||
} | ||||||
request.setValue("identity", forHTTPHeaderField: "Accept-Encoding") | ||||||
let (data, response) = try await URLSession.shared.data(for: request) | ||||||
|
||||||
let redirectDelegate = RedirectDelegate() | ||||||
let session = URLSession(configuration: .default, delegate: redirectDelegate, delegateQueue: nil) | ||||||
|
||||||
let (data, response) = try await session.data(for: request) | ||||||
guard let response = response as? HTTPURLResponse else { throw Hub.HubClientError.unexpectedError } | ||||||
|
||||||
switch response.statusCode { | ||||||
case 200..<300: break | ||||||
case 200..<400: break // Allow redirects to pass through to the redirect delegate | ||||||
case 400..<500: throw Hub.HubClientError.authorizationRequired | ||||||
default: throw Hub.HubClientError.httpStatusCode(response.statusCode) | ||||||
} | ||||||
|
@@ -138,6 +148,26 @@ public extension HubApi { | |||||
} | ||||||
} | ||||||
|
||||||
/// Additional Errors | ||||||
public extension HubApi { | ||||||
enum EnvironmentError: LocalizedError { | ||||||
case consistencyError(String) | ||||||
case diskSpaceError(String) | ||||||
case permissionError(String) | ||||||
case invalidMetadataError(String) | ||||||
|
||||||
public var errorDescription: String? { | ||||||
switch self { | ||||||
case .consistencyError(let message), | ||||||
.diskSpaceError(let message), | ||||||
.permissionError(let message), | ||||||
.invalidMetadataError(let message): | ||||||
return message | ||||||
} | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
/// Configuration loading helpers | ||||||
public extension HubApi { | ||||||
/// Assumes the file has already been downloaded. | ||||||
|
@@ -201,6 +231,12 @@ public extension HubApi { | |||||
repoDestination.appending(path: relativeFilename) | ||||||
} | ||||||
|
||||||
var metadataDestination: URL { | ||||||
repoDestination | ||||||
.appendingPathComponent(".cache") | ||||||
.appendingPathComponent("huggingface") | ||||||
} | ||||||
|
||||||
var downloaded: Bool { | ||||||
FileManager.default.fileExists(atPath: destination.path) | ||||||
} | ||||||
|
@@ -209,16 +245,160 @@ public extension HubApi { | |||||
let directoryURL = destination.deletingLastPathComponent() | ||||||
try FileManager.default.createDirectory(at: directoryURL, withIntermediateDirectories: true, attributes: nil) | ||||||
} | ||||||
|
||||||
|
||||||
func prepareMetadataDestination() throws { | ||||||
try FileManager.default.createDirectory(at: metadataDestination, withIntermediateDirectories: true, attributes: nil) | ||||||
} | ||||||
|
||||||
/// Reads metadata about a file in the local directory related to a download process. | ||||||
/// | ||||||
/// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/_local_folder.py#L263 | ||||||
/// | ||||||
/// - Parameters: | ||||||
/// - localDir: The local directory where files are downloaded. | ||||||
/// - filePath: The path of the file for which metadata is being read. | ||||||
/// - Throws: An `EnvironmentError.invalidMetadataError` if the metadata file is invalid and cannot be removed. | ||||||
/// - Returns: A `LocalDownloadFileMetadata` object if the metadata file exists and is valid, or `nil` if the file is missing or invalid. | ||||||
func readDownloadMetadata(localDir: URL, filePath: String) throws -> LocalDownloadFileMetadata? { | ||||||
let metadataPath = localDir.appending(path: filePath) | ||||||
if FileManager.default.fileExists(atPath: metadataPath.path) { | ||||||
do { | ||||||
let contents = try String(contentsOf: metadataPath, encoding: .utf8) | ||||||
let lines = contents.components(separatedBy: .newlines) | ||||||
|
||||||
guard lines.count >= 3 else { | ||||||
throw EnvironmentError.invalidMetadataError("Metadata file is missing required fields.") | ||||||
} | ||||||
|
||||||
let commitHash = lines[0].trimmingCharacters(in: .whitespacesAndNewlines) | ||||||
let etag = lines[1].trimmingCharacters(in: .whitespacesAndNewlines) | ||||||
guard let timestamp = Double(lines[2].trimmingCharacters(in: .whitespacesAndNewlines)) else { | ||||||
throw EnvironmentError.invalidMetadataError("Missing or invalid timestamp.") | ||||||
} | ||||||
let timestampDate = Date(timeIntervalSince1970: timestamp) | ||||||
|
||||||
// TODO: check if file hasn't been modified since the metadata was saved | ||||||
|
||||||
return LocalDownloadFileMetadata(commitHash: commitHash, etag: etag, filename: filePath, timestamp: timestampDate) | ||||||
} catch { | ||||||
do { | ||||||
logger.warning("Invalid metadata file \(metadataPath): \(error). Removing it from disk and continue.") | ||||||
try FileManager.default.removeItem(at: metadataPath) | ||||||
} catch { | ||||||
throw EnvironmentError.invalidMetadataError("Could not remove corrupted metadata file \(metadataPath): \(error)") | ||||||
} | ||||||
return nil | ||||||
} | ||||||
} | ||||||
|
||||||
// metadata file does not exist | ||||||
return nil | ||||||
} | ||||||
|
||||||
func isValidSHA256(_ hash: String) -> Bool { | ||||||
let sha256Pattern = "^[0-9a-f]{64}$" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sha-256 hashes can also contain upper case letters. We should modify the regex to account for them so we don't have false negatives. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. afaik, sha256 hashes are usually lower case only. at least that's the format huggingface uses to serve the hashes for large files, so I reused their pattern from the python library. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, never mind, I forgot that this is a hexadecimal encoding and not a base64 encoding. With hex, |
||||||
let regex = try? NSRegularExpression(pattern: sha256Pattern) | ||||||
let range = NSRange(location: 0, length: hash.utf16.count) | ||||||
return regex?.firstMatch(in: hash, options: [], range: range) != nil | ||||||
} | ||||||
|
||||||
/// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/_local_folder.py#L391 | ||||||
func writeDownloadMetadata(commitHash: String, etag: String, metadataRelativePath: String) throws { | ||||||
let metadataContent = "\(commitHash)\n\(etag)\n\(Date().timeIntervalSince1970)\n" | ||||||
let metadataPath = metadataDestination.appending(component: metadataRelativePath) | ||||||
|
||||||
do { | ||||||
try FileManager.default.createDirectory(at: metadataPath.deletingLastPathComponent(), withIntermediateDirectories: true) | ||||||
try metadataContent.write(to: metadataPath, atomically: true, encoding: .utf8) | ||||||
} catch { | ||||||
throw EnvironmentError.invalidMetadataError("Failed to write metadata file \(metadataPath)") | ||||||
} | ||||||
} | ||||||
|
||||||
func computeFileHash(file url: URL) throws -> String { | ||||||
// Open file for reading | ||||||
guard let fileHandle = try? FileHandle(forReadingFrom: url) else { | ||||||
throw Hub.HubClientError.unexpectedError | ||||||
} | ||||||
|
||||||
defer { | ||||||
try? fileHandle.close() | ||||||
} | ||||||
|
||||||
var hasher = SHA256() | ||||||
let chunkSize = 1024 * 1024 // 1MB chunks | ||||||
|
||||||
while autoreleasepool(invoking: { | ||||||
let nextChunk = try? fileHandle.read(upToCount: chunkSize) | ||||||
|
||||||
guard let nextChunk, | ||||||
!nextChunk.isEmpty | ||||||
else { | ||||||
return false | ||||||
} | ||||||
|
||||||
hasher.update(data: nextChunk) | ||||||
|
||||||
return true | ||||||
}) { } | ||||||
|
||||||
let digest = hasher.finalize() | ||||||
return digest.map { String(format: "%02x", $0) }.joined() | ||||||
} | ||||||
|
||||||
|
||||||
// Note we go from Combine in Downloader to callback-based progress reporting | ||||||
// We'll probably need to support Combine as well to play well with Swift UI | ||||||
// (See for example PipelineLoader in swift-coreml-diffusers) | ||||||
@discardableResult | ||||||
func download(progressHandler: @escaping (Double) -> Void) async throws -> URL { | ||||||
guard !downloaded else { return destination } | ||||||
|
||||||
var metadataRelativePath = (relativeFilename as NSString).deletingPathExtension | ||||||
metadataRelativePath += ".metadata" | ||||||
|
||||||
let localMetadata = try readDownloadMetadata(localDir: metadataDestination, filePath: metadataRelativePath) | ||||||
let remoteMetadata = try await HubApi.shared.getFileMetadata(url: source) | ||||||
|
||||||
let localCommitHash = localMetadata?.commitHash ?? "" | ||||||
let remoteCommitHash = remoteMetadata.commitHash ?? "" | ||||||
|
||||||
// Local file exists + metadata exists + commit_hash matches => return file | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! |
||||||
if isValidSHA256(remoteCommitHash) && downloaded && localMetadata != nil && localCommitHash == remoteCommitHash { | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here |
||||||
return destination | ||||||
} | ||||||
|
||||||
// From now on, etag, commit_hash, url and size are not empty | ||||||
guard let remoteCommitHash = remoteMetadata.commitHash, | ||||||
let remoteEtag = remoteMetadata.etag, | ||||||
remoteMetadata.location != "" else { | ||||||
throw EnvironmentError.invalidMetadataError("File metadata must have been retrieved from server") | ||||||
} | ||||||
|
||||||
// Local file exists => check if it's up-to-date | ||||||
if downloaded { | ||||||
// etag matches => update metadata and return file | ||||||
if localMetadata?.etag == remoteEtag { | ||||||
try writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataRelativePath: metadataRelativePath) | ||||||
return destination | ||||||
} | ||||||
|
||||||
// metadata is outdated + etag is a sha256 | ||||||
// => means it's an LFS file (large) | ||||||
// => let's compute local hash and compare | ||||||
// => if match, update metadata and return file | ||||||
if localMetadata != nil && isValidSHA256(remoteEtag) { | ||||||
let fileHash = try computeFileHash(file: destination) | ||||||
if fileHash == remoteEtag { | ||||||
try writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataRelativePath: metadataRelativePath) | ||||||
return destination | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
// Otherwise, let's download the file! | ||||||
try prepareDestination() | ||||||
let downloader = Downloader(from: source, to: destination, using: hfToken, inBackground: backgroundSession) | ||||||
try prepareMetadataDestination() | ||||||
|
||||||
let downloader = Downloader(from: source, to: destination, metadataDirURL: metadataDestination, using: hfToken, inBackground: backgroundSession) | ||||||
let downloadSubscriber = downloader.downloadState.sink { state in | ||||||
if case .downloading(let progress) = state { | ||||||
progressHandler(progress) | ||||||
|
@@ -227,6 +407,9 @@ public extension HubApi { | |||||
_ = try withExtendedLifetime(downloadSubscriber) { | ||||||
try downloader.waitUntilDone() | ||||||
} | ||||||
|
||||||
try writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataRelativePath: metadataRelativePath) | ||||||
|
||||||
return destination | ||||||
} | ||||||
} | ||||||
|
@@ -274,20 +457,36 @@ public extension HubApi { | |||||
|
||||||
/// Metadata | ||||||
public extension HubApi { | ||||||
/// A structure representing metadata for a remote file | ||||||
/// Data structure containing information about a file versioned on the Hub | ||||||
struct FileMetadata { | ||||||
/// The file's Git commit hash | ||||||
/// The commit hash related to the file | ||||||
public let commitHash: String? | ||||||
|
||||||
/// Server-provided ETag for caching | ||||||
/// Etag of the file on the server | ||||||
public let etag: String? | ||||||
|
||||||
/// Stringified URL location of the file | ||||||
/// Location where to download the file. Can be a Hub url or not (CDN). | ||||||
public let location: String | ||||||
|
||||||
/// The file's size in bytes | ||||||
/// Size of the file. In case of an LFS file, contains the size of the actual LFS file, not the pointer. | ||||||
public let size: Int? | ||||||
} | ||||||
|
||||||
/// Metadata about a file in the local directory related to a download process | ||||||
struct LocalDownloadFileMetadata { | ||||||
/// Commit hash of the file in the repo | ||||||
public let commitHash: String | ||||||
|
||||||
/// ETag of the file in the repo. Used to check if the file has changed. | ||||||
/// For LFS files, this is the sha256 of the file. For regular files, it corresponds to the git hash. | ||||||
public let etag: String | ||||||
|
||||||
/// Path of the file in the repo | ||||||
public let filename: String | ||||||
|
||||||
/// The timestamp of when the metadata was saved i.e. when the metadata was accurate | ||||||
public let timestamp: Date | ||||||
} | ||||||
|
||||||
private func normalizeEtag(_ etag: String?) -> String? { | ||||||
guard let etag = etag else { return nil } | ||||||
|
@@ -296,13 +495,14 @@ public extension HubApi { | |||||
|
||||||
func getFileMetadata(url: URL) async throws -> FileMetadata { | ||||||
let (_, response) = try await httpHead(for: url) | ||||||
let location = response.statusCode == 302 ? response.value(forHTTPHeaderField: "Location") : response.url?.absoluteString | ||||||
|
||||||
return FileMetadata( | ||||||
commitHash: response.value(forHTTPHeaderField: "X-Repo-Commit"), | ||||||
etag: normalizeEtag( | ||||||
(response.value(forHTTPHeaderField: "X-Linked-Etag")) ?? (response.value(forHTTPHeaderField: "Etag")) | ||||||
), | ||||||
location: (response.value(forHTTPHeaderField: "Location")) ?? url.absoluteString, | ||||||
location: location ?? url.absoluteString, | ||||||
size: Int(response.value(forHTTPHeaderField: "X-Linked-Size") ?? response.value(forHTTPHeaderField: "Content-Length") ?? "") | ||||||
) | ||||||
} | ||||||
|
@@ -395,3 +595,43 @@ public extension [String] { | |||||
filter { fnmatch(glob, $0, 0) == 0 } | ||||||
} | ||||||
} | ||||||
|
||||||
/// Only allow relative redirects and reject others | ||||||
/// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/file_download.py#L258 | ||||||
private class RedirectDelegate: NSObject, URLSessionTaskDelegate { | ||||||
func urlSession(_ session: URLSession, task: URLSessionTask, willPerformHTTPRedirection response: HTTPURLResponse, newRequest request: URLRequest, completionHandler: @escaping (URLRequest?) -> Void) { | ||||||
// Check if it's a redirect status code (300-399) | ||||||
if (300...399).contains(response.statusCode) { | ||||||
// Get the Location header | ||||||
if let locationString = response.value(forHTTPHeaderField: "Location"), | ||||||
let locationUrl = URL(string: locationString) { | ||||||
|
||||||
// Check if it's a relative redirect (no host component) | ||||||
if locationUrl.host == nil { | ||||||
// For relative redirects, construct the new URL using the original request's base | ||||||
if let originalUrl = task.originalRequest?.url, | ||||||
var components = URLComponents(url: originalUrl, resolvingAgainstBaseURL: true) { | ||||||
// Update the path component with the relative path | ||||||
components.path = locationUrl.path | ||||||
components.query = locationUrl.query | ||||||
|
||||||
// Create new request with the resolved URL | ||||||
if let resolvedUrl = components.url { | ||||||
var newRequest = URLRequest(url: resolvedUrl) | ||||||
// Copy headers from original request | ||||||
task.originalRequest?.allHTTPHeaderFields?.forEach { key, value in | ||||||
newRequest.setValue(value, forHTTPHeaderField: key) | ||||||
} | ||||||
newRequest.setValue(resolvedUrl.absoluteString, forHTTPHeaderField: "Location") | ||||||
completionHandler(newRequest) | ||||||
return | ||||||
} | ||||||
} | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
// For all other cases (non-redirects or absolute redirects), prevent redirect | ||||||
completionHandler(nil) | ||||||
} | ||||||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we allowing this to potentially be
nil
? Is there a case where you can download from HF without providing an auth token?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it was made optional by huggingface. you don't need and auth token to download from public repos