Skip to content

Commit a39b055

Browse files
feat: support 2 levels of wrapping (#824)
* feat: support 2 levels of wrapping * feat: nicer printing * Update src/ConcreteRArray.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 85bdbee commit a39b055

File tree

5 files changed

+54
-1
lines changed

5 files changed

+54
-1
lines changed

src/ConcreteRArray.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,13 @@ function Base.print_array(io::IO, X::AnyConcretePJRTArray)
167167
return Base.print_array(io, convert(Array, X))
168168
end
169169

170+
function Base.showarg(io::IO, a::ConcretePJRTArray{T,N}, toplevel) where {T,N}
171+
toplevel || print(io, "::")
172+
print(io, "ConcretePJRTArray{$T,$N}")
173+
Sharding.is_sharded(a) && print(io, " with sharding $(typeof(a.sharding.sharding))")
174+
return nothing
175+
end
176+
170177
function Base.show(io::IO, X::AnyConcretePJRTArray)
171178
if isempty(X)
172179
print(io, "<Empty Buffer eltype $(eltype(X)) of size $(size(X))>")

src/TracedRArray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module TracedRArrayOverrides
22

3-
using Adapt: WrappedReshapedArray
3+
using Adapt: WrappedReshapedArray, WrappedArray
44
using Base.Broadcast
55
using Base.Broadcast: BroadcastStyle, Broadcasted, AbstractArrayStyle, instantiate
66

src/TracedUtils.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,9 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
602602
end
603603

604604
function broadcast_to_size(arg::AbstractArray{<:TracedRNumber}, rsize)
605+
if Reactant.ancestor(arg) isa TracedRArray
606+
return broadcast_to_size(materialize_traced_array(arg), rsize)
607+
end
605608
return broadcast_to_size(reshape(Ops.vcat(arg...), size(arg)...), rsize)
606609
end
607610
broadcast_to_size(arg::AbstractArray, rsize) = broadcast_to_size(Ops.constant(arg), rsize)

src/stdlibs/LinearAlgebra.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using ..Reactant:
77
AnyTracedRMatrix,
88
AnyTracedRVector,
99
AnyTracedRVecOrMat,
10+
WrappedTracedRArray,
1011
unwrapped_eltype,
1112
Ops,
1213
MLIR
@@ -24,18 +25,36 @@ function TracedUtils.materialize_traced_array(
2425
return permutedims(A, (2, 1))
2526
end
2627

28+
function TracedUtils.materialize_traced_array(
29+
x::Transpose{TracedRNumber{T},<:WrappedTracedRArray{T,N}}
30+
) where {T,N}
31+
return materialize_traced_array(transpose(materialize_traced_array(parent(x))))
32+
end
33+
2734
function TracedUtils.materialize_traced_array(
2835
x::Adjoint{TracedRNumber{T},TracedRArray{T,N}}
2936
) where {T,N}
3037
return conj(materialize_traced_array(transpose(parent(x))))
3138
end
3239

40+
function TracedUtils.materialize_traced_array(
41+
x::Adjoint{TracedRNumber{T},<:WrappedTracedRArray{T,N}}
42+
) where {T,N}
43+
return materialize_traced_array(adjoint(materialize_traced_array(parent(x))))
44+
end
45+
3346
function TracedUtils.materialize_traced_array(
3447
x::Diagonal{TracedRNumber{T},TracedRArray{T,1}}
3548
) where {T}
3649
return diagm(parent(x))
3750
end
3851

52+
function TracedUtils.materialize_traced_array(
53+
x::Diagonal{TracedRNumber{T},WrappedTracedRArray{T,1}}
54+
) where {T}
55+
return diagm(materialize_traced_array(parent(x)))
56+
end
57+
3958
function TracedUtils.materialize_traced_array(
4059
x::Tridiagonal{TracedRNumber{T},TracedRArray{T,1}}
4160
) where {T}

test/wrapped_arrays.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,3 +264,27 @@ end
264264
hlo = repr(@code_hlo(permutedims_getindex(x_ra)))
265265
@test !occursin("stablehlo.gather", hlo)
266266
end
267+
268+
function view_adjoint(x)
269+
y = view(x, 1:2, 1:2)
270+
return adjoint(y) .+ y
271+
end
272+
273+
function view_transpose(x)
274+
y = view(x, 1:2, 1:2)
275+
return transpose(y) .+ y
276+
end
277+
278+
function view_diagonal(x)
279+
y = view(x, 1:2, 1:2)
280+
return Diagonal(y) .+ y
281+
end
282+
283+
@testset "2 levels of wrapping" begin
284+
x = reshape(collect(Float32, 1:8), 2, 4)
285+
x_ra = Reactant.to_rarray(x)
286+
287+
@test @jit(view_adjoint(x_ra)) view_adjoint(x)
288+
@test @jit(view_transpose(x_ra)) view_transpose(x)
289+
@test @jit(view_diagonal(x_ra)) view_diagonal(x)
290+
end

0 commit comments

Comments
 (0)