Skip to content

Commit ae7d9db

Browse files
[Codegen] Fix if_then_else codegen (#16242)
* fix * lint * fix while loop * clean * clean --------- Co-authored-by: Junru Shao <[email protected]>
1 parent 5308739 commit ae7d9db

File tree

1 file changed

+31
-8
lines changed

1 file changed

+31
-8
lines changed

src/target/source/codegen_c.cc

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -631,13 +631,33 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*)
631631
} else if (op->op.same_as(builtin::shift_right())) {
632632
PrintBinaryIntrinsic(op, " >> ", os, this);
633633
} else if (op->op.same_as(builtin::if_then_else())) {
634-
os << "(";
635-
PrintExpr(op->args[0], os);
636-
os << " ? ";
637-
PrintExpr(op->args[1], os);
638-
os << " : ";
639-
PrintExpr(op->args[2], os);
640-
os << ")";
634+
// conditional that skips eval if cond evals to false
635+
std::string result = name_supply_->FreshName("condval");
636+
std::string cond = PrintExpr(op->args[0]);
637+
this->PrintIndent();
638+
PrintType(op->dtype, this->stream);
639+
this->stream << " " << result << ";\n";
640+
this->PrintIndent();
641+
this->stream << "if (" << cond << ") {\n";
642+
{
643+
int then_scope = this->BeginScope();
644+
std::string true_val = PrintExpr(op->args[1]);
645+
this->PrintIndent();
646+
this->stream << result << " = " << true_val << ";\n";
647+
this->EndScope(then_scope);
648+
this->PrintIndent();
649+
this->stream << "} else {\n";
650+
}
651+
{
652+
int else_scope = this->BeginScope();
653+
std::string false_val = PrintExpr(op->args[2]);
654+
this->PrintIndent();
655+
this->stream << result << " = " << false_val << ";\n";
656+
this->EndScope(else_scope);
657+
this->PrintIndent();
658+
this->stream << "}\n";
659+
}
660+
os << result;
641661
} else if (op->op.same_as(builtin::address_of())) {
642662
const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
643663
ICHECK(op->args.size() == 1 && load);
@@ -1022,8 +1042,11 @@ void CodeGenC::VisitStmt_(const ForNode* op) {
10221042

10231043
void CodeGenC::VisitStmt_(const WhileNode* op) {
10241044
PrintIndent();
1025-
stream << "while (" << PrintExpr(op->condition) << ") {\n";
1045+
stream << "while (1) {\n";
10261046
int while_scope = BeginScope();
1047+
std::string cond = PrintExpr(op->condition);
1048+
PrintIndent();
1049+
stream << "if (!(" << cond << ")) { break; }\n";
10271050
PrintStmt(op->body);
10281051
this->EndScope(while_scope);
10291052
PrintIndent();

0 commit comments

Comments
 (0)