Skip to content

Commit 267f0f6

Browse files
Handle promoted traced numbers (#1247)
* Handle promoted traced numbers * fix mlir ownership * fix * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent d12a1fc commit 267f0f6

File tree

5 files changed

+59
-27
lines changed

5 files changed

+59
-27
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
991991
) where {F,tt}
992992
blockdim = CUDA.CuDim3(blocks)
993993
threaddim = CUDA.CuDim3(threads)
994+
mod = MLIR.IR.mmodule()
994995

995996
if convert == Val(true)
996997
args = recudaconvert.(args)
@@ -1021,7 +1022,6 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
10211022
end
10221023

10231024
sym_name = String(gensym("call_$fname"))
1024-
mod = MLIR.IR.mmodule()
10251025
CConv = MLIR.IR.Attribute(
10261026
MLIR.API.mlirLLVMCConvAttrGet(ctx, MLIR.API.MlirLLVMCConvPTX_Kernel)
10271027
)

src/Ops.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ end
104104
end
105105

106106
value = MLIR.IR.DenseElementsAttribute(x)
107-
constants = constant_context()[2]
107+
constant_blk, constants = constant_context()
108+
parent = MLIR.IR.parent_op(constant_blk)
109+
@assert MLIR.IR.name(parent) != "builtin.module"
108110
if haskey(constants, value)
109111
return constants[value]
110112
else
@@ -137,6 +139,7 @@ end
137139
@noinline function constant(
138140
x::AbstractArray{T,N}; location=mlir_stacktrace("constant", @__FILE__, @__LINE__)
139141
) where {T,N}
142+
@assert !(x isa TracedRArray)
140143
return constant(collect(x); location)
141144
end
142145

