-
Notifications
You must be signed in to change notification settings - Fork 29
FixedSizeArrays support (And Memory Support) #1669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e0906bc
460dd4c
73205e2
15eca0b
cbd5e75
702ac1d
658a3d7
7e79d15
08a425e
8b9f24d
dc00c6a
ff2d39b
995c8a3
8ef1575
c1842c3
a033cc6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,36 @@ | ||||||||||||||||||||||||||||||||||||||||||
module ReactantFixedSizeArraysExt | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
using FixedSizeArrays | ||||||||||||||||||||||||||||||||||||||||||
using Reactant | ||||||||||||||||||||||||||||||||||||||||||
using Reactant: TracedRArray, TracedRNumber, Ops | ||||||||||||||||||||||||||||||||||||||||||
using ReactantCore: ReactantCore | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
function Reactant.traced_type_inner( | ||||||||||||||||||||||||||||||||||||||||||
@nospecialize(_::Type{FixedSizeArrays.FixedSizeArrayDefault{T,N}}), | ||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
seen, | ||||||||||||||||||||||||||||||||||||||||||
@nospecialize(mode::Reactant.TraceMode), | ||||||||||||||||||||||||||||||||||||||||||
@nospecialize(track_numbers::Type), | ||||||||||||||||||||||||||||||||||||||||||
@nospecialize(sharding), | ||||||||||||||||||||||||||||||||||||||||||
@nospecialize(runtime) | ||||||||||||||||||||||||||||||||||||||||||
) where {T,N} | ||||||||||||||||||||||||||||||||||||||||||
T2 = Reactant.TracedRNumber{T} | ||||||||||||||||||||||||||||||||||||||||||
return FixedSizeArrays.FixedSizeArrayDefault{T2,N} | ||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
Base.@nospecializeinfer function Reactant.make_tracer( | ||||||||||||||||||||||||||||||||||||||||||
seen, | ||||||||||||||||||||||||||||||||||||||||||
@nospecialize(prev::FixedSizeArrays.FixedSizeArrayDefault{T,N}), | ||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
@nospecialize(path), | ||||||||||||||||||||||||||||||||||||||||||
mode; | ||||||||||||||||||||||||||||||||||||||||||
kwargs..., | ||||||||||||||||||||||||||||||||||||||||||
) where {T,N} | ||||||||||||||||||||||||||||||||||||||||||
shape = size(prev) | ||||||||||||||||||||||||||||||||||||||||||
return reshape( | ||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but array itself we override to keep the dimensionality and allocate everything ourselves for Array. if the actual tracing of a fixedsizearray does the current "generic recursion into structs" it will eventually allocate a 1-dim memory, always There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The whole point of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It feels to me like Lines 1177 to 1196 in e4bb34f
size(prev) but could be overridden if passed explicitly.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I opened #1712 to implement my suggestion. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function needs to be updated to return a FixedSizedArray where the internal data would be a Traced/Concrete Array (#1669 (comment)) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess I'm missing something fundamental: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Array maps to our Concrete/Traced Array. For all "wrapped" types, we need to preserve the wrapper type, if we want the outputs to preserve the wrapped type. Else any operation you perform on a FixedSizeArray with inevitably be mapped to an Array output There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not really convinced a proposal for changing the memory backend in FixedSizeArrays is going to fly: it's meant to follow Array very closely, and the memory backend is always a flattened dense vector. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean once #1696 goes in, I think we should be able to make that work (though the backing memory might need to be a reshape(tracedrarray{n}, 1), but the reshape will optimize out with the one that's emitted currently There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing the call to FixedSizedArray There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Returning a FixedSizeArray causes problems. With
I don't know why it's throwing this, as after inspection the inner Without it I get:
For this one I don't have a clue. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So: mem = Memory{Float32}([1.0f0, 2.0f0])
getfield(mem, 2)
This is where the pointer is coming from. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since Reactant isn't supporting resizing for normal arrays, and FixedSizeArray isn't doing any fancy memory/optimization stuff other than fixed size (unlike OneHotArrays, FillArrays) shouldn't it be fine for the FixedSizeArray to transform into an ordinary Concrete Array? |
||||||||||||||||||||||||||||||||||||||||||
Reactant.make_tracer( | ||||||||||||||||||||||||||||||||||||||||||
seen, parent(prev), (path..., 1), mode; kwargs..., track_numbers=Number | ||||||||||||||||||||||||||||||||||||||||||
), | ||||||||||||||||||||||||||||||||||||||||||
shape, | ||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||
end | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
end |
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -214,8 +214,8 @@ function ConcretePJRTArray{T,N}( | |||||||||||||||
return ConcretePJRTArray{T,N,D,typeof(sharding)}(data, shape, sharding) | ||||||||||||||||
end | ||||||||||||||||
|
||||||||||||||||
function ConcretePJRTArray( | ||||||||||||||||
data::Array{T,N}; | ||||||||||||||||
function make_concrete_PJRT_array( | ||||||||||||||||
data::AbstractArray{T,N}, | ||||||||||||||||
client::Union{Nothing,XLA.PJRT.Client}=nothing, | ||||||||||||||||
idx::Union{Int,Nothing}=nothing, | ||||||||||||||||
device::Union{Nothing,XLA.PJRT.Device}=nothing, | ||||||||||||||||
|
@@ -228,6 +228,28 @@ function ConcretePJRTArray( | |||||||||||||||
return ConcretePJRTArray{T,N,nsharded,typeof(shardinfo)}(sharded_data, shape, shardinfo) | ||||||||||||||||
end | ||||||||||||||||
|
||||||||||||||||
function ConcretePJRTArray( | ||||||||||||||||
data::Array{T,N}; | ||||||||||||||||
client::Union{Nothing,XLA.PJRT.Client}=nothing, | ||||||||||||||||
idx::Union{Int,Nothing}=nothing, | ||||||||||||||||
device::Union{Nothing,XLA.PJRT.Device}=nothing, | ||||||||||||||||
sharding::Sharding.AbstractSharding=Sharding.NoSharding(), | ||||||||||||||||
) where {T,N} | ||||||||||||||||
return make_concrete_PJRT_array(data, client, idx, device, sharding) | ||||||||||||||||
end | ||||||||||||||||
|
||||||||||||||||
if isdefined(Base, :Memory) | ||||||||||||||||
function ConcretePJRTArray( | ||||||||||||||||
data::Memory{T}; | ||||||||||||||||
client::Union{Nothing,XLA.PJRT.Client}=nothing, | ||||||||||||||||
idx::Union{Int,Nothing}=nothing, | ||||||||||||||||
device::Union{Nothing,XLA.PJRT.Device}=nothing, | ||||||||||||||||
sharding::Sharding.AbstractSharding=Sharding.NoSharding(), | ||||||||||||||||
) where {T} | ||||||||||||||||
return make_concrete_PJRT_array(data, client, idx, device, sharding) | ||||||||||||||||
end | ||||||||||||||||
end | ||||||||||||||||
|
||||||||||||||||
Base.wait(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = foreach(wait, x.data) | ||||||||||||||||
XLA.client(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = XLA.client(x.data) | ||||||||||||||||
function XLA.device(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) | ||||||||||||||||
|
@@ -342,8 +364,8 @@ function ConcreteIFRTArray{T,N}( | |||||||||||||||
return ConcreteIFRTArray{T,N,typeof(sharding)}(data, shape, sharding) | ||||||||||||||||
end | ||||||||||||||||
|
||||||||||||||||
function ConcreteIFRTArray( | ||||||||||||||||
data::Array{T,N}; | ||||||||||||||||
function make_concrete_IFRT_array( | ||||||||||||||||
data::AbstractArray{T,N}, | ||||||||||||||||
client::Union{Nothing,XLA.IFRT.Client}=nothing, | ||||||||||||||||
idx::Union{Int,Nothing}=nothing, | ||||||||||||||||
device::Union{Nothing,XLA.IFRT.Device}=nothing, | ||||||||||||||||
|
@@ -356,6 +378,28 @@ function ConcreteIFRTArray( | |||||||||||||||
return ConcreteIFRTArray{T,N,typeof(shardinfo)}(sharded_data, shape, shardinfo, padding) | ||||||||||||||||
end | ||||||||||||||||
|
||||||||||||||||
function ConcreteIFRTArray( | ||||||||||||||||
data::Array{T,N}; | ||||||||||||||||
client::Union{Nothing,XLA.IFRT.Client}=nothing, | ||||||||||||||||
idx::Union{Int,Nothing}=nothing, | ||||||||||||||||
device::Union{Nothing,XLA.IFRT.Device}=nothing, | ||||||||||||||||
sharding::Sharding.AbstractSharding=Sharding.NoSharding(), | ||||||||||||||||
) where {T,N} | ||||||||||||||||
return make_concrete_IFRT_array(data, client, idx, device, sharding) | ||||||||||||||||
end | ||||||||||||||||
|
||||||||||||||||
if isdefined(Base, :Memory) | ||||||||||||||||
function ConcreteIFRTArray( | ||||||||||||||||
data::Memory{T}; | ||||||||||||||||
client::Union{Nothing,XLA.IFRT.Client}=nothing, | ||||||||||||||||
idx::Union{Int,Nothing}=nothing, | ||||||||||||||||
device::Union{Nothing,XLA.IFRT.Device}=nothing, | ||||||||||||||||
sharding::Sharding.AbstractSharding=Sharding.NoSharding(), | ||||||||||||||||
) where {T} | ||||||||||||||||
return make_concrete_IFRT_array(data, client, idx, device, sharding) | ||||||||||||||||
end | ||||||||||||||||
end | ||||||||||||||||
|
||||||||||||||||
# Assemble data from multiple arrays. Needed in distributed setting where each process wont | ||||||||||||||||
# have enough host memory to hold all the arrays. We assume that the data is only provided | ||||||||||||||||
# for all of the addressable devices. | ||||||||||||||||
|
@@ -472,8 +516,11 @@ elseif XLA.REACTANT_XLA_RUNTIME == "IFRT" | |||||||||||||||
ConcreteIFRTArray | ||||||||||||||||
end | ||||||||||||||||
|
||||||||||||||||
@inline ConcreteRArray{T}(::UndefInitializer, shape::Integer...; kwargs...) where {T} = | ||||||||||||||||
ConcreteRArray{T}(undef, Dims(shape); kwargs...) | ||||||||||||||||
@inline ConcreteRArray{T}(::UndefInitializer, shape::Integer...; kwargs...) where {T} = ConcreteRArray{ | ||||||||||||||||
T | ||||||||||||||||
}( | ||||||||||||||||
undef, Dims(shape); kwargs... | ||||||||||||||||
) | ||||||||||||||||
Comment on lines
+519
to
+523
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
|
||||||||||||||||
""" | ||||||||||||||||
ConcreteRNumber( | ||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
|
||
using Reactant, Test, FixedSizeArrays | ||
|
||
fn(x, y) = (2 .* x .- 3) * y' | ||
|
||
@testset "FixedSizeArrays" begin | ||
@testset "1D" begin | ||
x = FixedSizeArray(fill(3.0f0, 100)) | ||
rx = Reactant.to_rarray(x) | ||
@test @jit(fn(rx, rx)) ≈ fn(x, x) | ||
end | ||
@testset "2D" begin | ||
x = FixedSizeArray(fill(3.0f0, (4, 5))) | ||
rx = Reactant.to_rarray(x) | ||
@test @jit(fn(rx, rx)) ≈ fn(x, x) | ||
end | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
using Reactant, Test | ||
|
||
fn(x, y) = sin.(x) .+ cos.(y) | ||
|
||
@testset "Memory test" begin | ||
x = Memory{Float32}(fill(2.0f0, 10)) | ||
x_ra = Reactant.to_rarray(x) | ||
|
||
@test @jit(fn(x_ra, x_ra)) ≈ fn(x, x) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.