@@ -162,6 +162,9 @@ def remove(self, force=False): # noqa: ARG002
162
162
def dump (self , stage , ** kwargs ):
163
163
raise NotImplementedError
164
164
165
+ def dump_stages (self , stages , ** kwargs ):
166
+ raise NotImplementedError
167
+
165
168
def merge (self , ancestor , other , allowed = None ):
166
169
raise NotImplementedError
167
170
@@ -198,6 +201,13 @@ def dump(self, stage, **kwargs) -> None:
198
201
dump_yaml (self .path , serialize .to_single_stage_file (stage , ** kwargs ))
199
202
self .repo .scm_context .track_file (self .relpath )
200
203
204
+ def dump_stages (self , stages , ** kwargs ) -> None :
205
+ if not stages :
206
+ return
207
+
208
+ assert len (stages ) == 1 , "SingleStageFile can only dump one stage."
209
+ return self .dump (stages [0 ], ** kwargs )
210
+
201
211
def remove_stage (self , stage ): # noqa: ARG002
202
212
self .remove ()
203
213
@@ -228,17 +238,27 @@ def _reset(self):
228
238
229
239
def dump (self , stage , update_pipeline = True , update_lock = True , ** kwargs ):
230
240
"""Dumps given stage appropriately in the dvcfile."""
241
+ return self .dump_stages (
242
+ [stage ], update_pipeline = update_pipeline , update_lock = update_lock , ** kwargs
243
+ )
244
+
245
+ def dump_stages (self , stages , update_pipeline = True , update_lock = True , ** kwargs ):
231
246
from dvc .stage import PipelineStage
232
247
233
- assert isinstance (stage , PipelineStage )
248
+ if not stages :
249
+ return
250
+
251
+ for stage in stages :
252
+ assert isinstance (stage , PipelineStage )
253
+
234
254
if self .verify :
235
255
check_dvcfile_path (self .repo , self .path )
236
256
237
- if update_pipeline and not stage . is_data_source :
238
- self ._dump_pipeline_file (stage )
257
+ if update_pipeline :
258
+ self ._dump_pipeline_file (stages )
239
259
240
260
if update_lock :
241
- self ._dump_lockfile (stage , ** kwargs )
261
+ self ._dump_lockfile (stages , ** kwargs )
242
262
243
263
def dump_dataset (self , dataset ):
244
264
with modify_yaml (self .path , fs = self .repo .fs ) as data :
@@ -260,32 +280,37 @@ def dump_dataset(self, dataset):
260
280
raw .append (dataset )
261
281
self .repo .scm_context .track_file (self .relpath )
262
282
263
- def _dump_lockfile (self , stage , ** kwargs ):
264
- self ._lockfile .dump ( stage , ** kwargs )
283
+ def _dump_lockfile (self , stages , ** kwargs ):
284
+ self ._lockfile .dump_stages ( stages , ** kwargs )
265
285
266
286
@staticmethod
267
287
def _check_if_parametrized (stage , action : str = "dump" ) -> None :
268
288
if stage .raw_data .parametrized :
269
289
raise ParametrizedDumpError (f"cannot { action } a parametrized { stage } " )
270
290
271
- def _dump_pipeline_file (self , stage ):
272
- self ._check_if_parametrized (stage )
273
- stage_data = serialize .to_pipeline_file (stage )
291
+ def _dump_pipeline_file (self , stages ):
292
+ stages = stages if isinstance (stages , list ) else [stages ]
293
+ if not stages :
294
+ return
295
+
296
+ for stage in stages :
297
+ self ._check_if_parametrized (stage )
274
298
275
299
with modify_yaml (self .path , fs = self .repo .fs ) as data :
276
300
if not data :
277
301
logger .info ("Creating '%s'" , self .relpath )
278
302
279
303
data ["stages" ] = data .get ("stages" , {})
280
- existing_entry = stage .name in data ["stages" ]
281
- action = "Modifying" if existing_entry else "Adding"
282
- logger .info ("%s stage '%s' in '%s'" , action , stage .name , self .relpath )
283
-
284
- if existing_entry :
285
- orig_stage_data = data ["stages" ][stage .name ]
286
- apply_diff (stage_data [stage .name ], orig_stage_data )
287
- else :
288
- data ["stages" ].update (stage_data )
304
+ for stage in stages :
305
+ stage_data = serialize .to_pipeline_file (stage )
306
+ existing_entry = stage .name in data ["stages" ]
307
+ action = "Modifying" if existing_entry else "Adding"
308
+ logger .info ("%s stage '%s' in '%s'" , action , stage .name , self .relpath )
309
+ if existing_entry :
310
+ orig_stage_data = data ["stages" ][stage .name ]
311
+ apply_diff (stage_data [stage .name ], orig_stage_data )
312
+ else :
313
+ data ["stages" ].update (stage_data )
289
314
290
315
self .repo .scm_context .track_file (self .relpath )
291
316
@@ -399,27 +424,37 @@ def dump_dataset(self, dataset: dict):
399
424
data .setdefault ("stages" , {})
400
425
self .repo .scm_context .track_file (self .relpath )
401
426
402
- def dump (self , stage , ** kwargs ):
403
- stage_data = serialize .to_lockfile (stage , ** kwargs )
427
+ def dump_stages (self , stages , ** kwargs ):
428
+ if not stages :
429
+ return
404
430
431
+ is_modified = False
432
+ log_updated = False
405
433
with modify_yaml (self .path , fs = self .repo .fs ) as data :
406
434
if not data :
407
435
data .update ({"schema" : "2.0" })
408
436
# order is important, meta should always be at the top
409
437
logger .info ("Generating lock file '%s'" , self .relpath )
410
438
411
439
data ["stages" ] = data .get ("stages" , {})
412
- modified = data ["stages" ].get (stage .name , {}) != stage_data .get (
413
- stage .name , {}
414
- )
415
- if modified :
416
- logger .info ("Updating lock file '%s'" , self .relpath )
417
-
418
- data ["stages" ].update (stage_data )
440
+ for stage in stages :
441
+ stage_data = serialize .to_lockfile (stage , ** kwargs )
442
+ modified = data ["stages" ].get (stage .name , {}) != stage_data .get (
443
+ stage .name , {}
444
+ )
445
+ if modified :
446
+ is_modified = True
447
+ if not log_updated :
448
+ logger .info ("Updating lock file '%s'" , self .relpath )
449
+ log_updated = True
450
+ data ["stages" ].update (stage_data )
419
451
420
- if modified :
452
+ if is_modified :
421
453
self .repo .scm_context .track_file (self .relpath )
422
454
455
+ def dump (self , stage , ** kwargs ):
456
+ self .dump_stages ([stage ], ** kwargs )
457
+
423
458
def remove_stage (self , stage ):
424
459
if not self .exists ():
425
460
return
0 commit comments