@@ -12,6 +12,8 @@ import Combine
1212class Downloader : NSObject , ObservableObject {
1313 private( set) var destination : URL
1414
15+ private let chunkSize = 10 * 1024 * 1024 // 10MB
16+
1517 enum DownloadState {
1618 case notStarted
1719 case downloading( Double )
@@ -29,7 +31,17 @@ class Downloader: NSObject, ObservableObject {
2931
3032 private var urlSession : URLSession ? = nil
3133
32- init ( from url: URL , to destination: URL , using authToken: String ? = nil , inBackground: Bool = false ) {
34+ init (
35+ from url: URL ,
36+ to destination: URL ,
37+ using authToken: String ? = nil ,
38+ inBackground: Bool = false ,
39+ resumeSize: Int = 0 ,
40+ headers: [ String : String ] ? = nil ,
41+ expectedSize: Int ? = nil ,
42+ timeout: TimeInterval = 10 ,
43+ numRetries: Int = 5
44+ ) {
3345 self . destination = destination
3446 super. init ( )
3547 let sessionIdentifier = " swift-transformers.hub.downloader "
@@ -43,10 +55,28 @@ class Downloader: NSObject, ObservableObject {
4355
4456 self . urlSession = URLSession ( configuration: config, delegate: self , delegateQueue: nil )
4557
46- setupDownload ( from: url, with: authToken)
58+ setupDownload ( from: url, with: authToken, resumeSize : resumeSize , headers : headers , expectedSize : expectedSize , timeout : timeout , numRetries : numRetries )
4759 }
4860
49- private func setupDownload( from url: URL , with authToken: String ? ) {
61+ /// Sets up and initiates a file download operation
62+ ///
63+ /// - Parameters:
64+ /// - url: Source URL to download from
65+ /// - authToken: Bearer token for authentication with Hugging Face
66+ /// - resumeSize: Number of bytes already downloaded for resuming interrupted downloads
67+ /// - headers: Additional HTTP headers to include in the request
68+ /// - expectedSize: Expected file size in bytes for validation
69+ /// - timeout: Time interval before the request times out
70+ /// - numRetries: Number of retry attempts for failed downloads
71+ private func setupDownload(
72+ from url: URL ,
73+ with authToken: String ? ,
74+ resumeSize: Int ,
75+ headers: [ String : String ] ? ,
76+ expectedSize: Int ? ,
77+ timeout: TimeInterval ,
78+ numRetries: Int
79+ ) {
5080 downloadState. value = . downloading( 0 )
5181 urlSession? . getAllTasks { tasks in
5282 // If there's an existing pending background task with the same URL, let it proceed.
@@ -71,14 +101,137 @@ class Downloader: NSObject, ObservableObject {
71101 }
72102 }
73103 var request = URLRequest ( url: url)
104+
105+ // Use headers from argument else create an empty header dictionary
106+ var requestHeaders = headers ?? [ : ]
107+
108+ // Populate header auth and range fields
74109 if let authToken = authToken {
75- request. setValue ( " Bearer \( authToken) " , forHTTPHeaderField: " Authorization " )
110+ requestHeaders [ " Authorization " ] = " Bearer \( authToken) "
111+ }
112+ if resumeSize > 0 {
113+ requestHeaders [ " Range " ] = " bytes= \( resumeSize) - "
76114 }
115+
116+
117+ request. timeoutInterval = timeout
118+ request. allHTTPHeaderFields = requestHeaders
77119
78- self . urlSession? . downloadTask ( with: request) . resume ( )
120+ Task {
121+ do {
122+ // Create a temp file to write
123+ let tempURL = FileManager . default. temporaryDirectory. appendingPathComponent ( UUID ( ) . uuidString)
124+ FileManager . default. createFile ( atPath: tempURL. path, contents: nil )
125+ let tempFile = try FileHandle ( forWritingTo: tempURL)
126+
127+ defer { tempFile. closeFile ( ) }
128+ try await self . httpGet ( request: request, tempFile: tempFile, resumeSize: resumeSize, numRetries: numRetries, expectedSize: expectedSize)
129+
130+ // Clean up and move the completed download to its final destination
131+ tempFile. closeFile ( )
132+ try FileManager . default. moveDownloadedFile ( from: tempURL, to: self . destination)
133+
134+ self . downloadState. value = . completed( self . destination)
135+ } catch {
136+ self . downloadState. value = . failed( error)
137+ }
138+ }
79139 }
80140 }
81141
142+ /// Downloads a file from given URL using chunked transfer and handles retries.
143+ ///
144+ /// Reference: https://github.com/huggingface/huggingface_hub/blob/418a6ffce7881f5c571b2362ed1c23ef8e4d7d20/src/huggingface_hub/file_download.py#L306
145+ ///
146+ /// - Parameters:
147+ /// - request: The URLRequest for the file to download
148+ /// - resumeSize: The number of bytes already downloaded. If set to 0 (default), the whole file is download. If set to a positive number, the download will resume at the given position
149+ /// - numRetries: The number of retry attempts remaining for failed downloads
150+ /// - expectedSize: The expected size of the file to download. If set, the download will raise an error if the size of the received content is different from the expected one.
151+ /// - Throws: `DownloadError.unexpectedError` if the response is invalid or file size mismatch occurs
152+ /// `URLError` if the download fails after all retries are exhausted
153+ private func httpGet(
154+ request: URLRequest ,
155+ tempFile: FileHandle ,
156+ resumeSize: Int ,
157+ numRetries: Int ,
158+ expectedSize: Int ?
159+ ) async throws {
160+ guard let session = self . urlSession else {
161+ throw DownloadError . unexpectedError
162+ }
163+
164+ // Create a new request with Range header for resuming
165+ var newRequest = request
166+ if resumeSize > 0 {
167+ newRequest. setValue ( " bytes= \( resumeSize) - " , forHTTPHeaderField: " Range " )
168+ }
169+
170+ // Start the download and get the byte stream
171+ let ( asyncBytes, response) = try await session. bytes ( for: newRequest)
172+
173+ guard let response = response as? HTTPURLResponse else {
174+ throw DownloadError . unexpectedError
175+ }
176+
177+ guard ( 200 ..< 300 ) . contains ( response. statusCode) else {
178+ throw DownloadError . unexpectedError
179+ }
180+
181+ var downloadedSize = resumeSize
182+
183+ // Create a buffer to collect bytes before writing to disk
184+ var buffer = Data ( capacity: chunkSize)
185+
186+ var newNumRetries = numRetries
187+ do {
188+ for try await byte in asyncBytes {
189+ buffer. append ( byte)
190+ // When buffer is full, write to disk
191+ if buffer. count == chunkSize {
192+ if !buffer. isEmpty { // Filter out keep-alive chunks
193+ try tempFile. write ( contentsOf: buffer)
194+ buffer. removeAll ( keepingCapacity: true )
195+ downloadedSize += chunkSize
196+ newNumRetries = 5
197+ guard let expectedSize = expectedSize else { continue }
198+ let progress = expectedSize != 0 ? Double ( downloadedSize) / Double( expectedSize) : 0
199+ downloadState. value = . downloading( progress)
200+ }
201+ }
202+ }
203+
204+ if !buffer. isEmpty {
205+ try tempFile. write ( contentsOf: buffer)
206+ downloadedSize += buffer. count
207+ buffer. removeAll ( keepingCapacity: true )
208+ newNumRetries = 5
209+ }
210+ } catch let error as URLError {
211+ if newNumRetries <= 0 {
212+ throw error
213+ }
214+ try await Task . sleep ( nanoseconds: 1_000_000_000 )
215+
216+ let config = URLSessionConfiguration . default
217+ self . urlSession = URLSession ( configuration: config, delegate: self , delegateQueue: nil )
218+
219+ try await httpGet (
220+ request: request,
221+ tempFile: tempFile,
222+ resumeSize: downloadedSize,
223+ numRetries: newNumRetries - 1 ,
224+ expectedSize: expectedSize
225+ )
226+ }
227+
228+ // Verify the downloaded file size matches the expected size
229+ let actualSize = try tempFile. seekToEnd ( )
230+ if let expectedSize = expectedSize, expectedSize != actualSize {
231+ throw DownloadError . unexpectedError
232+ }
233+ }
234+
82235 @discardableResult
83236 func waitUntilDone( ) throws -> URL {
84237 // It's either this, or stream the bytes ourselves (add to a buffer, save to disk, etc; boring and finicky)
0 commit comments