1+ import copy
12import dataclasses
23import datetime
34import functools
@@ -678,12 +679,13 @@ def run_pytorch_script( # noqa: C901
678679
679680
680681class _EvalRunner (Protocol ):
681- def __call__ (self , mode : str ) -> EvalResult : ...
682+ def __call__ (self , mode : str , ** kwargs ) -> EvalResult : ...
682683
683684
684685def run_evaluation (
685686 call : _EvalRunner ,
686687 mode : str ,
688+ common_args : dict ,
687689) -> dict [str , EvalResult ]:
688690 """
689691 Given a "runner" function `call`, interprets the mode
@@ -693,22 +695,28 @@ def run_evaluation(
693695 require multiple runner calls.
694696 """
695697 results : dict [str , EvalResult ] = {}
696- if mode in ["test" , "benchmark" , "profile" ]:
697- results [mode ] = call (mode = mode )
698+ if mode == "profile" :
699+ benchmarks = copy .deepcopy (common_args ["benchmarks" ])
700+ for i , benchmark in enumerate (benchmarks .splitlines ()):
701+ common_args ["benchmarks" ] = benchmark
702+ results [f"{ mode } .{ i } " ] = call (mode = mode , ** common_args )
703+
704+ elif mode in ["test" , "benchmark" ]:
705+ results [mode ] = call (mode = mode , ** common_args )
698706 elif mode in ["private" , "leaderboard" ]:
699707 # first, run the tests
700- results ["test" ] = call (mode = "test" )
708+ results ["test" ] = call (mode = "test" , ** common_args )
701709
702710 if not results ["test" ].run or not results ["test" ].run .passed :
703711 return results
704712
705- results ["benchmark" ] = call (mode = "benchmark" )
713+ results ["benchmark" ] = call (mode = "benchmark" , ** common_args )
706714
707715 if not results ["benchmark" ].run or not results ["benchmark" ].run .passed :
708716 return results
709717
710718 # if they pass, run the leaderboard validation
711- results ["leaderboard" ] = call (mode = "leaderboard" )
719+ results ["leaderboard" ] = call (mode = "leaderboard" , ** common_args )
712720 else :
713721 raise AssertionError ("Invalid mode" )
714722
@@ -742,8 +750,7 @@ def run_config(config: dict):
742750 runner = functools .partial (
743751 run_pytorch_script ,
744752 sources = config ["sources" ],
745- main = config ["main" ],
746- ** common_args ,
753+ main = config ["main" ]
747754 )
748755 elif config ["lang" ] == "cu" :
749756 runner = functools .partial (
@@ -755,10 +762,9 @@ def run_config(config: dict):
755762 include_dirs = config .get ("include_dirs" , []),
756763 libraries = config .get ("libraries" , []),
757764 flags = CUDA_FLAGS ,
758- ** common_args ,
759765 )
760766 else :
761767 raise ValueError (f"Invalid language { config ['lang' ]} " )
762768
763- results = run_evaluation (runner , config ["mode" ])
769+ results = run_evaluation (runner , config ["mode" ], common_args )
764770 return FullResult (success = True , error = "" , runs = results , system = system )
0 commit comments