Skip to content

Commit a68dda4

Browse files
authored
commit: batch dump stages, do not re-read on save from the workspace (#10839)
1 parent 4b1adfd commit a68dda4

File tree

6 files changed

+212
-81
lines changed

6 files changed

+212
-81
lines changed

dvc/dvcfile.py

Lines changed: 63 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,9 @@ def remove(self, force=False): # noqa: ARG002
162162
def dump(self, stage, **kwargs):
163163
raise NotImplementedError
164164

165+
def dump_stages(self, stages, **kwargs):
166+
raise NotImplementedError
167+
165168
def merge(self, ancestor, other, allowed=None):
166169
raise NotImplementedError
167170

@@ -198,6 +201,13 @@ def dump(self, stage, **kwargs) -> None:
198201
dump_yaml(self.path, serialize.to_single_stage_file(stage, **kwargs))
199202
self.repo.scm_context.track_file(self.relpath)
200203

204+
def dump_stages(self, stages, **kwargs) -> None:
205+
if not stages:
206+
return
207+
208+
assert len(stages) == 1, "SingleStageFile can only dump one stage."
209+
return self.dump(stages[0], **kwargs)
210+
201211
def remove_stage(self, stage): # noqa: ARG002
202212
self.remove()
203213

@@ -228,17 +238,27 @@ def _reset(self):
228238

229239
def dump(self, stage, update_pipeline=True, update_lock=True, **kwargs):
230240
"""Dumps given stage appropriately in the dvcfile."""
241+
return self.dump_stages(
242+
[stage], update_pipeline=update_pipeline, update_lock=update_lock, **kwargs
243+
)
244+
245+
def dump_stages(self, stages, update_pipeline=True, update_lock=True, **kwargs):
231246
from dvc.stage import PipelineStage
232247

233-
assert isinstance(stage, PipelineStage)
248+
if not stages:
249+
return
250+
251+
for stage in stages:
252+
assert isinstance(stage, PipelineStage)
253+
234254
if self.verify:
235255
check_dvcfile_path(self.repo, self.path)
236256

237-
if update_pipeline and not stage.is_data_source:
238-
self._dump_pipeline_file(stage)
257+
if update_pipeline:
258+
self._dump_pipeline_file(stages)
239259

240260
if update_lock:
241-
self._dump_lockfile(stage, **kwargs)
261+
self._dump_lockfile(stages, **kwargs)
242262

243263
def dump_dataset(self, dataset):
244264
with modify_yaml(self.path, fs=self.repo.fs) as data:
@@ -260,32 +280,37 @@ def dump_dataset(self, dataset):
260280
raw.append(dataset)
261281
self.repo.scm_context.track_file(self.relpath)
262282

263-
def _dump_lockfile(self, stage, **kwargs):
264-
self._lockfile.dump(stage, **kwargs)
283+
def _dump_lockfile(self, stages, **kwargs):
284+
self._lockfile.dump_stages(stages, **kwargs)
265285

266286
@staticmethod
267287
def _check_if_parametrized(stage, action: str = "dump") -> None:
268288
if stage.raw_data.parametrized:
269289
raise ParametrizedDumpError(f"cannot {action} a parametrized {stage}")
270290

271-
def _dump_pipeline_file(self, stage):
272-
self._check_if_parametrized(stage)
273-
stage_data = serialize.to_pipeline_file(stage)
291+
def _dump_pipeline_file(self, stages):
292+
stages = stages if isinstance(stages, list) else [stages]
293+
if not stages:
294+
return
295+
296+
for stage in stages:
297+
self._check_if_parametrized(stage)
274298

275299
with modify_yaml(self.path, fs=self.repo.fs) as data:
276300
if not data:
277301
logger.info("Creating '%s'", self.relpath)
278302

279303
data["stages"] = data.get("stages", {})
280-
existing_entry = stage.name in data["stages"]
281-
action = "Modifying" if existing_entry else "Adding"
282-
logger.info("%s stage '%s' in '%s'", action, stage.name, self.relpath)
283-
284-
if existing_entry:
285-
orig_stage_data = data["stages"][stage.name]
286-
apply_diff(stage_data[stage.name], orig_stage_data)
287-
else:
288-
data["stages"].update(stage_data)
304+
for stage in stages:
305+
stage_data = serialize.to_pipeline_file(stage)
306+
existing_entry = stage.name in data["stages"]
307+
action = "Modifying" if existing_entry else "Adding"
308+
logger.info("%s stage '%s' in '%s'", action, stage.name, self.relpath)
309+
if existing_entry:
310+
orig_stage_data = data["stages"][stage.name]
311+
apply_diff(stage_data[stage.name], orig_stage_data)
312+
else:
313+
data["stages"].update(stage_data)
289314

290315
self.repo.scm_context.track_file(self.relpath)
291316

@@ -399,27 +424,37 @@ def dump_dataset(self, dataset: dict):
399424
data.setdefault("stages", {})
400425
self.repo.scm_context.track_file(self.relpath)
401426

402-
def dump(self, stage, **kwargs):
403-
stage_data = serialize.to_lockfile(stage, **kwargs)
427+
def dump_stages(self, stages, **kwargs):
428+
if not stages:
429+
return
404430

431+
is_modified = False
432+
log_updated = False
405433
with modify_yaml(self.path, fs=self.repo.fs) as data:
406434
if not data:
407435
data.update({"schema": "2.0"})
408436
# order is important, meta should always be at the top
409437
logger.info("Generating lock file '%s'", self.relpath)
410438

411439
data["stages"] = data.get("stages", {})
412-
modified = data["stages"].get(stage.name, {}) != stage_data.get(
413-
stage.name, {}
414-
)
415-
if modified:
416-
logger.info("Updating lock file '%s'", self.relpath)
417-
418-
data["stages"].update(stage_data)
440+
for stage in stages:
441+
stage_data = serialize.to_lockfile(stage, **kwargs)
442+
modified = data["stages"].get(stage.name, {}) != stage_data.get(
443+
stage.name, {}
444+
)
445+
if modified:
446+
is_modified = True
447+
if not log_updated:
448+
logger.info("Updating lock file '%s'", self.relpath)
449+
log_updated = True
450+
data["stages"].update(stage_data)
419451

420-
if modified:
452+
if is_modified:
421453
self.repo.scm_context.track_file(self.relpath)
422454

455+
def dump(self, stage, **kwargs):
456+
self.dump_stages([stage], **kwargs)
457+
423458
def remove_stage(self, stage):
424459
if not self.exists():
425460
return

dvc/output.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,25 +1452,43 @@ def restore_fields(self, other: "Output"):
14521452
self.remote = other.remote
14531453
self.can_push = other.can_push
14541454

1455-
def merge_version_meta(self, other: "Output"):
1455+
def _get_versioned_meta(
1456+
self,
1457+
) -> Optional[
1458+
tuple["HashInfo", Optional["Meta"], Optional[Union["HashFile", "Tree"]]]
1459+
]:
1460+
if self.files is not None or (
1461+
self.meta is not None and self.meta.version_id is not None
1462+
):
1463+
old_obj = self.obj if self.obj is not None else self.get_obj()
1464+
return self.hash_info, self.meta, old_obj
1465+
return None
1466+
1467+
def merge_version_meta(
1468+
self,
1469+
old_hi: "HashInfo",
1470+
old_meta: Optional["Meta"],
1471+
old_obj: Optional[Union["HashFile", "Tree"]],
1472+
):
14561473
"""Merge version meta for files which are unchanged from other."""
14571474
if not self.hash_info:
14581475
return
14591476
if self.hash_info.isdir:
1460-
return self._merge_dir_version_meta(other)
1461-
if self.hash_info != other.hash_info:
1477+
return self._merge_dir_version_meta(old_hi, old_obj)
1478+
if self.hash_info != old_hi:
14621479
return
1463-
self.meta = other.meta
1480+
self.meta = old_meta
14641481

1465-
def _merge_dir_version_meta(self, other: "Output"):
1482+
def _merge_dir_version_meta(
1483+
self, old_hi: "HashInfo", old_obj: Optional[Union["HashFile", "Tree"]]
1484+
):
14661485
from dvc_data.hashfile.tree import update_meta
14671486

1468-
if not self.obj or not other.hash_info.isdir:
1487+
if not self.obj or not old_hi.isdir:
14691488
return
1470-
other_obj = other.obj if other.obj is not None else other.get_obj()
14711489
assert isinstance(self.obj, Tree)
1472-
assert isinstance(other_obj, Tree)
1473-
updated = update_meta(self.obj, other_obj)
1490+
assert isinstance(old_obj, Tree)
1491+
updated = update_meta(self.obj, old_obj)
14741492
assert updated.hash_info == self.obj.hash_info
14751493
self.obj = updated
14761494
self.files = updated.as_list(with_meta=True)

dvc/repo/commit.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from itertools import groupby
12
from typing import TYPE_CHECKING
23

34
from dvc import prompt
@@ -52,29 +53,38 @@ def commit(
5253
data_only=False,
5354
relink=True,
5455
):
55-
stages_info = [
56-
info
57-
for info in self.stage.collect_granular(
58-
target, with_deps=with_deps, recursive=recursive
59-
)
60-
if not data_only or info.stage.is_data_source
61-
]
62-
for stage_info in stages_info:
63-
stage = stage_info.stage
64-
if force:
65-
stage.save(allow_missing=allow_missing)
66-
else:
67-
changes = stage.changed_entries()
68-
if any(changes):
69-
prompt_to_commit(stage, changes, force=force)
56+
committed_stages = []
57+
groups = groupby(
58+
[
59+
info
60+
for info in self.stage.collect_granular(
61+
target, with_deps=with_deps, recursive=recursive
62+
)
63+
if not data_only or info.stage.is_data_source
64+
],
65+
key=lambda info: info.stage.dvcfile,
66+
)
67+
68+
for dvcfile, stages_info_group in groups:
69+
to_dump = []
70+
for stage_info in stages_info_group:
71+
stage = stage_info.stage
72+
if force:
7073
stage.save(allow_missing=allow_missing)
71-
stage.commit(
72-
filter_info=stage_info.filter_info,
73-
allow_missing=allow_missing,
74-
relink=relink,
75-
)
76-
stage.dump(update_pipeline=False)
77-
return [s.stage for s in stages_info]
74+
else:
75+
changes = stage.changed_entries()
76+
if any(changes):
77+
prompt_to_commit(stage, changes, force=force)
78+
stage.save(allow_missing=allow_missing)
79+
stage.commit(
80+
filter_info=stage_info.filter_info,
81+
allow_missing=allow_missing,
82+
relink=relink,
83+
)
84+
to_dump.append(stage)
85+
dvcfile.dump_stages(to_dump, update_pipeline=False)
86+
committed_stages.extend(to_dump)
87+
return committed_stages
7888

7989

8090
@locked

dvc/stage/__init__.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -509,34 +509,20 @@ def save_deps(self, allow_missing=False):
509509
if not allow_missing:
510510
raise
511511

512-
def get_versioned_outs(self) -> dict[str, "Output"]:
513-
from .exceptions import StageFileDoesNotExistError, StageNotFound
514-
515-
try:
516-
old = self.reload()
517-
except (StageFileDoesNotExistError, StageNotFound):
518-
return {}
519-
520-
return {
521-
out.def_path: out
522-
for out in old.outs
523-
if out.files is not None
524-
or (out.meta is not None and out.meta.version_id is not None)
525-
}
526-
527512
def save_outs(self, allow_missing: bool = False):
528513
from dvc.output import OutputDoesNotExistError
529514

530-
old_versioned_outs = self.get_versioned_outs()
531515
for out in self.outs:
516+
# old state just before saving so to merge them later
517+
old_state = out._get_versioned_meta()
532518
try:
533519
out.save()
534520
except OutputDoesNotExistError:
535521
if not allow_missing:
536522
raise
537523

538-
if old_out := old_versioned_outs.get(out.def_path):
539-
out.merge_version_meta(old_out)
524+
if old_state:
525+
out.merge_version_meta(*old_state)
540526

541527
def ignore_outs(self) -> None:
542528
for out in self.outs:
@@ -579,8 +565,9 @@ def add_outs(self, filter_info=None, allow_missing: bool = False, **kwargs):
579565
from dvc.output import OutputDoesNotExistError
580566

581567
link_failures = []
582-
old_versioned_outs = self.get_versioned_outs()
583568
for out in self.filter_outs(filter_info):
569+
# old state just before saving so to merge them later
570+
old_state = out._get_versioned_meta()
584571
try:
585572
out.add(filter_info, **kwargs)
586573
except (FileNotFoundError, OutputDoesNotExistError):
@@ -589,8 +576,8 @@ def add_outs(self, filter_info=None, allow_missing: bool = False, **kwargs):
589576
except CacheLinkError:
590577
link_failures.append(filter_info or out.fs_path)
591578

592-
if old_out := old_versioned_outs.get(out.def_path):
593-
out.merge_version_meta(old_out)
579+
if old_state:
580+
out.merge_version_meta(*old_state)
594581

595582
if link_failures:
596583
raise CacheLinkError(link_failures)

tests/func/test_commit.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55

66
from dvc.dependency.base import DependencyDoesNotExistError
7-
from dvc.dvcfile import PROJECT_FILE
7+
from dvc.dvcfile import PROJECT_FILE, Lockfile, ProjectFile, SingleStageFile
88
from dvc.fs import localfs
99
from dvc.output import OutputDoesNotExistError
1010
from dvc.stage.exceptions import StageCommitError
@@ -325,3 +325,51 @@ def test_commit_dos2unix(tmp_dir, dvc):
325325
dvc.commit("foo.dvc", force=True)
326326
content = (tmp_dir / "foo.dvc").read_text()
327327
assert "hash: md5" in content
328+
329+
330+
def test_commit_multiple_files(tmp_dir, dvc, mocker):
331+
tmp_dir.gen({"foo": "foo", "bar": "bar"})
332+
stages = dvc.add(["foo", "bar"], no_commit=True)
333+
test1_stage = dvc.stage.add(name="test", cmd="echo test", deps=["foo"])
334+
test2_stage = dvc.stage.add(name="test2", cmd="echo test2", deps=["foo"])
335+
336+
subdir = tmp_dir / "subdir"
337+
subdir.mkdir()
338+
with subdir.chdir():
339+
bar_relpath = os.path.relpath(tmp_dir / "bar", subdir)
340+
test3_stage = dvc.stage.add(name="test3", cmd="echo test3", deps=[bar_relpath])
341+
342+
pointerfile_spy = mocker.spy(SingleStageFile, "dump_stages")
343+
projectfile_spy = mocker.spy(ProjectFile, "dump_stages")
344+
lockfile_spy = mocker.spy(Lockfile, "dump_stages")
345+
346+
assert set(dvc.commit(force=True)) == {
347+
*stages,
348+
test1_stage,
349+
test2_stage,
350+
test3_stage,
351+
}
352+
pointerfile_spy.assert_has_calls(
353+
[
354+
mocker.call(stages[0].dvcfile, [stages[0]], update_pipeline=False),
355+
mocker.call(stages[1].dvcfile, [stages[1]], update_pipeline=False),
356+
],
357+
any_order=True,
358+
)
359+
projectfile_spy.assert_has_calls(
360+
[
361+
mocker.call(
362+
test1_stage.dvcfile, [test1_stage, test2_stage], update_pipeline=False
363+
),
364+
mocker.call(test3_stage.dvcfile, [test3_stage], update_pipeline=False),
365+
],
366+
any_order=True,
367+
)
368+
lockfile_spy.assert_has_calls(
369+
[
370+
mocker.call(test1_stage.dvcfile._lockfile, [test1_stage, test2_stage]),
371+
mocker.call(test3_stage.dvcfile._lockfile, [test3_stage]),
372+
],
373+
any_order=True,
374+
)
375+
assert dvc.status() == {}

0 commit comments

Comments
 (0)