Skip to content

Commit 10d6c9a

Browse files
More disable custom interp (#1029)
* More disable custom interp * err * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * materialize traced array * fix * fix * fix * oneto * fixup * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * up docs --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 6c071ab commit 10d6c9a

File tree

7 files changed

+86
-34
lines changed

7 files changed

+86
-34
lines changed

docs/src/api/api.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,5 @@ Reactant.addressable_devices
5050
## Internal utils
5151

5252
```@docs
53-
Reactant.TracedUtils.materialize_traced_array
53+
ReactantCore.materialize_traced_array
5454
```

lib/ReactantCore/src/ReactantCore.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,4 +516,12 @@ function error_if_any_control_flow(expr)
516516
end
517517
end
518518

519+
"""
520+
materialize_traced_array(AbstractArray{<:TracedRNumber})::TracedRArray
521+
522+
Given an AbstractArray{TracedRNumber}, return or create an equivalent TracedRArray.
523+
524+
"""
525+
function materialize_traced_array end
526+
519527
end

src/Reactant.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module Reactant
22

3-
using ReactantCore: ReactantCore, @trace, within_compile, MissingTracedValue
3+
using ReactantCore:
4+
ReactantCore, @trace, within_compile, MissingTracedValue, materialize_traced_array
45

56
using LinearAlgebra: LinearAlgebra
67
using Random: Random, AbstractRNG

src/TracedRNumber.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,31 @@ function TracedUnitRange{T}(r::AbstractUnitRange) where {T<:Real}
575575
end
576576
TracedUnitRange(r::AbstractUnitRange) = TracedUnitRange(first(r), last(r))
577577

578+
@inline function Base.length(r::TracedUnitRange{TracedRNumber{T}}) where {T}
579+
start, stop = first(r), last(r)
580+
a = Base.oneunit(Base.zero(stop) - Base.zero(start))
581+
if a isa Signed
582+
# Signed are allowed to go negative
583+
Ops.select(stop >= start, a + stop - start, a)
584+
else
585+
Ops.select(stop >= start, a + stop - start, zero(a))
586+
end
587+
end
588+
589+
function Base._reshape(v::TracedUnitRange, dims::Dims{1})
590+
Base.require_one_based_indexing(v)
591+
len = dims[1]
592+
# TODO support errors
593+
# len == length(v) || Base._throw_dmrs(length(v), "length", len)
594+
return v
595+
end
596+
function Base._reshape(parent::TracedUnitRange, dims::Dims)
597+
n = length(parent)
598+
# TODO support errors
599+
# prod(dims) == n || Base._throw_dmrs(n, "size", dims)
600+
return Base.__reshape((parent, IndexStyle(parent)), dims)
601+
end
602+
578603
AbstractUnitRange{T}(r::TracedUnitRange) where {T} = TracedUnitRange{T}(r)
579604

580605
struct TracedStepRangeLen{T,R,S,L} <: AbstractRange{T}

src/TracedUtils.jl

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,48 +14,47 @@ using ..Reactant:
1414
OrderedIdDict,
1515
ReactantPrimitive,
1616
Ops
17-
using ReactantCore: MissingTracedValue, is_traced
17+
using ReactantCore: ReactantCore
18+
using ReactantCore: MissingTracedValue, is_traced, materialize_traced_array
1819
using Functors: Functors
1920

20-
"""
21-
materialize_traced_array(AbstractArray{<:TracedRNumber})::TracedRArray
21+
ReactantCore.materialize_traced_array(x::AbstractArray) = x
2222

23-
Given an AbstractArray{TracedRNumber}, return or create an equivalent TracedRArray.
23+
ReactantCore.materialize_traced_array(x::TracedRArray) = x
2424

25-
"""
26-
materialize_traced_array(x::AbstractArray) = x
25+
ReactantCore.materialize_traced_array(x::AnyTracedRArray) = x[axes(x)...]
2726

28-
materialize_traced_array(x::TracedRArray) = x
29-
30-
materialize_traced_array(x::AnyTracedRArray) = x[axes(x)...]
31-
32-
function materialize_traced_array(x::AbstractRange{<:TracedRNumber})
27+
function ReactantCore.materialize_traced_array(x::AbstractRange)
3328
return Reactant.aos_to_soa(collect(x))
3429
end
3530

36-
function materialize_traced_array(x::UnitRange{<:TracedRNumber})
31+
function ReactantCore.materialize_traced_array(x::Base.OneTo)
32+
return Ops.iota(Reactant.unwrapped_eltype(x), [length(x)]; iota_dimension=1)
33+
end
34+
35+
function ReactantCore.materialize_traced_array(x::UnitRange)
3736
return Ops.add(
3837
Ops.iota(Reactant.unwrapped_eltype(x), [length(x)]; iota_dimension=1),
3938
Ops.fill(first(x), [length(x)]),
4039
)
4140
end
4241

43-
function materialize_traced_array(x::SubArray{TracedRNumber{T}}) where {T}
42+
function ReactantCore.materialize_traced_array(x::SubArray)
4443
z = SubArray(materialize_traced_array(parent(x)), x.indices)
4544
return z[axes(z)...]
4645
end
4746

48-
function materialize_traced_array(x::Base.ReshapedArray{TracedRNumber{T}}) where {T}
47+
function ReactantCore.materialize_traced_array(x::Base.ReshapedArray)
4948
return Ops.reshape(materialize_traced_array(parent(x)), size(x)...)
5049
end
5150

52-
function materialize_traced_array(
53-
x::PermutedDimsArray{TracedRNumber{T},N,perm}
54-
) where {T,N,perm}
51+
function ReactantCore.materialize_traced_array(
52+
x::PermutedDimsArray{<:Any,<:Any,perm}
53+
) where {perm}
5554
return permutedims(materialize_traced_array(parent(x)), perm)
5655
end
5756

58-
function materialize_traced_array(x::AbstractArray{TracedRNumber{T}}) where {T}
57+
function ReactantCore.materialize_traced_array(x::AbstractArray{TracedRNumber{T}}) where {T}
5958
return Reactant.aos_to_soa(x)
6059
end
6160

src/stdlibs/LinearAlgebra.jl

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,34 +12,37 @@ using ..Reactant:
1212
Ops,
1313
MLIR
1414

15-
using ..TracedUtils: TracedUtils, get_mlir_data, materialize_traced_array, set_mlir_data!
15+
using ReactantCore: ReactantCore
16+
using ReactantCore: materialize_traced_array
17+
18+
using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!
1619

1720
using LinearAlgebra
1821

1922
# Various Wrapper Arrays defined in LinearAlgebra
20-
function TracedUtils.materialize_traced_array(
23+
function ReactantCore.materialize_traced_array(
2124
x::Transpose{TracedRNumber{T},<:AnyTracedRArray}
2225
) where {T}
23-
px = TracedUtils.materialize_traced_array(parent(x))
26+
px = materialize_traced_array(parent(x))
2427
A = ndims(px) == 1 ? reshape(px, :, 1) : px
2528
return permutedims(A, (2, 1))
2629
end
2730

28-
function TracedUtils.materialize_traced_array(
31+
function ReactantCore.materialize_traced_array(
2932
x::Adjoint{TracedRNumber{T},<:AnyTracedRArray}
3033
) where {T}
3134
return Ops.conj(
3235
materialize_traced_array(transpose(materialize_traced_array(parent(x))))
3336
)
3437
end
3538

36-
function TracedUtils.materialize_traced_array(
39+
function ReactantCore.materialize_traced_array(
3740
x::Diagonal{TracedRNumber{T},<:AnyTracedRVector}
3841
) where {T}
3942
return diagm(materialize_traced_array(parent(x)))
4043
end
4144

42-
function TracedUtils.materialize_traced_array(
45+
function ReactantCore.materialize_traced_array(
4346
x::Tridiagonal{TracedRNumber{T},<:AnyTracedRVector}
4447
) where {T}
4548
return diagm(-1 => x.dl, 0 => x.d, 1 => x.du)
@@ -48,22 +51,22 @@ end
4851
for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE"))
4952
uAT = Symbol(:Unit, AT)
5053
@eval begin
51-
function TracedUtils.materialize_traced_array(
54+
function ReactantCore.materialize_traced_array(
5255
x::$(AT){TracedRNumber{T},<:AnyTracedRMatrix}
5356
) where {T}
5457
m, n = size(x)
55-
px = TracedUtils.materialize_traced_array(parent(x))
58+
px = materialize_traced_array(parent(x))
5659
row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1)
5760
col_idxs = Ops.iota(Int, [m, n]; iota_dimension=2)
5861
indicator = Ops.compare(row_idxs, col_idxs; comparison_direction=$(comp))
5962
return Ops.select(indicator, px, zero(px))
6063
end
6164

62-
function TracedUtils.materialize_traced_array(
65+
function ReactantCore.materialize_traced_array(
6366
x::$(uAT){TracedRNumber{T},<:AnyTracedRMatrix}
6467
) where {T}
6568
m, n = size(x)
66-
px = TracedUtils.materialize_traced_array(parent(x))
69+
px = materialize_traced_array(parent(x))
6770
row_idxs = Ops.iota(Int, [m, n]; iota_dimension=1)
6871
col_idxs = Ops.iota(Int, [m, n]; iota_dimension=2)
6972
nondiag_indicator = Ops.compare(row_idxs, col_idxs; comparison_direction="NE")
@@ -73,7 +76,7 @@ for (AT, comp) in ((:LowerTriangular, "GE"), (:UpperTriangular, "LE"))
7376
end
7477
end
7578

76-
function TracedUtils.materialize_traced_array(
79+
function ReactantCore.materialize_traced_array(
7780
x::Symmetric{TracedRNumber{T},<:AnyTracedRMatrix}
7881
) where {T}
7982
m, n = size(x)

src/utils.jl

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,9 @@ function should_rewrite_call(@nospecialize(ft))
179179
ft <: typeof(Adapt.adapt_structure) ||
180180
ft <: typeof(Core.is_top_bit_set) ||
181181
ft <: typeof(Base.setindex_widen_up_to) ||
182-
ft <: typeof(Base.typejoin)
182+
ft <: typeof(Base.typejoin) ||
183+
ft <: typeof(Base.argtype_decl) ||
184+
ft <: typeof(Base.arg_decl_parts)
183185
return false
184186
end
185187

@@ -188,7 +190,12 @@ function should_rewrite_call(@nospecialize(ft))
188190
end
189191

190192
# by default, same as `should_rewrite_call`
191-
should_rewrite_invoke(@nospecialize(ft), @nospecialize(args)) = should_rewrite_call(ft)
193+
function should_rewrite_invoke(@nospecialize(ft), @nospecialize(args))
194+
if ft <: typeof(repeat) && args == Tuple{String,Int64}
195+
return false
196+
end
197+
return should_rewrite_call(ft)
198+
end
192199

193200
# Avoid recursively interpreting into methods we define explicitly
194201
# as overloads, which we assume should handle the entirety of the
@@ -582,7 +589,16 @@ function call_with_reactant_generator(
582589
safe_print("ir", ir)
583590
end
584591

585-
if !is_reactant_method(mi::Core.MethodInstance) || guaranteed_error
592+
mi = mi::Core.MethodInstance
593+
594+
if !(
595+
is_reactant_method(mi) || (
596+
mi.def.sig isa DataType &&
597+
!should_rewrite_invoke(
598+
mi.def.sig.parameters[1], Tuple{mi.def.sig.parameters[2:end]...}
599+
)
600+
)
601+
) || guaranteed_error
586602
ir, any_changed = rewrite_insts!(ir, interp, guaranteed_error)
587603
end
588604

0 commit comments

Comments
 (0)