@@ -1256,76 +1256,184 @@ function cfg_simplify!(ir::IRCode)
12561256    return  finish (compact)
12571257end 
12581258
1259- function  is_allocation (stmt)
1259+ #  function is_known_fcall(stmt::Expr, @nospecialize(func))
1260+ #      isexpr(stmt, :foreigncall) || return false
1261+ #      s = stmt.args[1]
1262+ #      isa(s, QuoteNode) && (s = s.value)
1263+ #      return s === func
1264+ #  end
1265+ 
1266+ function  is_known_fcall (stmt:: Expr , funcs:: Vector{Symbol} )
12601267    isexpr (stmt, :foreigncall ) ||  return  false 
12611268    s =  stmt. args[1 ]
12621269    isa (s, QuoteNode) &&  (s =  s. value)
1263-     return  s ===  :jl_alloc_array_1d 
1270+     #  return any(e -> s === e, funcs)
1271+     return  true  in  map (e ->  s ===  e, funcs)
1272+ end 
1273+ 
1274+ function  is_allocation (stmt:: Expr )
1275+     isexpr (stmt, :foreigncall ) ||  return  false 
1276+     s =  stmt. args[1 ]
1277+     isa (s, QuoteNode) &&  (s =  s. value)
1278+     return  (s ===  :jl_alloc_array_1d 
1279+          ||  s ===  :jl_alloc_array_2d 
1280+          ||  s ===   :jl_alloc_array_3d 
1281+          ||  s ===  :jl_new_array )
12641282end 
12651283
12661284function  memory_opt! (ir:: IRCode )
12671285    compact =  IncrementalCompact (ir, false )
12681286    uses =  IdDict {Int, Vector{Int}} ()
1269-     relevant =  IdSet {Int} ()
1270-     revisit =  Int[]
1271-     function  mark_val (val)
1287+     relevant =  IdSet {Int} () #  allocations
1288+     revisit =  Int[] #  potential targets for a mutating_arrayfreeze drop-in
1289+     maybecopies =  Int[] #  calls to maybecopy
1290+ 
1291+     function  mark_escape (@nospecialize  val)
12721292        isa (val, SSAValue) ||  return 
1293+         # println(val.id, " escaped.")
12731294        val. id in  relevant &&  pop! (relevant, val. id)
12741295    end 
1296+ 
1297+     function  mark_use (val, idx)
1298+         isa (val, SSAValue) ||  return  
1299+         id =  val. id
1300+         id in  relevant ||  return 
1301+         (haskey (uses, id)) ||  (uses[id] =  Int[])
1302+         push! (uses[id], idx)
1303+     end 
1304+ 
12751305    for  ((_, idx), stmt) in  compact
1306+ 
1307+         # println("idx: ", idx, " = ", stmt)
1308+ 
12761309        if  isa (stmt, ReturnNode)
12771310            isdefined (stmt, :val ) ||  continue 
12781311            val =  stmt. val
1279-             if  isa (val, SSAValue) &&  val. id in  relevant
1280-                 (haskey (uses, val. id)) ||  (uses[val. id] =  Int[])
1281-                 push! (uses[val. id], idx)
1282-             end 
1312+             mark_use (val, idx)
12831313            continue 
1314+ 
1315+         #  check for phinodes that are possibly allocations
1316+         elseif  isa (stmt, PhiNode)
1317+ 
1318+             #  ensure all of the phinode values are defined
1319+             defined =  true 
1320+             for  i =  1 : length (stmt. values)
1321+                 if  ! isassigned (stmt. values, i)
1322+                     defined =  false 
1323+                 end 
1324+             end 
1325+ 
1326+             defined ||  continue 
1327+ 
1328+             for  val in  stmt. values
1329+                 if  isa (val, SSAValue) &&  val. id in  relevant
1330+                     push! (relevant, idx)
1331+                 end 
1332+             end 
12841333        end 
1334+ 
12851335        (isexpr (stmt, :call ) ||  isexpr (stmt, :foreigncall )) ||  continue 
1336+ 
1337+         if  is_known_call (stmt, Core. maybecopy, compact)
1338+             push! (maybecopies, idx)
1339+             continue 
1340+         end 
1341+ 
12861342        if  is_allocation (stmt)
12871343            push! (relevant, idx)
12881344            #  TODO : Mark everything else here
12891345            continue 
12901346        end 
1291-          #   TODO : Replace this by interprocedural escape analysis 
1292-         if  is_known_call (stmt, arrayset, compact)
1347+ 
1348+         if  is_known_call (stmt, arrayset, compact)  &&   length (stmt . args)  >=   5 
12931349            #  The value being set escapes, everything else doesn't
1294-             mark_val (stmt. args[4 ])
1350+             mark_escape (stmt. args[4 ])
12951351            arr =  stmt. args[3 ]
1296-             if  isa (arr, SSAValue) &&  arr. id in  relevant
1297-                 (haskey (uses, arr. id)) ||  (uses[arr. id] =  Int[])
1298-                 push! (uses[arr. id], idx)
1299-             end 
1352+             mark_use (arr, idx)
1353+ 
1354+         elseif  is_known_call (stmt, arrayref, compact) &&  length (stmt. args) ==  4 
1355+             arr =  stmt. args[3 ]
1356+             mark_use (arr, idx)
1357+ 
1358+         elseif  is_known_call (stmt, setindex!, compact) &&  length (stmt. args) ==  4 
1359+             #  handle similarly to arrayset
1360+             val =  stmt. args[3 ]
1361+             mark_escape (val)
1362+ 
1363+             arr =  stmt. args[2 ]
1364+             mark_use (arr, idx)
1365+ 
1366+         elseif  is_known_call (stmt, (=== ), compact) &&  length (stmt. args) ==  3 
1367+             arr1 =  stmt. args[2 ]
1368+             arr2 =  stmt. args[3 ]
1369+ 
1370+             mark_use (arr1, idx)
1371+             mark_use (arr2, idx)
1372+ 
1373+         #  these foreigncalls have similar structure and don't escape our array, so handle them all at once
1374+         elseif  is_known_fcall (stmt, [:jl_array_ptr , :jl_array_copy ]) &&  length (stmt. args) ==  6 
1375+             arr =  stmt. args[6 ]
1376+             mark_use (arr, idx)
1377+ 
1378+         elseif  is_known_call (stmt, arraysize, compact) &&  isa (stmt. args[2 ], SSAValue)
1379+             arr =  stmt. args[2 ]
1380+             mark_use (arr, idx)
1381+ 
13001382        elseif  is_known_call (stmt, Core. arrayfreeze, compact) &&  isa (stmt. args[2 ], SSAValue)
1383+             #  mark these for potential replacement with mutating_arrayfreeze
13011384            push! (revisit, idx)
1385+ 
13021386        else 
1303-             #  For now we assume everything escapes
1304-             #  TODO : We could handle PhiNodes specially and improve this
1387+             #  Assume everything else escapes
13051388            for  ur in  userefs (stmt)
1306-                 mark_val (ur[])
1389+                 mark_escape (ur[])
13071390            end 
13081391        end 
13091392    end 
1393+ 
13101394    ir =  finish (compact)
1311-     isempty (revisit) &&  return  ir
1395+     isempty (revisit) &&  isempty (maybecopies) &&  return  ir
1396+ 
13121397    domtree =  construct_domtree (ir. cfg. blocks)
1398+ 
13131399    for  idx in  revisit
13141400        #  Make sure that the value we reference didn't escape
1315-         id =  ir. stmts[idx][:inst ]. args[2 ]. id
1401+         stmt =  ir. stmts[idx][:inst ]:: Expr 
1402+         id =  (stmt. args[2 ]:: SSAValue ). id
13161403        (id in  relevant) ||  continue 
13171404
1405+         # println("Revisiting ", stmt)
1406+ 
13181407        #  We're ok to steal the memory if we don't dominate any uses
13191408        ok =  true 
1320-         for  use in  uses[id]
1321-             if  ssadominates (ir, domtree, idx, use)
1322-                 ok =  false 
1323-                 break 
1409+         if  haskey (uses, id)
1410+             for  use in  uses[id]
1411+                 if  ssadominates (ir, domtree, idx, use)
1412+                     ok =  false 
1413+                     break 
1414+                 end 
13241415            end 
13251416        end 
13261417        ok ||  continue 
1327- 
1328-         ir. stmts[idx][:inst ]. args[1 ] =  Core. mutating_arrayfreeze
1418+         stmt. args[1 ] =  Core. mutating_arrayfreeze
13291419    end 
1420+ 
1421+     #  TODO : Use escape analysis info to determine if maybecopy should copy
1422+ 
1423+     #  for idx in maybecopies
1424+     #      stmt = ir.stmts[idx][:inst]::Expr
1425+     #      #println(stmt.args)
1426+     #      arr = stmt.args[2]
1427+     #      id = isa(arr, SSAValue) ? arr.id : arr.n # SSAValue or Core.Argument
1428+ 
1429+     #      if (id in relevant) # didn't escape elsewhere, so make a copy to keep it un-escaped
1430+     #          #println("didn't escape maybecopy")
1431+     #          stmt.args[1] = Main.Base.copy
1432+     #      else # already escaped, so save the cost of copying and just pass the actual object
1433+     #          #println("escaped maybecopy")
1434+     #          ir.stmts[idx][:inst] = arr
1435+     #      end
1436+     #  end
1437+ 
13301438    return  ir
13311439end 
0 commit comments