Skip to content

Commit 954241f

Browse files
committed
profile each benchmark individually for cleaner traces
1 parent d754094 commit 954241f

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

src/libkernelbot/run_eval.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import dataclasses
23
import datetime
34
import functools
@@ -678,12 +679,13 @@ def run_pytorch_script( # noqa: C901
678679

679680

680681
class _EvalRunner(Protocol):
681-
def __call__(self, mode: str) -> EvalResult: ...
682+
def __call__(self, mode: str, **kwargs) -> EvalResult: ...
682683

683684

684685
def run_evaluation(
685686
call: _EvalRunner,
686687
mode: str,
688+
common_args: dict,
687689
) -> dict[str, EvalResult]:
688690
"""
689691
Given a "runner" function `call`, interprets the mode
@@ -693,22 +695,28 @@ def run_evaluation(
693695
require multiple runner calls.
694696
"""
695697
results: dict[str, EvalResult] = {}
696-
if mode in ["test", "benchmark", "profile"]:
697-
results[mode] = call(mode=mode)
698+
if mode == "profile":
699+
benchmarks = copy.deepcopy(common_args["benchmarks"])
700+
for i, benchmark in enumerate(benchmarks.splitlines()):
701+
common_args["benchmarks"] = benchmark
702+
results[f"{mode}.{i}"] = call(mode=mode, **common_args)
703+
704+
elif mode in ["test", "benchmark"]:
705+
results[mode] = call(mode=mode, **common_args)
698706
elif mode in ["private", "leaderboard"]:
699707
# first, run the tests
700-
results["test"] = call(mode="test")
708+
results["test"] = call(mode="test", **common_args)
701709

702710
if not results["test"].run or not results["test"].run.passed:
703711
return results
704712

705-
results["benchmark"] = call(mode="benchmark")
713+
results["benchmark"] = call(mode="benchmark", **common_args)
706714

707715
if not results["benchmark"].run or not results["benchmark"].run.passed:
708716
return results
709717

710718
# if they pass, run the leaderboard validation
711-
results["leaderboard"] = call(mode="leaderboard")
719+
results["leaderboard"] = call(mode="leaderboard", **common_args)
712720
else:
713721
raise AssertionError("Invalid mode")
714722

@@ -742,8 +750,7 @@ def run_config(config: dict):
742750
runner = functools.partial(
743751
run_pytorch_script,
744752
sources=config["sources"],
745-
main=config["main"],
746-
**common_args,
753+
main=config["main"]
747754
)
748755
elif config["lang"] == "cu":
749756
runner = functools.partial(
@@ -755,10 +762,9 @@ def run_config(config: dict):
755762
include_dirs=config.get("include_dirs", []),
756763
libraries=config.get("libraries", []),
757764
flags=CUDA_FLAGS,
758-
**common_args,
759765
)
760766
else:
761767
raise ValueError(f"Invalid language {config['lang']}")
762768

763-
results = run_evaluation(runner, config["mode"])
769+
results = run_evaluation(runner, config["mode"], common_args)
764770
return FullResult(success=True, error="", runs=results, system=system)

0 commit comments

Comments
 (0)