diff --git a/Sources/Hub/Downloader.swift b/Sources/Hub/Downloader.swift index 1f3b8c6d..3334b8de 100644 --- a/Sources/Hub/Downloader.swift +++ b/Sources/Hub/Downloader.swift @@ -11,6 +11,9 @@ import Combine class Downloader: NSObject, ObservableObject { private(set) var destination: URL + private(set) var metadataDestination: URL + + private let chunkSize = 10 * 1024 * 1024 // 10MB enum DownloadState { case notStarted @@ -29,8 +32,16 @@ class Downloader: NSObject, ObservableObject { private var urlSession: URLSession? = nil - init(from url: URL, to destination: URL, using authToken: String? = nil, inBackground: Bool = false) { + init( + from url: URL, + to destination: URL, + metadataDirURL: URL, + using authToken: String? = nil, + inBackground: Bool = false + ) { self.destination = destination + let filename = (destination.lastPathComponent as NSString).deletingPathExtension + self.metadataDestination = metadataDirURL.appending(component: "\(filename).metadata") super.init() let sessionIdentifier = "swift-transformers.hub.downloader" diff --git a/Sources/Hub/HubApi.swift b/Sources/Hub/HubApi.swift index fdf12568..bc013515 100644 --- a/Sources/Hub/HubApi.swift +++ b/Sources/Hub/HubApi.swift @@ -6,6 +6,8 @@ // import Foundation +import CryptoKit +import os public struct HubApi { var downloadBase: URL @@ -29,6 +31,8 @@ public struct HubApi { } public static let shared = HubApi() + + private static let logger = Logger() } private extension HubApi { @@ -91,6 +95,8 @@ public extension HubApi { return (data, response) } + /// Throws error if page does not exist or is not accessible. + /// Allows relative redirects but ignores absolute ones for LFS files. func httpHead(for url: URL) async throws -> (Data, HTTPURLResponse) { var request = URLRequest(url: url) request.httpMethod = "HEAD" @@ -98,11 +104,15 @@ public extension HubApi { request.setValue("Bearer \(hfToken)", forHTTPHeaderField: "Authorization") } request.setValue("identity", forHTTPHeaderField: "Accept-Encoding") - let (data, response) = try await URLSession.shared.data(for: request) + + let redirectDelegate = RedirectDelegate() + let session = URLSession(configuration: .default, delegate: redirectDelegate, delegateQueue: nil) + + let (data, response) = try await session.data(for: request) guard let response = response as? HTTPURLResponse else { throw Hub.HubClientError.unexpectedError } switch response.statusCode { - case 200..<300: break + case 200..<400: break // Allow redirects to pass through to the redirect delegate case 400..<500: throw Hub.HubClientError.authorizationRequired default: throw Hub.HubClientError.httpStatusCode(response.statusCode) } @@ -138,6 +148,26 @@ public extension HubApi { } } +/// Additional Errors +public extension HubApi { + enum EnvironmentError: LocalizedError { + case consistencyError(String) + case diskSpaceError(String) + case permissionError(String) + case invalidMetadataError(String) + + public var errorDescription: String? { + switch self { + case .consistencyError(let message), + .diskSpaceError(let message), + .permissionError(let message), + .invalidMetadataError(let message): + return message + } + } + } +} + /// Configuration loading helpers public extension HubApi { /// Assumes the file has already been downloaded. @@ -201,6 +231,12 @@ public extension HubApi { repoDestination.appending(path: relativeFilename) } + var metadataDestination: URL { + repoDestination + .appendingPathComponent(".cache") + .appendingPathComponent("huggingface") + } + var downloaded: Bool { FileManager.default.fileExists(atPath: destination.path) } @@ -209,16 +245,160 @@ public extension HubApi { let directoryURL = destination.deletingLastPathComponent() try FileManager.default.createDirectory(at: directoryURL, withIntermediateDirectories: true, attributes: nil) } - + + func prepareMetadataDestination() throws { + try FileManager.default.createDirectory(at: metadataDestination, withIntermediateDirectories: true, attributes: nil) + } + + /// Reads metadata about a file in the local directory related to a download process. + /// + /// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/_local_folder.py#L263 + /// + /// - Parameters: + /// - localDir: The local directory where files are downloaded. + /// - filePath: The path of the file for which metadata is being read. + /// - Throws: An `EnvironmentError.invalidMetadataError` if the metadata file is invalid and cannot be removed. + /// - Returns: A `LocalDownloadFileMetadata` object if the metadata file exists and is valid, or `nil` if the file is missing or invalid. + func readDownloadMetadata(localDir: URL, filePath: String) throws -> LocalDownloadFileMetadata? { + let metadataPath = localDir.appending(path: filePath) + if FileManager.default.fileExists(atPath: metadataPath.path) { + do { + let contents = try String(contentsOf: metadataPath, encoding: .utf8) + let lines = contents.components(separatedBy: .newlines) + + guard lines.count >= 3 else { + throw EnvironmentError.invalidMetadataError("Metadata file is missing required fields.") + } + + let commitHash = lines[0].trimmingCharacters(in: .whitespacesAndNewlines) + let etag = lines[1].trimmingCharacters(in: .whitespacesAndNewlines) + guard let timestamp = Double(lines[2].trimmingCharacters(in: .whitespacesAndNewlines)) else { + throw EnvironmentError.invalidMetadataError("Missing or invalid timestamp.") + } + let timestampDate = Date(timeIntervalSince1970: timestamp) + + // TODO: check if file hasn't been modified since the metadata was saved + + return LocalDownloadFileMetadata(commitHash: commitHash, etag: etag, filename: filePath, timestamp: timestampDate) + } catch { + do { + logger.warning("Invalid metadata file \(metadataPath): \(error). Removing it from disk and continue.") + try FileManager.default.removeItem(at: metadataPath) + } catch { + throw EnvironmentError.invalidMetadataError("Could not remove corrupted metadata file \(metadataPath): \(error)") + } + return nil + } + } + + // metadata file does not exist + return nil + } + + func isValidSHA256(_ hash: String) -> Bool { + let sha256Pattern = "^[0-9a-f]{64}$" + let regex = try? NSRegularExpression(pattern: sha256Pattern) + let range = NSRange(location: 0, length: hash.utf16.count) + return regex?.firstMatch(in: hash, options: [], range: range) != nil + } + + /// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/_local_folder.py#L391 + func writeDownloadMetadata(commitHash: String, etag: String, metadataRelativePath: String) throws { + let metadataContent = "\(commitHash)\n\(etag)\n\(Date().timeIntervalSince1970)\n" + let metadataPath = metadataDestination.appending(component: metadataRelativePath) + + do { + try FileManager.default.createDirectory(at: metadataPath.deletingLastPathComponent(), withIntermediateDirectories: true) + try metadataContent.write(to: metadataPath, atomically: true, encoding: .utf8) + } catch { + throw EnvironmentError.invalidMetadataError("Failed to write metadata file \(metadataPath)") + } + } + + func computeFileHash(file url: URL) throws -> String { + // Open file for reading + guard let fileHandle = try? FileHandle(forReadingFrom: url) else { + throw Hub.HubClientError.unexpectedError + } + + defer { + try? fileHandle.close() + } + + var hasher = SHA256() + let chunkSize = 1024 * 1024 // 1MB chunks + + while autoreleasepool(invoking: { + let nextChunk = try? fileHandle.read(upToCount: chunkSize) + + guard let nextChunk, + !nextChunk.isEmpty + else { + return false + } + + hasher.update(data: nextChunk) + + return true + }) { } + + let digest = hasher.finalize() + return digest.map { String(format: "%02x", $0) }.joined() + } + + // Note we go from Combine in Downloader to callback-based progress reporting // We'll probably need to support Combine as well to play well with Swift UI // (See for example PipelineLoader in swift-coreml-diffusers) @discardableResult func download(progressHandler: @escaping (Double) -> Void) async throws -> URL { - guard !downloaded else { return destination } - + var metadataRelativePath = (relativeFilename as NSString).deletingPathExtension + metadataRelativePath += ".metadata" + + let localMetadata = try readDownloadMetadata(localDir: metadataDestination, filePath: metadataRelativePath) + let remoteMetadata = try await HubApi.shared.getFileMetadata(url: source) + + let localCommitHash = localMetadata?.commitHash ?? "" + let remoteCommitHash = remoteMetadata.commitHash ?? "" + + // Local file exists + metadata exists + commit_hash matches => return file + if isValidSHA256(remoteCommitHash) && downloaded && localMetadata != nil && localCommitHash == remoteCommitHash { + return destination + } + + // From now on, etag, commit_hash, url and size are not empty + guard let remoteCommitHash = remoteMetadata.commitHash, + let remoteEtag = remoteMetadata.etag, + remoteMetadata.location != "" else { + throw EnvironmentError.invalidMetadataError("File metadata must have been retrieved from server") + } + + // Local file exists => check if it's up-to-date + if downloaded { + // etag matches => update metadata and return file + if localMetadata?.etag == remoteEtag { + try writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataRelativePath: metadataRelativePath) + return destination + } + + // metadata is outdated + etag is a sha256 + // => means it's an LFS file (large) + // => let's compute local hash and compare + // => if match, update metadata and return file + if localMetadata != nil && isValidSHA256(remoteEtag) { + let fileHash = try computeFileHash(file: destination) + if fileHash == remoteEtag { + try writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataRelativePath: metadataRelativePath) + return destination + } + } + } + + // Otherwise, let's download the file! try prepareDestination() - let downloader = Downloader(from: source, to: destination, using: hfToken, inBackground: backgroundSession) + try prepareMetadataDestination() + + let downloader = Downloader(from: source, to: destination, metadataDirURL: metadataDestination, using: hfToken, inBackground: backgroundSession) let downloadSubscriber = downloader.downloadState.sink { state in if case .downloading(let progress) = state { progressHandler(progress) @@ -227,6 +407,9 @@ public extension HubApi { _ = try withExtendedLifetime(downloadSubscriber) { try downloader.waitUntilDone() } + + try writeDownloadMetadata(commitHash: remoteCommitHash, etag: remoteEtag, metadataRelativePath: metadataRelativePath) + return destination } } @@ -274,20 +457,36 @@ public extension HubApi { /// Metadata public extension HubApi { - /// A structure representing metadata for a remote file + /// Data structure containing information about a file versioned on the Hub struct FileMetadata { - /// The file's Git commit hash + /// The commit hash related to the file public let commitHash: String? - /// Server-provided ETag for caching + /// Etag of the file on the server public let etag: String? - /// Stringified URL location of the file + /// Location where to download the file. Can be a Hub url or not (CDN). public let location: String - /// The file's size in bytes + /// Size of the file. In case of an LFS file, contains the size of the actual LFS file, not the pointer. public let size: Int? } + + /// Metadata about a file in the local directory related to a download process + struct LocalDownloadFileMetadata { + /// Commit hash of the file in the repo + public let commitHash: String + + /// ETag of the file in the repo. Used to check if the file has changed. + /// For LFS files, this is the sha256 of the file. For regular files, it corresponds to the git hash. + public let etag: String + + /// Path of the file in the repo + public let filename: String + + /// The timestamp of when the metadata was saved i.e. when the metadata was accurate + public let timestamp: Date + } private func normalizeEtag(_ etag: String?) -> String? { guard let etag = etag else { return nil } @@ -296,13 +495,14 @@ public extension HubApi { func getFileMetadata(url: URL) async throws -> FileMetadata { let (_, response) = try await httpHead(for: url) + let location = response.statusCode == 302 ? response.value(forHTTPHeaderField: "Location") : response.url?.absoluteString return FileMetadata( commitHash: response.value(forHTTPHeaderField: "X-Repo-Commit"), etag: normalizeEtag( (response.value(forHTTPHeaderField: "X-Linked-Etag")) ?? (response.value(forHTTPHeaderField: "Etag")) ), - location: (response.value(forHTTPHeaderField: "Location")) ?? url.absoluteString, + location: location ?? url.absoluteString, size: Int(response.value(forHTTPHeaderField: "X-Linked-Size") ?? response.value(forHTTPHeaderField: "Content-Length") ?? "") ) } @@ -395,3 +595,43 @@ public extension [String] { filter { fnmatch(glob, $0, 0) == 0 } } } + +/// Only allow relative redirects and reject others +/// Reference: https://github.com/huggingface/huggingface_hub/blob/b2c9a148d465b43ab90fab6e4ebcbbf5a9df27d4/src/huggingface_hub/file_download.py#L258 +private class RedirectDelegate: NSObject, URLSessionTaskDelegate { + func urlSession(_ session: URLSession, task: URLSessionTask, willPerformHTTPRedirection response: HTTPURLResponse, newRequest request: URLRequest, completionHandler: @escaping (URLRequest?) -> Void) { + // Check if it's a redirect status code (300-399) + if (300...399).contains(response.statusCode) { + // Get the Location header + if let locationString = response.value(forHTTPHeaderField: "Location"), + let locationUrl = URL(string: locationString) { + + // Check if it's a relative redirect (no host component) + if locationUrl.host == nil { + // For relative redirects, construct the new URL using the original request's base + if let originalUrl = task.originalRequest?.url, + var components = URLComponents(url: originalUrl, resolvingAgainstBaseURL: true) { + // Update the path component with the relative path + components.path = locationUrl.path + components.query = locationUrl.query + + // Create new request with the resolved URL + if let resolvedUrl = components.url { + var newRequest = URLRequest(url: resolvedUrl) + // Copy headers from original request + task.originalRequest?.allHTTPHeaderFields?.forEach { key, value in + newRequest.setValue(value, forHTTPHeaderField: key) + } + newRequest.setValue(resolvedUrl.absoluteString, forHTTPHeaderField: "Location") + completionHandler(newRequest) + return + } + } + } + } + } + + // For all other cases (non-redirects or absolute redirects), prevent redirect + completionHandler(nil) + } +} diff --git a/Tests/HubTests/HubApiTests.swift b/Tests/HubTests/HubApiTests.swift index 9871ba6f..9b12ee64 100644 --- a/Tests/HubTests/HubApiTests.swift +++ b/Tests/HubTests/HubApiTests.swift @@ -144,6 +144,26 @@ class HubApiTests: XCTestCase { XCTAssertGreaterThan(metadata.size!, 0) } } + + /// Verify with `curl -I https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel` + func testGetLargeFileMetadata() async throws { + do { + let revision = "eaf97358a37d03fd48e5a87d15aff2e8423c1afb" + let etag = "fc329090bfbb2570382c9af997cffd5f4b78b39b8aeca62076db69534e020107" + let location = "https://cdn-lfs.hf.co/repos/4a/4e/4a4e587f66a2979dcd75e1d7324df8ee9ef74be3582a05bea31c2c26d0d467d0/fc329090bfbb2570382c9af997cffd5f4b78b39b8aeca62076db69534e020107?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27model.mlmodel%3B+filename%3D%22model.mlmodel" + let size = 504766 + + let url = URL(string: "https://huggingface.co/coreml-projects/Llama-2-7b-chat-coreml/resolve/main/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel") + let metadata = try await Hub.getFileMetadata(fileURL: url!) + + XCTAssertEqual(metadata.commitHash, revision) + XCTAssertNotNil(metadata.etag, etag) + XCTAssertTrue(metadata.location.contains(location)) + XCTAssertEqual(metadata.size, size) + } catch { + XCTFail("\(error)") + } + } } class SnapshotDownloadTests: XCTestCase { @@ -252,4 +272,239 @@ class SnapshotDownloadTests: XCTestCase { ]) ) } + + func testDownloadFileMetadata() async throws { + let hubApi = HubApi(downloadBase: downloadDestination) + var lastProgress: Progress? = nil + let downloadedTo = try await hubApi.snapshot(from: repo, matching: "*.json") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + XCTAssertEqual(lastProgress?.fractionCompleted, 1) + XCTAssertEqual(lastProgress?.completedUnitCount, 6) + XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) + + let downloadedFilenames = getRelativeFiles(url: downloadDestination) + XCTAssertEqual( + Set(downloadedFilenames), + Set([ + "config.json", "tokenizer.json", "tokenizer_config.json", + "llama-2-7b-chat.mlpackage/Manifest.json", + "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/FeatureDescriptions.json", + "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json", + ]) + ) + + let metadataDestination = downloadedTo.appending(component: ".cache/huggingface") + let downloadedMetadataFilenames = getRelativeFiles(url: metadataDestination) + XCTAssertEqual( + Set(downloadedMetadataFilenames), + Set([ + ".cache/huggingface/config.metadata", + ".cache/huggingface/tokenizer.metadata", + ".cache/huggingface/tokenizer_config.metadata", + ".cache/huggingface/llama-2-7b-chat.mlpackage/Manifest.metadata", + ".cache/huggingface/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/FeatureDescriptions.metadata", + ".cache/huggingface/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.metadata", + ]) + ) + } + + func testDownloadFileMetadataExists() async throws { + let hubApi = HubApi(downloadBase: downloadDestination) + var lastProgress: Progress? = nil + + let downloadedTo = try await hubApi.snapshot(from: repo, matching: "*.json") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + XCTAssertEqual(lastProgress?.fractionCompleted, 1) + XCTAssertEqual(lastProgress?.completedUnitCount, 6) + XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) + + let downloadedFilenames = getRelativeFiles(url: downloadDestination) + XCTAssertEqual( + Set(downloadedFilenames), + Set([ + "config.json", "tokenizer.json", "tokenizer_config.json", + "llama-2-7b-chat.mlpackage/Manifest.json", + "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/FeatureDescriptions.json", + "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json", + ]) + ) + + let metadataDestination = downloadedTo.appending(component: ".cache/huggingface") + + let configPath = downloadedTo.appending(path: "config.json") + var attributes = try FileManager.default.attributesOfItem(atPath: configPath.path) + let originalTimestamp = attributes[.modificationDate] as! Date + + let downloadedMetadataFilenames = getRelativeFiles(url: metadataDestination) + XCTAssertEqual( + Set(downloadedMetadataFilenames), + Set([ + ".cache/huggingface/config.metadata", + ".cache/huggingface/tokenizer.metadata", + ".cache/huggingface/tokenizer_config.metadata", + ".cache/huggingface/llama-2-7b-chat.mlpackage/Manifest.metadata", + ".cache/huggingface/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/FeatureDescriptions.metadata", + ".cache/huggingface/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.metadata", + ]) + ) + + let _ = try await hubApi.snapshot(from: repo, matching: "*.json") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + attributes = try FileManager.default.attributesOfItem(atPath: configPath.path) + let secondDownloadTimestamp = attributes[.modificationDate] as! Date + + // File will not be downloaded again thus last modified date will remain unchanged + XCTAssertTrue(originalTimestamp == secondDownloadTimestamp) + } + + func testDownloadFileMetadataCorrupted() async throws { + let hubApi = HubApi(downloadBase: downloadDestination) + var lastProgress: Progress? = nil + + let downloadedTo = try await hubApi.snapshot(from: repo, matching: "*.json") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + XCTAssertEqual(lastProgress?.fractionCompleted, 1) + XCTAssertEqual(lastProgress?.completedUnitCount, 6) + XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) + + let downloadedFilenames = getRelativeFiles(url: downloadDestination) + XCTAssertEqual( + Set(downloadedFilenames), + Set([ + "config.json", "tokenizer.json", "tokenizer_config.json", + "llama-2-7b-chat.mlpackage/Manifest.json", + "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/FeatureDescriptions.json", + "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.json", + ]) + ) + + let metadataDestination = downloadedTo.appending(component: ".cache/huggingface") + + let configPath = downloadedTo.appending(path: "config.json") + var attributes = try FileManager.default.attributesOfItem(atPath: configPath.path) + let originalTimestamp = attributes[.modificationDate] as! Date + + let downloadedMetadataFilenames = getRelativeFiles(url: metadataDestination) + XCTAssertEqual( + Set(downloadedMetadataFilenames), + Set([ + ".cache/huggingface/config.metadata", + ".cache/huggingface/tokenizer.metadata", + ".cache/huggingface/tokenizer_config.metadata", + ".cache/huggingface/llama-2-7b-chat.mlpackage/Manifest.metadata", + ".cache/huggingface/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/FeatureDescriptions.metadata", + ".cache/huggingface/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/Metadata.metadata", + ]) + ) + + // Corrupt config.metadata + print("Testing corrupted file.") + try "a".write(to: metadataDestination.appendingPathComponent("config.metadata"), atomically: true, encoding: .utf8) + + let _ = try await hubApi.snapshot(from: repo, matching: "*.json") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + attributes = try FileManager.default.attributesOfItem(atPath: configPath.path) + let secondDownloadTimestamp = attributes[.modificationDate] as! Date + + // File will be downloaded again thus last modified date will change + XCTAssertTrue(originalTimestamp != secondDownloadTimestamp) + + // Corrupt config.metadata again + print("Testing corrupted timestamp.") + try "a\nb\nc\n".write(to: metadataDestination.appendingPathComponent("config.metadata"), atomically: true, encoding: .utf8) + + let _ = try await hubApi.snapshot(from: repo, matching: "*.json") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + attributes = try FileManager.default.attributesOfItem(atPath: configPath.path) + let thirdDownloadTimestamp = attributes[.modificationDate] as! Date + + // File will be downloaded again thus last modified date will change + XCTAssertTrue(originalTimestamp != thirdDownloadTimestamp) + } + + func testDownloadLargeFileMetadataCorrupted() async throws { + let hubApi = HubApi(downloadBase: downloadDestination) + var lastProgress: Progress? = nil + + let downloadedTo = try await hubApi.snapshot(from: repo, matching: "*.mlmodel") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + XCTAssertEqual(lastProgress?.fractionCompleted, 1) + XCTAssertEqual(lastProgress?.completedUnitCount, 1) + XCTAssertEqual(downloadedTo, downloadDestination.appending(path: "models/\(repo)")) + + let downloadedFilenames = getRelativeFiles(url: downloadDestination) + XCTAssertEqual( + Set(downloadedFilenames), + Set([ +"llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel", + ]) + ) + + let metadataDestination = downloadedTo.appending(component: ".cache/huggingface") + + let modelPath = downloadedTo.appending(path: "llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.mlmodel") + var attributes = try FileManager.default.attributesOfItem(atPath: modelPath.path) + let originalTimestamp = attributes[.modificationDate] as! Date + + let downloadedMetadataFilenames = getRelativeFiles(url: metadataDestination) + XCTAssertEqual( + Set(downloadedMetadataFilenames), + Set([ + ".cache/huggingface/llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.metadata", + ]) + ) + + // Corrupt model.metadata etag + print("Testing corrupted etag.") + let corruptedMetadataString = "a\nfc329090bfbb2570382c9af997cffd5f4b78b39b8aeca62076db69534e020108\n0\n" + let metadataFile = metadataDestination.appendingPathComponent("llama-2-7b-chat.mlpackage/Data/com.apple.CoreML/model.metadata") + try corruptedMetadataString.write(to: metadataFile, atomically: true, encoding: .utf8) + + let _ = try await hubApi.snapshot(from: repo, matching: "*.mlmodel") { progress in + print("Total Progress: \(progress.fractionCompleted)") + print("Files Completed: \(progress.completedUnitCount) of \(progress.totalUnitCount)") + lastProgress = progress + } + + attributes = try FileManager.default.attributesOfItem(atPath: modelPath.path) + let thirdDownloadTimestamp = attributes[.modificationDate] as! Date + + // File will not be downloaded again because this is an LFS file. + // While downloading LFS files, we first check if local file ETag is the same as remote ETag. + // If that's the case we just update the metadata and keep the local file. + XCTAssertEqual(originalTimestamp, thirdDownloadTimestamp) + + let metadataString = try String(contentsOfFile: metadataFile.path) + + // Updated metadata file needs to have the correct commit hash, etag and timestamp. + // This is being updated because the local etag (SHA256 checksum) matches the remote etag + XCTAssertNotEqual(metadataString, corruptedMetadataString) + } }