src/TracedUtils.jl

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -280,12 +280,9 @@ function make_mlir_fn(
280280
return mlir_fn_res
281281
end
282282

283-
seen_args = OrderedIdDict()
284-
285-
(; N, traced_args, linear_args, inv_map, in_tys, sym_visibility, mod, traced_args_to_shardings, func, fnbody) = prepare_mlir_fn_args(
283+
(; N, traced_args, linear_args, inv_map, in_tys, sym_visibility, mod, traced_args_to_shardings, func, fnbody, seen_args, skipped_args) = prepare_mlir_fn_args(
286284
args,
287285
name,
288-
seen_args,
289286
concretein,
290287
toscalar,
291288
argprefix,
@@ -327,14 +324,12 @@ function make_mlir_fn(
327324
end
328325
end
329326

330-
seen_results = OrderedIdDict()
331-
332327
(func2, traced_result, ret, linear_args, in_tys, linear_results, num_partitions, is_sharded, unique_meshes, mutated_args, global_device_ids) = finalize_mlir_fn(
333328
result,
334329
traced_args,
335330
linear_args,
331+
skipped_args,
336332
seen_args,
337-
seen_results,
338333
fnbody,
339334
func,
340335
mod,
@@ -388,7 +383,6 @@ end
388383
function prepare_mlir_fn_args(
389384
args,
390385
name,
391-
seen_args,
392386
concretein,
393387
toscalar,
394388
argprefix,
@@ -406,16 +400,35 @@ function prepare_mlir_fn_args(
406400
else
407401
Reactant.TracedSetPath
408402
end
409-
for i in 1:N
410-
@inbounds traced_args[i] = Reactant.make_tracer(
411-
seen_args, args[i], (argprefix, i), inmode; toscalar, runtime
412-
)
403+
fnbody = MLIR.IR.Block(MLIR.IR.Type[], MLIR.IR.Location[])
404+
MLIR.IR.activate!(fnbody)
405+
Ops.activate_constant_context!(fnbody)
406+
seen_args0 = OrderedIdDict()
407+
try
408+
for i in 1:N
409+
@inbounds traced_args[i] = Reactant.make_tracer(
410+
seen_args0, args[i], (argprefix, i), inmode; toscalar, runtime
411+
)
412+
end
413+
finally
414+
MLIR.IR.deactivate!(fnbody)
415+
Ops.deactivate_constant_context!(fnbody)
413416
end
414417

418+
seen_args = OrderedIdDict()
415419
linear_args = Reactant.TracedType[]
420+
skipped_args = Reactant.TracedType[]
416421
inv_map = IdDict()
417-
for (k, v) in seen_args
422+
for (k, v) in seen_args0
418423
v isa Reactant.TracedType || continue
424+
arg = get_mlir_data(v)
425+
if (arg isa MLIR.IR.Value) &&
426+
MLIR.IR.is_op_res(arg) &&
427+
MLIR.IR.block(MLIR.IR.op_owner(arg)) == fnbody
428+
push!(skipped_args, v)
429+
continue
430+
end
431+
seen_args[k] = v
419432
push!(linear_args, v)
420433
inv_map[v] = k
421434
end
@@ -468,7 +481,7 @@ function prepare_mlir_fn_args(
468481
end
469482

470483
arglocs = MLIR.IR.Location[]
471-
for arg in linear_args
484+
for (i, arg) in enumerate(linear_args)
472485
path = get_idx(arg, argprefix)
473486
stridx = if verify_arg_names isa Nothing
474487
"arg" * string(path[2])
@@ -490,9 +503,12 @@ function prepare_mlir_fn_args(
490503
aval = getfield(aval, idx)
491504
end
492505
end
493-
push!(arglocs, MLIR.IR.Location(stridx * " (path=$path)", MLIR.IR.Location()))
506+
MLIR.IR.push_argument!(
507+
fnbody,
508+
in_tys[i];
509+
location=MLIR.IR.Location(stridx * " (path=$path)", MLIR.IR.Location()),
510+
)
494511
end
495-
fnbody = MLIR.IR.Block(in_tys, arglocs)
496512
push!(MLIR.IR.region(func, 1), fnbody)
497513

498514
return (;
@@ -506,6 +522,8 @@ function prepare_mlir_fn_args(
506522
traced_args_to_shardings,
507523
func,
508524
fnbody,
525+
seen_args,
526+
skipped_args,
509527
)
510528
end
511529

@@ -533,8 +551,8 @@ function finalize_mlir_fn(
533551
result,
534552
traced_args,
535553
linear_args,
554+
skipped_args,
536555
seen_args,
537-
seen_results,
538556
fnbody,
539557
func,
540558
mod,
@@ -578,6 +596,7 @@ function finalize_mlir_fn(
578596
Reactant.TracedTrack
579597
end
580598

599+
seen_results = OrderedIdDict()
581600
MLIR.IR.activate!(fnbody)
582601
traced_result = try
583602
traced_result = Reactant.make_tracer(
@@ -602,6 +621,9 @@ function finalize_mlir_fn(
602621
linear_results = Reactant.TracedType[]
603622
for (k, v) in seen_results
604623
v isa Reactant.TracedType || continue
624+
if any(Base.Fix1(===, k), skipped_args)
625+
continue
626+
end
605627
if args_in_result != :all
606628
if has_idx(v, argprefix)
607629
if !(

src/Tracing.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,12 +1128,19 @@ Base.@nospecializeinfer function make_tracer_unknown(
11281128
if xi !== xi2
11291129
changed = true
11301130
end
1131-
if mode != TracedToTypes && !(Core.Typeof(xi2) <: fieldtype(TT, i))
1132-
throw(
1133-
AssertionError(
1134-
"Could not recursively make tracer of object of type $RT into $TT at field $i (named $(fieldname(TT, i))), need object of type $(fieldtype(TT, i)) found object of type $(Core.Typeof(xi2)) ",
1135-
),
1136-
)
1131+
FT = fieldtype(TT, i)
1132+
if mode != TracedToTypes && !(Core.Typeof(xi2) <: FT)
1133+
if FT <: TracedRNumber && xi2 isa unwrapped_eltype(FT)
1134+
xi2 = FT(xi2)
1135+
xi2 = Core.Typeof(xi2)((newpath,), xi2.mlir_data)
1136+
seen[xi2] = xi2
1137+
else
1138+
throw(
1139+
AssertionError(
1140+
"Could not recursively make tracer of object of type $RT into $TT at field $i (named $(fieldname(TT, i))), need object of type $(fieldtype(TT, i)) found object of type $(Core.Typeof(xi2)) ",
1141+
),
1142+
)
1143+
end
11371144
end
11381145
flds[i] = xi2
11391146
else

src/mlir/IR/Block.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@ Base.unsafe_convert(::Core.Type{API.MlirBlock}, block::Block) = block.block
3838
3939
Returns the closest surrounding operation that contains this block.
4040
"""
41-
parent_op(block::Block) = Operation(API.mlirBlockGetParentOperation(block))
41+
parent_op(block::Block) = Operation(API.mlirBlockGetParentOperation(block), false)
4242

4343
"""
4444
parent_region(block)
4545
4646
Returns the region that contains this block.
4747
"""
48-
parent_region(block::Block) = Region(API.mlirBlockGetParentRegion(block))
48+
parent_region(block::Block) = Region(API.mlirBlockGetParentRegion(block), false)
4949

5050
Base.parent(block::Block) = parent_region(block)
5151

0 commit comments

Comments
 (0)