Skip to content
Closed
13 changes: 12 additions & 1 deletion Sources/Hub/Downloader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ import Combine

class Downloader: NSObject, ObservableObject {
private(set) var destination: URL
private(set) var metadataDestination: URL

private let chunkSize = 10 * 1024 * 1024 // 10MB

enum DownloadState {
case notStarted
Expand All @@ -29,8 +32,16 @@ class Downloader: NSObject, ObservableObject {

private var urlSession: URLSession? = nil

init(from url: URL, to destination: URL, using authToken: String? = nil, inBackground: Bool = false) {
init(
from url: URL,
to destination: URL,
metadataDirURL: URL,
using authToken: String? = nil,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we allowing this to potentially be nil? Is there a case where you can download from HF without providing an auth token?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it was made optional by huggingface. you don't need and auth token to download from public repos

inBackground: Bool = false
) {
self.destination = destination
let filename = (destination.lastPathComponent as NSString).deletingPathExtension
self.metadataDestination = metadataDirURL.appending(component: "\(filename).metadata")
super.init()
let sessionIdentifier = "swift-transformers.hub.downloader"

Expand Down
264 changes: 252 additions & 12 deletions Sources/Hub/HubApi.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
//

import Foundation
import CryptoKit
import os

public struct HubApi {
var downloadBase: URL
Expand All @@ -29,6 +31,8 @@ public struct HubApi {
}

public static let shared = HubApi()

private static let logger = Logger()
}

private extension HubApi {
Expand Down Expand Up @@ -91,18 +95,24 @@ public extension HubApi {
return (data, response)
}

/// Throws error if page does not exist or is not accessible.
/// Allows relative redirects but ignores absolute ones for LFS files.
func httpHead(for url: URL) async throws -> (Data, HTTPURLResponse) {
var request = URLRequest(url: url)
request.httpMethod = "HEAD"
if let hfToken = hfToken {
request.setValue("Bearer \(hfToken)", forHTTPHeaderField: "Authorization")
}
request.setValue("identity", forHTTPHeaderField: "Accept-Encoding")
let (data, response) = try await URLSession.shared.data(for: request)

let redirectDelegate = RedirectDelegate()
let session = URLSession(configuration: .default, delegate: redirectDelegate, delegateQueue: nil)

let (data, response) = try await session.data(for: request)
guard let response = response as? HTTPURLResponse else { throw Hub.HubClientError.unexpectedError }

switch response.statusCode {
case 200..<300: break
case 200..<400: break // Allow redirects to pass through to the redirect delegate
case 400..<500: throw Hub.HubClientError.authorizationRequired
default: throw Hub.HubClientError.httpStatusCode(response.statusCode)
}
Expand Down Expand Up @@ -138,6 +148,26 @@ public extension HubApi {
}
}

/// Additional Errors
public extension HubApi {
enum EnvironmentError: LocalizedError {
case consistencyError(String)
case diskSpaceError(String)
case permissionError(String)
case invalidMetadataError(String)

public var errorDescription: String? {
switch self {
case .consistencyError(let message),
.diskSpaceError(let message),
.permissionError(let message),
.invalidMetadataError(let message):
return message
}
}
}
}

/// Configuration loading helpers
public extension HubApi {
/// Assumes the file has already been downloaded.
Expand Down Expand Up @@ -201,6 +231,12 @@ public extension HubApi {
repoDestination.appending(path: relativeFilename)
}

var metadataDestination: URL {
repoDestination
.appendingPathComponent(".cache")
.appendingPathComponent("huggingface")
}

var downloaded: Bool {
FileManager.default.fileExists(atPath: destination.path)
}
Expand All @@ -209,16 +245,160 @@ public extension HubApi {
let directoryURL = destination.deletingLastPathComponent()
try FileManager.default.createDirectory(at: directoryURL, withIntermediateDirectories: true, attributes: nil)
}


func prepareMetadataDestination() throws {
try FileManager.default.createDirectory(at: metadataDestination, withIntermediateDirectories: true, attributes: nil)
}

/// Reads metadata about a file in the local directory related to a download process.
///
/// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/_local_folder.py#L263
///
/// - Parameters:
/// - localDir: The local directory where files are downloaded.
/// - filePath: The path of the file for which metadata is being read.
/// - Throws: An `EnvironmentError.invalidMetadataError` if the metadata file is invalid and cannot be removed.
/// - Returns: A `LocalDownloadFileMetadata` object if the metadata file exists and is valid, or `nil` if the file is missing or invalid.
func readDownloadMetadata(localDir: URL, filePath: String) throws -> LocalDownloadFileMetadata? {
let metadataPath = localDir.appending(path: filePath)
if FileManager.default.fileExists(atPath: metadataPath.path) {
do {
let contents = try String(contentsOf: metadataPath, encoding: .utf8)
let lines = contents.components(separatedBy: .newlines)

guard lines.count >= 3 else {
throw EnvironmentError.invalidMetadataError("Metadata file is missing required fields.")
}

let commitHash = lines[0].trimmingCharacters(in: .whitespacesAndNewlines)
let etag = lines[1].trimmingCharacters(in: .whitespacesAndNewlines)
guard let timestamp = Double(lines[2].trimmingCharacters(in: .whitespacesAndNewlines)) else {
throw EnvironmentError.invalidMetadataError("Missing or invalid timestamp.")
}
let timestampDate = Date(timeIntervalSince1970: timestamp)

// TODO: check if file hasn't been modified since the metadata was saved

return LocalDownloadFileMetadata(commitHash: commitHash, etag: etag, filename: filePath, timestamp: timestampDate)
} catch {
do {
logger.warning("Invalid metadata file \(metadataPath): \(error). Removing it from disk and continue.")
try FileManager.default.removeItem(at: metadataPath)
} catch {
throw EnvironmentError.invalidMetadataError("Could not remove corrupted metadata file \(metadataPath): \(error)")
}
return nil
}
}

// metadata file does not exist
return nil
}

func isValidSHA256(_ hash: String) -> Bool {
let sha256Pattern = "^[0-9a-f]{64}$"

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sha-256 hashes can also contain upper case letters. We should modify the regex to account for them so we don't have false negatives.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

afaik, sha256 hashes are usually lower case only. at least that's the format huggingface uses to serve the hashes for large files, so I reused their pattern from the python library.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, never mind, I forgot that this is a hexadecimal encoding and not a base64 encoding. With hex, 0xab37 is not different from 0xAB37, so we're good.

let regex = try? NSRegularExpression(pattern: sha256Pattern)
let range = NSRange(location: 0, length: hash.utf16.count)
return regex?.firstMatch(in: hash, options: [], range: range) != nil
}

/// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/_local_folder.py#L391
func writeDownloadMetadata(commitHash: String, etag: String, metadataRelativePath: String) throws {
let metadataContent = "\(commitHash)\n\(etag)\n\(Date().timeIntervalSince1970)\n"
let metadataPath = metadataDestination.appending(component: metadataRelativePath)

do {
try FileManager.default.createDirectory(at: metadataPath.deletingLastPathComponent(), withIntermediateDirectories: true)
try metadataContent.write(to: metadataPath, atomically: true, encoding: .utf8)
} catch {
throw EnvironmentError.invalidMetadataError("Failed to write metadata file \(metadataPath)")
}
}

func computeFileHash(file url: URL) throws -> String {
// Open file for reading
guard let fileHandle = try? FileHandle(forReadingFrom: url) else {
throw Hub.HubClientError.unexpectedError
}

defer {
try? fileHandle.close()
}

var hasher = SHA256()
let chunkSize = 1024 * 1024 // 1MB chunks

while autoreleasepool(invoking: {
let nextChunk = try? fileHandle.read(upToCount: chunkSize)

guard let nextChunk,
!nextChunk.isEmpty
else {
return false
}

hasher.update(data: nextChunk)

return true
}) { }

let digest = hasher.finalize()
return digest.map { String(format: "%02x", $0) }.joined()
}


// Note we go from Combine in Downloader to callback-based progress reporting
// We'll probably need to support Combine as well to play well with Swift UI
// (See for example PipelineLoader in swift-coreml-diffusers)
@discardableResult
func download(progressHandler: @escaping (Double) -> Void) async throws -> URL {
guard !downloaded else { return destination }

var metadataRelativePath = (relativeFilename as NSString).deletingPathExtension
metadataRelativePath += ".metadata"

let localMetadata = try readDownloadMetadata(localDir: metadataDestination, filePath: metadataRelativePath)
let remoteMetadata = try await HubApi.shared.getFileMetadata(url: source)

let localCommitHash = localMetadata?.commitHash ?? ""
let remoteCommitHash = remoteMetadata.commitHash ?? ""

// Local file exists + metadata exists + commit_hash matches => return file

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

if isValidSHA256(remoteCommitHash) && downloaded && localMetadata != nil && localCommitHash == remoteCommitHash {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isValidSHA256(remoteCommitHash) && downloaded && localMetadata != nil && localCommitHash == remoteCommitHash {
if downloaded && localMetadata != nil && localCommitHash == remoteCommitHash {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here isValidSHA256 is actually immaterial since we're also checking whether the local commit hash is the same as the remote one. There is no case where the hash is not a real SHA hash and the condition localCommitHash == remoteCommitHash is true.

return destination
}

// From now on, etag, commit_hash, url and size are not empty
guard let remoteCommitHash = remoteMetadata.commitHash,
let remoteEtag = remoteMetadata.etag,
remoteMetadata.location != "" else {
throw EnvironmentError.invalidMetadataError("File metadata must have been retrieved from server")
}

// Local file exists => check if it's up-to-date
if downloaded {
// etag matches => update metadata and return file
if localMetadata?.etag == remoteEtag {
try writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataRelativePath: metadataRelativePath)
return destination
}

// metadata is outdated + etag is a sha256
// => means it's an LFS file (large)
// => let's compute local hash and compare
// => if match, update metadata and return file
if localMetadata != nil && isValidSHA256(remoteEtag) {
let fileHash = try computeFileHash(file: destination)
if fileHash == remoteEtag {
try writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataRelativePath: metadataRelativePath)
return destination
}
}
}

// Otherwise, let's download the file!
try prepareDestination()
let downloader = Downloader(from: source, to: destination, using: hfToken, inBackground: backgroundSession)
try prepareMetadataDestination()

let downloader = Downloader(from: source, to: destination, metadataDirURL: metadataDestination, using: hfToken, inBackground: backgroundSession)
let downloadSubscriber = downloader.downloadState.sink { state in
if case .downloading(let progress) = state {
progressHandler(progress)
Expand All @@ -227,6 +407,9 @@ public extension HubApi {
_ = try withExtendedLifetime(downloadSubscriber) {
try downloader.waitUntilDone()
}

try writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataRelativePath: metadataRelativePath)

return destination
}
}
Expand Down Expand Up @@ -274,20 +457,36 @@ public extension HubApi {

/// Metadata
public extension HubApi {
/// A structure representing metadata for a remote file
/// Data structure containing information about a file versioned on the Hub
struct FileMetadata {
/// The file's Git commit hash
/// The commit hash related to the file
public let commitHash: String?

/// Server-provided ETag for caching
/// Etag of the file on the server
public let etag: String?

/// Stringified URL location of the file
/// Location where to download the file. Can be a Hub url or not (CDN).
public let location: String

/// The file's size in bytes
/// Size of the file. In case of an LFS file, contains the size of the actual LFS file, not the pointer.
public let size: Int?
}

/// Metadata about a file in the local directory related to a download process
struct LocalDownloadFileMetadata {
/// Commit hash of the file in the repo
public let commitHash: String

/// ETag of the file in the repo. Used to check if the file has changed.
/// For LFS files, this is the sha256 of the file. For regular files, it corresponds to the git hash.
public let etag: String

/// Path of the file in the repo
public let filename: String

/// The timestamp of when the metadata was saved i.e. when the metadata was accurate
public let timestamp: Date
}

private func normalizeEtag(_ etag: String?) -> String? {
guard let etag = etag else { return nil }
Expand All @@ -296,13 +495,14 @@ public extension HubApi {

func getFileMetadata(url: URL) async throws -> FileMetadata {
let (_, response) = try await httpHead(for: url)
let location = response.statusCode == 302 ? response.value(forHTTPHeaderField: "Location") : response.url?.absoluteString

return FileMetadata(
commitHash: response.value(forHTTPHeaderField: "X-Repo-Commit"),
etag: normalizeEtag(
(response.value(forHTTPHeaderField: "X-Linked-Etag")) ?? (response.value(forHTTPHeaderField: "Etag"))
),
location: (response.value(forHTTPHeaderField: "Location")) ?? url.absoluteString,
location: location ?? url.absoluteString,
size: Int(response.value(forHTTPHeaderField: "X-Linked-Size") ?? response.value(forHTTPHeaderField: "Content-Length") ?? "")
)
}
Expand Down Expand Up @@ -395,3 +595,43 @@ public extension [String] {
filter { fnmatch(glob, $0, 0) == 0 }
}
}

/// Only allow relative redirects and reject others
/// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/file_download.py#L258
private class RedirectDelegate: NSObject, URLSessionTaskDelegate {
func urlSession(_ session: URLSession, task: URLSessionTask, willPerformHTTPRedirection response: HTTPURLResponse, newRequest request: URLRequest, completionHandler: @escaping (URLRequest?) -> Void) {
// Check if it's a redirect status code (300-399)
if (300...399).contains(response.statusCode) {
// Get the Location header
if let locationString = response.value(forHTTPHeaderField: "Location"),
let locationUrl = URL(string: locationString) {

// Check if it's a relative redirect (no host component)
if locationUrl.host == nil {
// For relative redirects, construct the new URL using the original request's base
if let originalUrl = task.originalRequest?.url,
var components = URLComponents(url: originalUrl, resolvingAgainstBaseURL: true) {
// Update the path component with the relative path
components.path = locationUrl.path
components.query = locationUrl.query

// Create new request with the resolved URL
if let resolvedUrl = components.url {
var newRequest = URLRequest(url: resolvedUrl)
// Copy headers from original request
task.originalRequest?.allHTTPHeaderFields?.forEach { key, value in
newRequest.setValue(value, forHTTPHeaderField: key)
}
newRequest.setValue(resolvedUrl.absoluteString, forHTTPHeaderField: "Location")
completionHandler(newRequest)
return
}
}
}
}
}

// For all other cases (non-redirects or absolute redirects), prevent redirect
completionHandler(nil)
}
}
Loading