Skip to content

Commit d1a6a24

Browse files
Support for dicts (#748)
* Support for dicts * fix * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 6e8ef9f commit d1a6a24

File tree

4 files changed

+86
-0
lines changed

4 files changed

+86
-0
lines changed

src/Compiler.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ import ..Reactant:
2222

2323
import ..ReactantCore: correct_maybe_bcast_call
2424

25+
@inline function traced_getfield(@nospecialize(obj::Dict), field)
26+
return Base.getindex(obj, field)
27+
end
28+
2529
@inline function traced_getfield(@nospecialize(obj), field)
2630
return Base.getfield(obj, field)
2731
end
@@ -40,6 +44,10 @@ end
4044
return Base.setindex!(obj, val, field)
4145
end
4246

47+
@inline function traced_setfield!(@nospecialize(obj::Dict), field, val)
48+
return Base.setindex!(obj, field, val)
49+
end
50+
4351
function create_result(
4452
tocopy::T, path, result_stores, path_to_shard_info, sharding_mesh
4553
) where {T}

src/Sharding.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ struct NoSharding <: AbstractSharding end
5858

5959
# This allows us to mark entire branches as NoSharding
6060
Base.getproperty(::NoSharding, x) = NoSharding()
61+
Base.getproperty(::NoSharding, x::Symbol) = NoSharding()
6162

6263
function (::NoSharding)(client::XLA.Client, device, x::Union{AbstractArray,Number})
6364
buffer = XLA.AsyncBuffer(XLA.ArrayFromHostBuffer(client, x, device), nothing)

src/Tracing.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,6 +1232,68 @@ function make_tracer(
12321232
return newa
12331233
end
12341234

1235+
function make_tracer(
1236+
seen,
1237+
@nospecialize(prev::Dict{Key,Value}),
1238+
@nospecialize(path),
1239+
mode;
1240+
@nospecialize(track_numbers::Type = Union{}),
1241+
@nospecialize(sharding = Sharding.NoSharding()),
1242+
kwargs...,
1243+
) where {Key,Value}
1244+
RT = Core.Typeof(prev)
1245+
# XXX: If someone wants to shard the same array with different shardings, we need to
1246+
# somehow handle this correctly... Right now we just use the first sharding.
1247+
if mode != NoStopTracedTrack && haskey(seen, prev)
1248+
if mode == TracedToTypes
1249+
visited = seen[prev]
1250+
push!(path, visited)
1251+
return nothing
1252+
end
1253+
return seen[prev]
1254+
end
1255+
if eltype(RT) <: ReactantPrimitive
1256+
if mode == ArrayToConcrete && return seen[prev] = ConcreteRArray(prev; sharding)
1257+
elseif mode == TracedToTypes
1258+
# Original array can get mutated so we store a copy:
1259+
push!(path, copy(prev))
1260+
seen[prev] = VisitedObject(length(seen) + 1)
1261+
return nothing
1262+
end
1263+
elseif mode == TracedToTypes
1264+
push!(path, RT)
1265+
for (k, v) in prev
1266+
make_tracer(seen, k, path, mode; track_numbers, sharding, kwargs...)
1267+
make_tracer(seen, v, path, mode; track_numbers, sharding, kwargs...)
1268+
end
1269+
return nothing
1270+
end
1271+
Value2 = traced_type(Value, Val(mode), track_numbers, sharding)
1272+
newa = Dict{Key,Value2}()
1273+
seen[prev] = newa
1274+
same = true
1275+
for (k, v) in prev
1276+
nv = make_tracer(
1277+
seen,
1278+
v,
1279+
append_path(path, k),
1280+
mode;
1281+
track_numbers,
1282+
sharding=Base.getproperty(sharding, k),
1283+
kwargs...,
1284+
)
1285+
if v !== nv
1286+
same = false
1287+
end
1288+
newa[k] = nv
1289+
end
1290+
if same
1291+
seen[prev] = prev
1292+
return prev
1293+
end
1294+
return newa
1295+
end
1296+
12351297
function make_tracer(
12361298
seen,
12371299
@nospecialize(prev::Tuple),

test/basic.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -901,3 +901,18 @@ end
901901
@test contains(hlo, "stablehlo.add")
902902
@test Array(@jit(fn(Base.OneTo(10000)))) collect(Base.OneTo(10000))
903903
end
904+
905+
function dip!(x)
906+
x[:a] = x[:a] .* x[:b]
907+
return nothing
908+
end
909+
910+
@testset "Dict" begin
911+
x = Dict{Symbol,Vector{Float32}}()
912+
x[:a] = 2.7 * ones(4)
913+
x[:b] = 3.1 * ones(4)
914+
915+
ra = Reactant.to_rarray(x)
916+
Reactant.@jit dip!(ra)
917+
ra[:a] (2.7 * 2) * ones(4)
918+
end

0 commit comments

Comments
 (0)