Skip to content

Commit 62fbad2

Browse files
KenoJeff Bezanson
andcommitted
Use loop inverted loop lowering
The primary idea of the new iteration protocol is that for a function like: ``` function iterate(itr) done(itr) ? nothing : next(itr) end ``` we can fuse the `done` comparison into the loop condition and recover the same loop structure we had before (while retaining the flexibility of not requiring the done function to be separate), i.e. for ``` y = iterate(itr) y === nothing && break ``` we want to have after inlining and early optimization: ``` done(itr) && break y = next(itr) ``` LLVM performs this optimization in jump threading. However, we run into a problem. At the top of the loop we have: ``` y = iterate top: %cond = y === nothing br i1 %cond, %exit, %loop .... ``` We'd want to thread over the `top` block (this makes sense, since by the discussion above, we need to merge our condition into the loop exit condition). However, LLVM (quite sensibly) refuses to thread over loop headers and since `top` is both a loop header and a loop exit, we fail to perform the appropriate transformation. However, there's a simple fix. Instead of emitting a foor loop as ``` y = iterate(itr) while y !== nothing x, state = y ... y = iterate(itr, state) end ``` we can emit it as ``` y = iterate(itr) if y !== nothing while true x, state = y ... y = iterate(itr, state) y === nothing && break end end ``` This transformation is known as `loop inversion` (or a special case of `loop rotation`. In our case the primary benefit is that we can fuse the condition contained in the initial `iterate` call into the bypass if, which then lets LLVM understand our loop structure. Co-authored-by: Jeff Bezanson <[email protected]>
1 parent 1a1d6b6 commit 62fbad2

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

src/codegen.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6507,7 +6507,7 @@ static std::unique_ptr<Module> emit_function(
65076507
};
65086508

65096509
// Codegen Phi nodes
6510-
std::map<BasicBlock *, BasicBlock*> BB_rewrite_map;
6510+
std::map<std::pair<BasicBlock *, BasicBlock*>, BasicBlock*> BB_rewrite_map;
65116511
std::vector<llvm::PHINode*> ToDelete;
65126512
for (auto &tup : ctx.PhiNodes) {
65136513
jl_cgval_t phi_result;
@@ -6526,8 +6526,9 @@ static std::unique_ptr<Module> emit_function(
65266526
Value *V = NULL;
65276527
BasicBlock *IncomingBB = come_from_bb[edge];
65286528
BasicBlock *FromBB = IncomingBB;
6529-
if (BB_rewrite_map.count(FromBB)) {
6530-
FromBB = BB_rewrite_map[IncomingBB];
6529+
std::pair<BasicBlock *, BasicBlock*> LookupKey(IncomingBB, PhiBB);
6530+
if (BB_rewrite_map.count(LookupKey)) {
6531+
FromBB = BB_rewrite_map[LookupKey];
65316532
}
65326533
#ifndef JL_NDEBUG
65336534
bool found_pred = false;
@@ -6681,7 +6682,7 @@ static std::unique_ptr<Module> emit_function(
66816682
// Check any phi nodes in the Phi block to see if by splitting the edges,
66826683
// we made things inconsistent
66836684
if (FromBB != ctx.builder.GetInsertBlock()) {
6684-
BB_rewrite_map[IncomingBB] = ctx.builder.GetInsertBlock();
6685+
BB_rewrite_map[LookupKey] = ctx.builder.GetInsertBlock();
66856686
for (BasicBlock::iterator I = PhiBB->begin(); isa<PHINode>(I); ++I) {
66866687
PHINode *PN = cast<PHINode>(I);
66876688
ssize_t BBIdx = PN->getBasicBlockIndex(FromBB);

src/julia-syntax.scm

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,10 +1606,11 @@
16061606
;; TODO avoid `local declared twice` error from this
16071607
;;,@(if outer `((local ,lhs)) '())
16081608
,@(if outer `((require-existing-local ,lhs)) '())
1609-
(_while
1610-
(call (|.| (core Intrinsics) 'not_int) (call (core ===) ,next (null)))
1611-
(block ,body
1612-
(= ,next (call (top iterate) ,coll ,state)))))))))))
1609+
(if (call (top not_int) (call (core ===) ,next (null)))
1610+
(_do_while
1611+
(block ,body
1612+
(= ,next (call (top iterate) ,coll ,state)))
1613+
(call (|.| (core Intrinsics) 'not_int) (call (core ===) ,next (null))))))))))))
16131614

16141615
;; wrap `expr` in a function appropriate for consuming values from given ranges
16151616
(define (func-for-generator-ranges expr range-exprs flat outervars)
@@ -3644,6 +3645,14 @@ f(x) = yt(x)
36443645
(compile (caddr e) break-labels #f #f)
36453646
(emit `(goto ,topl))
36463647
(mark-label endl)))
3648+
((_do_while)
3649+
(let* ((endl (make-label))
3650+
(topl (make&mark-label)))
3651+
(compile (cadr e) break-labels #f #f)
3652+
(let ((test (compile-cond (caddr e) break-labels)))
3653+
(emit `(gotoifnot ,test ,endl)))
3654+
(emit `(goto ,topl))
3655+
(mark-label endl)))
36473656
((break-block)
36483657
(let ((endl (make-label)))
36493658
(begin0 (compile (caddr e)

0 commit comments

Comments
 (0)