Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tritonbench/utils/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
DEFAULT_WARMUP = 25
DEFAULT_REP = 100
DEFAULT_QUANTILES = [0.5, 0.1, 0.9]
DEFAULT_SLEEP = 0.0
2 changes: 1 addition & 1 deletion tritonbench/utils/parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse

from tritonbench.utils.env_utils import AVAILABLE_PRECISIONS, is_fbcode
from tritonbench.utils.triton_op import DEFAULT_REP, DEFAULT_WARMUP
from tritonbench.utils.constants import DEFAULT_REP, DEFAULT_WARMUP


def get_parser(args=None):
Expand Down
68 changes: 57 additions & 11 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import sys
import tempfile
import time
import types

from collections import defaultdict, OrderedDict
from dataclasses import asdict, dataclass, fields
Expand All @@ -35,6 +36,7 @@
)
from tritonbench.components.export import export_data

from tritonbench.utils.constants import (DEFAULT_WARMUP,DEFAULT_REP,DEFAULT_QUANTILES,DEFAULT_SLEEP)
from tritonbench.utils.env_utils import (
apply_precision,
is_fbcode,
Expand All @@ -43,6 +45,7 @@
set_random_seed,
)
from tritonbench.utils.input import input_cast
from tritonbench.utils.parser import get_parser
from tritonbench.utils.path_utils import add_cmd_parameter, remove_cmd_parameter

if is_hip():
Expand Down Expand Up @@ -77,11 +80,6 @@ class BenchmarkOperatorBackend:
# ci = False implies enabled = False
ci: bool = True


DEFAULT_WARMUP = 25
DEFAULT_REP = 100
DEFAULT_QUANTILES = [0.5, 0.1, 0.9]
DEFAULT_SLEEP = 0.0
REGISTERED_BENCHMARKS: Dict[str, OrderedDict[str, BenchmarkOperatorBackend]] = {}
REGISTERED_METRICS: defaultdict[str, List[str]] = defaultdict(list)
OVERRIDDEN_METRICS: defaultdict[str, List[str]] = defaultdict(list)
Expand Down Expand Up @@ -590,10 +588,11 @@ def register_benchmark(
label: Optional[str] = None,
):
def decorator(function):

op_name = (
_find_op_name_from_module_path(function.__module__)
if not operator_name
else operator_name
operator_name
if operator_name
else _find_op_name_from_module_path(function.__module__)
)
fn_name = function.__name__ if not func_name else func_name
backend_config = BenchmarkOperatorBackend(
Expand Down Expand Up @@ -667,6 +666,11 @@ def _has_and_true(attr):
if _has_and_true("fwd_no_grad"):
tb_args.mode = "fwd_no_grad"

def override_args(args_to_override):
parser = get_parser()
tb_args, extra_args = parser.parse_known_args(args_to_override)
return tb_args, extra_args


class BenchmarkOperator(metaclass=PostInitProcessor):
mode: Mode = Mode.FWD
Expand All @@ -692,11 +696,19 @@ class BenchmarkOperator(metaclass=PostInitProcessor):
"""

def __init__(
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
self, tb_args: argparse.Namespace = None, extra_args: Optional[List[str]] = None
):
set_env()
set_random_seed()
self.name = _find_op_name_from_module_path(self.__class__.__module__)
if extra_args and not tb_args:
tb_args, extra_args = override_args(extra_args)
elif not tb_args:
raise ValueError('no args selected. Either pass in argparse namespace or give list override')

if tb_args.benchmark_name:
self.name = tb_args.benchmark_name
else:
self.name = _find_op_name_from_module_path(self.__class__.__module__)
self._raw_extra_args = copy.deepcopy(extra_args)
self.tb_args = tb_args
self.add_production_shapes = (
Expand Down Expand Up @@ -807,6 +819,39 @@ def fwd_no_grad_fn():

setattr(fwd_no_grad_fn, "_name", bm_func_name)
return fwd_no_grad_fn

def set_input_iter(self, input_iter: Callable):
def input_decorator(input_iter):
def input_callable(self):
return input_iter()
return input_callable
self.get_input_iter = input_decorator(input_iter)
self.get_input_iter = input_decorator(input_iter).__get__(self, BenchmarkOperator)
self.input_iter = input_iter
self._available_num_inputs = sum(1 for _ in self.get_input_iter())
self._num_inputs = self._available_num_inputs - self._input_id

def add_benchmark(
self,
bm_callable: Callable,
operator_name: Optional[str] = None,
func_name: Optional[str] = None,
baseline: bool = False,
fwd_only: bool = False,
label: Optional[str] = None
) -> None:
decorator_kwargs = {
"operator_name":operator_name or self.name,
"func_name":func_name,
"enabled":True,
"baseline":baseline,
"fwd_only":fwd_only,
"label":label
}
decorated_func = register_benchmark(**decorator_kwargs)(bm_callable)
bound_method = types.MethodType(decorated_func, self)
setattr(self, func_name or bm_callable.__name__, bound_method)
REGISTERED_BENCHMARKS[func_name] = bm_callable

def run(
self,
Expand Down Expand Up @@ -959,9 +1004,10 @@ def get_bwd_fn(self, fwd_fn: Callable) -> Callable:

def get_input_iter(self) -> Generator:
"""Return the dynamic input iterator for the model."""
raise NotImplementedError(
logger.warning(
"Each operator must implement its own input iterator."
)
return []

def get_grad_to_none(self, args):
return None
Expand Down
Loading