Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions include/tvm/runtime/crt/platform.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,24 @@ extern "C" {
*/
void __attribute__((noreturn)) TVMPlatformAbort(int code);

/*! \brief Start a device timer.
*
* The device timer used must not be running.
*
* \return An error code.
*/
int TVMPlatformTimerStart();

/*! \brief Stop the running device timer and get the elapsed time (in microseconds).
*
* The device timer used must be running.
*
* \param res_us Pointer to write elapsed time into.
*
* \return An error code.
*/
int TVMPlatformTimerStop(double* res_us);

#ifdef __cplusplus
} // extern "C"
#endif
Expand Down
14 changes: 10 additions & 4 deletions python/tvm/contrib/debugger/debug_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
_DUMP_PATH_PREFIX = "_tvmdbg_"


def create(graph_json_str, libmod, ctx, dump_root=None):
def create(graph_json_str, libmod, ctx, dump_root=None, number=10, repeat=1, min_repeat_ms=1):
"""Create a runtime executor module given a graph and module.

Parameters
Expand Down Expand Up @@ -72,7 +72,8 @@ def create(graph_json_str, libmod, ctx, dump_root=None):
"config.cmake and rebuild TVM to enable debug mode"
)
func_obj = fcreate(graph_json_str, libmod, *device_type_id)
return GraphModuleDebug(func_obj, ctx, graph_json_str, dump_root)
return GraphModuleDebug(func_obj, ctx, graph_json_str, dump_root,
number=number, repeat=repeat, min_repeat_ms=min_repeat_ms)


class GraphModuleDebug(graph_runtime.GraphModule):
Expand All @@ -99,13 +100,17 @@ class GraphModuleDebug(graph_runtime.GraphModule):
None will make a temp folder in /tmp/tvmdbg<rand_string> and does the dumping
"""

def __init__(self, module, ctx, graph_json_str, dump_root):
def __init__(self, module, ctx, graph_json_str, dump_root,
number, repeat, min_repeat_ms):
self._dump_root = dump_root
self._dump_path = None
self._get_output_by_layer = module["get_output_by_layer"]
self._run_individual = module["run_individual"]
graph_runtime.GraphModule.__init__(self, module)
self._create_debug_env(graph_json_str, ctx)
self.number = number
self.repeat = repeat
self.min_repeat_ms = min_repeat_ms

def _format_context(self, ctx):
return str(ctx[0]).upper().replace("(", ":").replace(")", "")
Expand Down Expand Up @@ -180,7 +185,8 @@ def _run_debug(self):

"""
self.debug_datum._time_list = [
[float(t) * 1e-6] for t in self.run_individual(10, 1, 1)
[float(t) * 1e-6] for t in
self.run_individual(self.number, self.repeat, self.min_repeat_ms)
]
for i, node in enumerate(self.debug_datum.get_graph_nodes()):
num_outputs = self.debug_datum.get_graph_node_output_num(node)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/micro/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""MicroTVM module for bare-metal backends"""

from .artifact import Artifact
from .build import build_static_runtime, DefaultOptions, TVM_ROOT_DIR, Workspace
from .build import build_static_runtime, DefaultOptions, TVM_ROOT_DIR, CRT_ROOT_DIR, Workspace
from .compiler import Compiler, DefaultCompiler, Flasher
from .debugger import GdbRemoteDebugger, RpcDebugger
from .micro_library import MicroLibrary
Expand Down
47 changes: 47 additions & 0 deletions python/tvm/relay/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
configuring the passes and scripting them in Python.
"""
from tvm.ir import IRModule
# TODO(weberlo) remove when we port dtype collectors to C++
from tvm.relay.expr_functor import ExprVisitor
from tvm.relay.type_functor import TypeVisitor

from . import _ffi_api
from .feature import Feature
Expand Down Expand Up @@ -219,6 +222,50 @@ def all_type_vars(expr, mod=None):
return _ffi_api.all_type_vars(expr, use_mod)


class TyDtypeCollector(TypeVisitor):
"""Pass that collects data types used in the visited type."""

def __init__(self):
TypeVisitor.__init__(self)
self.dtypes = set()

def visit_tensor_type(self, tt):
self.dtypes.add(tt.dtype)


class ExprDtypeCollector(ExprVisitor):
"""Pass that collects data types used in all types in the visited expression."""

def __init__(self):
ExprVisitor.__init__(self)
self.ty_visitor = TyDtypeCollector()

def visit(self, expr):
if hasattr(expr, 'checked_type'):
self.ty_visitor.visit(expr.checked_type)
elif hasattr(expr, 'type_annotation'):
self.ty_visitor.visit(expr.type_annotation)
ExprVisitor.visit(self, expr)


def all_dtypes(expr):
"""Collect set of all data types used in `expr`.

Parameters
----------
expr : tvm.relay.Expr
The input expression

Returns
-------
ret : Set[String]
Set of data types used in the expression
"""
dtype_collector = ExprDtypeCollector()
dtype_collector.visit(expr)
return dtype_collector.ty_visitor.dtypes


def collect_device_info(expr):
"""Collect the device allocation map for the given expression. The device
ids are propagated from the `device_copy` operators.
Expand Down
Loading