3030T = TypeVar ("T" , bound = "BaseModel" )
3131
3232
33+ class OverrideConfigs (BaseModel ):
34+ """
35+ The class that specifies the override configurations.
36+ """
37+
38+ context_window_size : Optional [int ] = None
39+ sliding_window_size : Optional [int ] = None
40+ prefill_chunk_size : Optional [int ] = None
41+ attention_sink_size : Optional [int ] = None
42+ tensor_parallel_shards : Optional [int ] = None
43+
44+
3345class ModelDeliveryTask (BaseModel ):
3446 """
3547 Example:
@@ -38,21 +50,21 @@ class ModelDeliveryTask(BaseModel):
3850 "model": "HF://microsoft/Phi-3-mini-128k-instruct",
3951 "conv_template": "phi-3",
4052 "quantization": ["q3f16_1"],
41- "context_window_size": 4096
53+ "overrides": {
54+ "q3f16_1": {
55+ "context_window_size": 512
56+ }
57+ }
4258 }
4359 """
4460
4561 model_id : str
4662 model : str
4763 conv_template : str
48- quantization : Optional [Union [List [str ], str ]] = Field (default_factory = list )
64+ quantization : Union [List [str ], str ] = Field (default_factory = list )
65+ overrides : Dict [str , OverrideConfigs ] = Field (default_factory = dict )
4966 destination : Optional [str ] = None
50-
51- context_window_size : Optional [int ] = None
52- sliding_window_size : Optional [int ] = None
53- prefill_chunk_size : Optional [int ] = None
54- attention_sink_size : Optional [int ] = None
55- tensor_parallel_shards : Optional [int ] = None
67+ gen_config_only : Optional [bool ] = False
5668
5769
5870class ModelDeliveryList (BaseModel ):
@@ -63,7 +75,8 @@ class ModelDeliveryList(BaseModel):
6375 tasks : List [ModelDeliveryTask ]
6476 # For delivered log, the default destination and quantization fields are optional
6577 default_destination : Optional [str ] = None
66- default_quantization : Optional [List [str ]] = None
78+ default_quantization : List [str ] = Field (default_factory = list )
79+ default_overrides : Dict [str , OverrideConfigs ] = Field (default_factory = dict )
6780
6881 @classmethod
6982 def from_json (cls : Type [T ], json_dict : Dict [str , Any ]) -> T :
@@ -115,10 +128,7 @@ def _run_quantization(
115128 except HfHubHTTPError as error :
116129 if error .response .status_code != 409 :
117130 raise
118- logger .info ("[HF] Repo already exists. Recreating..." )
119- api .delete_repo (repo_id = repo )
120- api .create_repo (repo_id = repo , private = False )
121- logger .info ("[HF] Repo recreated" )
131+ logger .info ("[HF] Repo already exists. Skipping creation." )
122132 succeeded = True
123133 log_path = Path (output_dir ) / "logs.txt"
124134 with log_path .open ("a" , encoding = "utf-8" ) as log_file :
@@ -147,21 +157,24 @@ def _run_quantization(
147157
148158 print (" " .join (cmd ), file = log_file , flush = True )
149159 subprocess .run (cmd , check = True , stdout = log_file , stderr = subprocess .STDOUT , env = os .environ )
150- cmd = [
151- sys .executable ,
152- "-m" ,
153- "mlc_llm" ,
154- "convert_weight" ,
155- str (model_info .model ),
156- "--quantization" ,
157- model_info .quantization ,
158- "--output" ,
159- output_dir ,
160- ]
161- print (" " .join (cmd ), file = log_file , flush = True )
162- subprocess .run (cmd , check = False , stdout = log_file , stderr = subprocess .STDOUT , env = os .environ )
160+ if not model_info .gen_config_only :
161+ cmd = [
162+ sys .executable ,
163+ "-m" ,
164+ "mlc_llm" ,
165+ "convert_weight" ,
166+ str (model_info .model ),
167+ "--quantization" ,
168+ model_info .quantization ,
169+ "--output" ,
170+ output_dir ,
171+ ]
172+ print (" " .join (cmd ), file = log_file , flush = True )
173+ subprocess .run (
174+ cmd , check = False , stdout = log_file , stderr = subprocess .STDOUT , env = os .environ
175+ )
163176 logger .info ("[MLC] Complete!" )
164- if not (Path (output_dir ) / "ndarray-cache.json" ).exists ():
177+ if not (Path (output_dir ) / "ndarray-cache.json" ).exists () and not model_info . gen_config_only :
165178 logger .error (
166179 "[%s] Model %s. Quantization %s. No weights metadata found." ,
167180 red ("FAILED" ),
@@ -175,7 +188,7 @@ def _run_quantization(
175188 api .upload_folder (
176189 folder_path = output_dir ,
177190 repo_id = repo ,
178- commit_message = "Initial commit" ,
191+ ignore_patterns = [ "logs.txt" ] ,
179192 )
180193 except Exception as exc : # pylint: disable=broad-except
181194 logger .error ("[%s] %s. Retrying..." , red ("FAILED" ), exc )
@@ -198,38 +211,99 @@ def _get_current_log(log: str) -> ModelDeliveryList:
198211 return current_log
199212
200213
214+ def _generate_model_delivery_diff ( # pylint: disable=too-many-locals
215+ spec : ModelDeliveryList , log : ModelDeliveryList
216+ ) -> ModelDeliveryList :
217+ diff_tasks = []
218+ default_quantization = spec .default_quantization
219+ default_overrides = spec .default_overrides
220+
221+ for task in spec .tasks :
222+ model_id = task .model_id
223+ conv_template = task .conv_template
224+ quantization = task .quantization
225+ overrides = {** default_overrides , ** task .overrides }
226+
227+ logger .info (f"Checking task: %s %s %s %s" , model_id , conv_template , quantization , overrides )
228+ log_tasks = [t for t in log .tasks if t .model_id == model_id ]
229+ delivered_quantizations = set ()
230+ gen_config_only = set ()
231+
232+ for log_task in log_tasks :
233+ log_quantization = log_task .quantization
234+ assert isinstance (log_quantization , str )
235+ log_override = log_task .overrides .get (log_quantization , OverrideConfigs ())
236+ override = overrides .get (log_quantization , OverrideConfigs ())
237+ if log_override == override :
238+ if log_task .conv_template == conv_template :
239+ delivered_quantizations .add (log_quantization )
240+ else :
241+ gen_config_only .add (log_quantization )
242+
243+ all_quantizations = set (default_quantization ) | set (quantization )
244+ quantization_diff = all_quantizations - set (delivered_quantizations )
245+
246+ if quantization_diff :
247+ for q in quantization_diff :
248+ logger .info (f"Adding task %s %s %s to the diff." , model_id , conv_template , q )
249+ task_copy = task .model_copy ()
250+ task_copy .quantization = [q ]
251+ task_copy .overrides = {q : overrides .get (q , OverrideConfigs ())}
252+ task_copy .gen_config_only = task_copy .gen_config_only or q in gen_config_only
253+ diff_tasks .append (task_copy )
254+ else :
255+ logger .info (f"Task %s %s %s is up-to-date." , model_id , conv_template , quantization )
256+
257+ diff_config = spec .model_copy ()
258+ diff_config .default_quantization = []
259+ diff_config .default_overrides = {}
260+ diff_config .tasks = diff_tasks
261+
262+ logger .info ("Model delivery diff: %s" , diff_config .model_dump_json (indent = 4 , exclude_none = True ))
263+
264+ return diff_config
265+
266+
201267def _main ( # pylint: disable=too-many-locals, too-many-arguments
202268 username : str ,
203269 api : HfApi ,
204270 spec : ModelDeliveryList ,
205271 log : str ,
206272 hf_local_dir : Optional [str ],
207273 output : str ,
274+ dry_run : bool ,
208275):
276+ delivery_diff = _generate_model_delivery_diff (spec , _get_current_log (log ))
277+ if dry_run :
278+ logger .info ("Dry run. No actual delivery." )
279+ return
280+
209281 failed_cases : List [Tuple [str , str ]] = []
210282 delivered_log = _get_current_log (log )
211- for task_index , task in enumerate (spec .tasks , 1 ):
283+ for task_index , task in enumerate (delivery_diff .tasks , 1 ):
212284 logger .info (
213285 bold ("[{task_index}/{total_tasks}] Processing model: " ).format (
214286 task_index = task_index ,
215- total_tasks = len (spec .tasks ),
287+ total_tasks = len (delivery_diff .tasks ),
216288 )
217289 + green (task .model_id )
218290 )
219291 model = _clone_repo (task .model , hf_local_dir )
220292
221293 quantizations = []
222294
223- if spec .default_quantization :
224- quantizations += spec .default_quantization
295+ if delivery_diff .default_quantization :
296+ quantizations += delivery_diff .default_quantization
225297
226298 if task .quantization :
227299 if isinstance (task .quantization , str ):
228300 quantizations .append (task .quantization )
229301 else :
230302 quantizations += task .quantization
231303
232- default_destination = spec .default_destination or "{username}/{model_id}-{quantization}-MLC"
304+ default_destination = (
305+ delivery_diff .default_destination or "{username}/{model_id}-{quantization}-MLC"
306+ )
233307 for quantization in quantizations :
234308 repo = default_destination .format (
235309 username = username ,
@@ -260,12 +334,19 @@ def _main( # pylint: disable=too-many-locals, too-many-arguments
260334 (task .model_id , quantization ),
261335 )
262336 else :
337+ delivered_log .tasks = [
338+ task
339+ for task in delivered_log .tasks
340+ if task .model_id != model_info .model_id
341+ or task .quantization != model_info .quantization
342+ ]
263343 delivered_log .tasks .append (model_info )
264344 if failed_cases :
265345 logger .info ("Total %s %s:" , len (failed_cases ), red ("failures" ))
266346 for model_id , quantization in failed_cases :
267347 logger .info (" Model %s. Quantization %s." , model_id , quantization )
268348
349+ delivered_log .tasks .sort (key = lambda task : task .model_id )
269350 logger .info ("Writing log to %s" , log )
270351 with open (log , "w" , encoding = "utf-8" ) as o_f :
271352 json .dump (delivered_log .to_json (), o_f , indent = 4 )
@@ -336,6 +417,11 @@ def _get_default_hf_token() -> str:
336417 required = False ,
337418 help = "Local directory to store the downloaded HuggingFace model" ,
338419 )
420+ parser .add_argument (
421+ "--dry-run" ,
422+ action = "store_true" ,
423+ help = "Dry run without uploading to HuggingFace Hub" ,
424+ )
339425 parsed = parser .parse_args ()
340426 _main (
341427 parsed .username ,
@@ -344,6 +430,7 @@ def _get_default_hf_token() -> str:
344430 api = HfApi (token = parsed .token ),
345431 hf_local_dir = parsed .hf_local_dir ,
346432 output = parsed .output ,
433+ dry_run = parsed .dry_run ,
347434 )
348435
349436
0 commit comments