Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions python/tvm/script/parser/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,8 @@ def _visit(self, node: doc.AST) -> Any:
value = self._eval_slice(fields)
else:
value = self._eval_expr(node.__class__(**fields))
except Exception as e: # pylint: disable=broad-except,invalid-name
self.parser.report_error(node, e)
except Exception as err: # pylint: disable=broad-except
self.parser.report_error(node, err)
return self._add_intermediate_result(value)

def _eval_lambda(self, node: doc.Lambda) -> Any:
Expand All @@ -286,8 +286,8 @@ def _eval_lambda(self, node: doc.Lambda) -> Any:
"""
try:
value = self._eval_expr(node)
except Exception as e: # pylint: disable=broad-except,invalid-name
self.parser.report_error(node, str(e))
except Exception as err: # pylint: disable=broad-except
self.parser.report_error(node, err)
return self._add_intermediate_result(value)

def _eval_bool_op(self, fields: Dict[str, Any]) -> Any:
Expand Down Expand Up @@ -463,8 +463,8 @@ def eval_assign(
"""
try:
return _eval_assign(target, source)
except Exception as e: # pylint: disable=broad-except,invalid-name
parser.report_error(target, f"Failed to evaluate assignment: {str(e)}")
except Exception as err: # pylint: disable=broad-except
parser.report_error(target, err)
raise


Expand Down
19 changes: 10 additions & 9 deletions python/tvm/script/parser/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,10 +307,8 @@ def _dispatch_wrapper(func: dispatch.ParseMethod) -> dispatch.ParseMethod:
def _wrapper(self: "Parser", node: doc.AST) -> None:
try:
return func(self, node)
except DiagnosticError:
raise
except Exception as e: # pylint: disable=broad-except,invalid-name
self.report_error(node, e)
except Exception as err: # pylint: disable=broad-except
self.report_error(node, err)
raise

return _wrapper
Expand Down Expand Up @@ -547,6 +545,12 @@ def report_error(
err: Union[Exception, str]
The error to report.
"""

# If the error is already being raised as a DiagnosticError,
# re-raise it without wrapping it in a DiagnosticContext.
if isinstance(err, DiagnosticError):
raise err

# Only take the last line of the error message
if isinstance(err, TVMError):
msg = list(filter(None, str(err).split("\n")))[-1]
Expand Down Expand Up @@ -595,11 +599,8 @@ def visit(self, node: doc.AST) -> None:
raise NotImplementedError(f"Visitor of AST node is not implemented: {name}")
try:
func(node)
except DiagnosticError:
raise
except Exception as e: # pylint: disable=broad-except,invalid-name
self.report_error(node, str(e))
raise
except Exception as err: # pylint: disable=broad-except
self.report_error(node, err)

def visit_body(self, node: List[doc.stmt]) -> Any:
"""The general body visiting method.
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/script/parser/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,19 +104,19 @@ def eval_struct_info_proxy(self: Parser, node: doc.expr) -> StructInfoProxy:
try:
annotation = self.eval_expr(node)
return _normalize_struct_info_proxy(annotation)
except Exception as err:
self.report_error(node, str(err))
raise err
except Exception as err: # pylint: disable=broad-except
self.report_error(node, err)
raise


def eval_struct_info(self: Parser, node: doc.expr, eval_str: bool = False) -> StructInfo:
var_table = self.var_table.get() if eval_str else None
try:
struct_info = self.eval_expr(node)
return _normalize_struct_info(struct_info, var_table)
except Exception as err:
except Exception as err: # pylint: disable=broad-except
self.report_error(node, err)
raise err
raise


def is_called(node: Any, func_name: str) -> bool:
Expand Down
3 changes: 2 additions & 1 deletion src/ir/diagnostic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ void DiagnosticContext::Render() {
}

if (errs) {
(*this)->renderer = DiagnosticRenderer();
(*this)->renderer = DiagnosticRenderer([](DiagnosticContext) {});
// (*this)->diagnostics.clear();
LOG(FATAL) << "DiagnosticError: one or more error diagnostics were "
<< "emitted, please check diagnostic render for output.";
}
Expand Down
14 changes: 11 additions & 3 deletions tests/python/relax/test_tvmscript_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,15 @@ def f(x: R.Tensor):
return x


def test_incorrect_tensor_shape():
with pytest.raises(tvm.error.DiagnosticError):

@R.function
def f(x: R.Tensor([16])):
y: R.Tensor(16) = R.add(x, x)
return y


def test_simple_module():
@I.ir_module
class TestModule:
Expand Down Expand Up @@ -1045,7 +1054,6 @@ def main(


def test_call_tir_inplace_with_tuple_var_raises_error():

with pytest.raises(tvm.error.DiagnosticError):

@tvm.script.ir_module
Expand Down Expand Up @@ -1838,7 +1846,7 @@ def mul_add(x: R.Tensor) -> R.Tensor:
_check(InputModule, OutputModule)


def test_context_aware_parsing():
def test_context_aware_parsing(monkeypatch):
@tvm.script.ir_module
class Module:
@T.prim_func
Expand All @@ -1863,7 +1871,7 @@ def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32
def _break_env(self, *args):
raise RuntimeError("Fail to pass context-aware parsing")

tvm.ir.GlobalVar.__call__ = _break_env
monkeypatch.setattr(tvm.ir.GlobalVar, "__call__", _break_env)

_check(Module)

Expand Down
8 changes: 5 additions & 3 deletions tests/python/tvmscript/test_tvmscript_printer_highlight.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import tvm.testing
from tvm import relay
from tvm.script import tir as T
from tvm.script.highlight import cprint
from tvm.script.highlight import cprint, _format


def test_highlight_script():
Expand Down Expand Up @@ -58,12 +58,14 @@ def test_cprint():
# Print nodes with `script` method, e.g. PrimExpr
cprint(tvm.tir.Var("v", "int32") + 1)

# Cannot print non-Python-style codes if black installed
# Cannot print non-Python-style codes when using the black
# formatter. This error comes from `_format`, used internally by
# `cprint`, and doesn't occur when using the `ruff` formatter.
try:
import black

with pytest.raises(ValueError):
cprint("if (a == 1) { a +=1; }")
_format("if (a == 1) { a +=1; }", formatter="black")
except ImportError:
pass

Expand Down