Skip to content

Commit 6ef8f83

Browse files
jumerckxgithub-actions[bot]giordano
authored
Recurse through structs to check is_traced (#931)
* recurse through structs to check is_traced * test * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * handle recursive objects * julia 1.10 doesn't have an IdSet * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * try out not tracking numbers in Ops.while * bump reactantcore * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Mosè Giordano <[email protected]> * bump reactant * nospecialize is_traced arg * of course IdSet exists in 1.10 🤦‍♂️ --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Mosè Giordano <[email protected]>
1 parent d0ee001 commit 6ef8f83

File tree

7 files changed

+47
-4
lines changed

7 files changed

+47
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ Preferences = "1.4"
8585
PythonCall = "0.9"
8686
Random = "1.10"
8787
Random123 = "1.7"
88-
ReactantCore = "0.1.5"
88+
ReactantCore = "0.1.6"
8989
Reactant_jll = "0.0.92"
9090
Scratch = "1.2"
9191
Sockets = "1.10"

lib/ReactantCore/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ReactantCore"
22
uuid = "a3311ec8-5e00-46d5-b541-4f83e724a433"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>"]
4-
version = "0.1.5"
4+
version = "0.1.6"
55

66
[deps]
77
ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43"

lib/ReactantCore/src/ReactantCore.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,18 @@ using MacroTools: MacroTools
66
export @trace, within_compile, MissingTracedValue
77

88
# Traits
9-
is_traced(x) = false
9+
function is_traced((@nospecialize x::T), seen=Base.IdSet()) where {T}
10+
if !isprimitivetype(x)
11+
for fn in fieldnames(T)
12+
f = getfield(x, fn)
13+
if !(f in seen)
14+
push!(seen, f)
15+
is_traced(f, seen) && return true
16+
end
17+
end
18+
end
19+
return false
20+
end
1021

1122
# New Type signifying that a value is missing
1223
mutable struct MissingTracedValue

src/Ops.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1724,7 +1724,10 @@ end
17241724
traced_args = Vector{Any}(undef, N)
17251725
for i in 1:N
17261726
@inbounds traced_args[i] = Reactant.make_tracer(
1727-
seen_args, args[i], (), Reactant.NoStopTracedTrack; track_numbers=Number
1727+
seen_args,
1728+
args[i],
1729+
(),
1730+
Reactant.NoStopTracedTrack, #; track_numbers=Number
17281731
)
17291732
end
17301733

src/TracedRArray.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!, materialize_tra
2222
using ReactantCore: ReactantCore
2323
using GPUArraysCore: GPUArraysCore, @allowscalar
2424

25+
ReactantCore.is_traced(::TracedRArray, seen) = true
2526
ReactantCore.is_traced(::TracedRArray) = true
2627

2728
Base.strides(x::TracedRArray) = Base.size_to_strides(1, size(x)...)

src/TracedRNumber.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using ..Reactant:
44
Reactant, TracedRNumber, TracedRArray, TracedUtils, Ops, MLIR, unwrapped_eltype
55
using ReactantCore
66

7+
ReactantCore.is_traced(::TracedRNumber, seen) = true
78
ReactantCore.is_traced(::TracedRNumber) = true
89

910
Base.getindex(a::TracedRNumber{T}) where {T} = a

test/control_flow.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,3 +784,30 @@ end
784784

785785
@test (@jit ternary_max(a, b)) == 2
786786
end
787+
788+
mutable struct MaybeTraced
789+
x
790+
end
791+
792+
@testset "is_traced of struct" begin
793+
containstraced = MaybeTraced(
794+
MaybeTraced(Reactant.TracedRArray{Float64,1}((), nothing, (3,)))
795+
)
796+
@test Reactant.ReactantCore.is_traced(containstraced)
797+
798+
doesnotcontaintraced = MaybeTraced(MaybeTraced(3))
799+
@test !Reactant.ReactantCore.is_traced(doesnotcontaintraced)
800+
801+
recursivetraced = MaybeTraced((
802+
1,
803+
"string",
804+
MaybeTraced(nothing),
805+
MaybeTraced(Reactant.TracedRArray{Float64,1}((), nothing, (3,))),
806+
))
807+
recursivetraced.x[3].x = recursivetraced
808+
@test Reactant.ReactantCore.is_traced(recursivetraced)
809+
810+
recursivenottraced = MaybeTraced((1, "string", MaybeTraced(nothing)))
811+
recursivenottraced.x[3].x = recursivenottraced
812+
@test !Reactant.ReactantCore.is_traced(recursivenottraced)
813+
end

0 commit comments

Comments
 (0)