Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
FixedSizeArrays = "3821ddf9-e5b5-40d5-8e25-6813ab96b5e2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[sources]
Expand All @@ -58,6 +59,7 @@ ReactantArrayInterfaceExt = "ArrayInterface"
ReactantCUDAExt = ["CUDA", "GPUCompiler", "KernelAbstractions", "LLVM"]
ReactantDLFP8TypesExt = "DLFP8Types"
ReactantFillArraysExt = "FillArrays"
ReactantFixedSizeArraysExt = "FixedSizeArrays"
ReactantFloat8sExt = "Float8s"
ReactantKernelAbstractionsExt = "KernelAbstractions"
ReactantMPIExt = "MPI"
Expand All @@ -84,6 +86,7 @@ EnumX = "1"
Enzyme = "0.13.78"
EnzymeCore = "0.8.13"
FillArrays = "1.13"
FixedSizeArrays = "1.2.0"
Float8s = "0.1"
Functors = "0.5"
GPUArraysCore = "0.2"
Expand Down
36 changes: 36 additions & 0 deletions ext/ReactantFixedSizeArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
module ReactantFixedSizeArraysExt

using FixedSizeArrays
using Reactant
using Reactant: TracedRArray, TracedRNumber, Ops
using ReactantCore: ReactantCore
Comment on lines +3 to +6
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
using FixedSizeArrays
using Reactant
using Reactant: TracedRArray, TracedRNumber, Ops
using ReactantCore: ReactantCore
using FixedSizeArrays: FixedSizeArrayDefault
using Reactant: Reactant, TracedRArray, TracedRNumber, Ops
using ReactantCore: ReactantCore


function Reactant.traced_type_inner(
@nospecialize(_::Type{FixedSizeArrays.FixedSizeArrayDefault{T,N}}),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@nospecialize(_::Type{FixedSizeArrays.FixedSizeArrayDefault{T,N}}),
@nospecialize(_::Type{FixedSizeArrayDefault{T,N}}),

seen,
@nospecialize(mode::Reactant.TraceMode),
@nospecialize(track_numbers::Type),
@nospecialize(sharding),
@nospecialize(runtime)
) where {T,N}
T2 = Reactant.TracedRNumber{T}
return FixedSizeArrays.FixedSizeArrayDefault{T2,N}
end

Base.@nospecializeinfer function Reactant.make_tracer(
seen,
@nospecialize(prev::FixedSizeArrays.FixedSizeArrayDefault{T,N}),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@nospecialize(prev::FixedSizeArrays.FixedSizeArrayDefault{T,N}),
@nospecialize(prev::FixedSizeArrayDefault{T,N}),

@nospecialize(path),
mode;
kwargs...,
) where {T,N}
shape = size(prev)
return reshape(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this reshape is the culprit? How does Array work? Is it possible to construct this object directly with the right size instead of reshaping it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but array itself we override to keep the dimensionality and allocate everything ourselves for Array.

if the actual tracing of a fixedsizearray does the current "generic recursion into structs" it will eventually allocate a 1-dim memory, always

Copy link
Member

@giordano giordano Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole point of FixedSizeArray is that the size is...well...fixed. Having to reshape it all the time seems to go into the opposite direction, especially when Array doesn't have that.

Copy link
Member

@giordano giordano Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels to me like make_tracer should take the size as an (optional) argument. Looking at

Reactant.jl/src/Tracing.jl

Lines 1177 to 1196 in e4bb34f

Base.@nospecializeinfer function make_tracer(
seen,
@nospecialize(prev::ConcreteIFRTArray{T,N}),
@nospecialize(path),
mode;
@nospecialize(sharding = Sharding.NoSharding()),
@nospecialize(device = nothing),
@nospecialize(client = nothing),
kwargs...,
) where {T,N}
if mode == TracedToTypes
throw("Cannot have ConcreteIFRTArray as function call argument.")
end
mode == ArrayToConcrete && return ConcreteIFRTArray(prev; sharding, device, client)
mode != ConcreteToTraced && throw("Cannot trace concrete")
haskey(seen, prev) && return seen[prev]::TracedRArray{T,N}
res = TracedRArray{T,N}((path,), nothing, size(prev))
seen[prev] = res
return res
end
(and all similar methods) the size could be another argument which defaults to size(prev) but could be overridden if passed explicitly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened #1712 to implement my suggestion.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function needs to be updated to return a FixedSizedArray where the internal data would be a Traced/Concrete Array (#1669 (comment))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I'm missing something fundamental: Array has the same memory backend, why should FixedSizeArray be any different?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Array maps to our Concrete/Traced Array. For all "wrapped" types, we need to preserve the wrapper type, if we want the outputs to preserve the wrapped type. Else any operation you perform on a FixedSizeArray with inevitably be mapped to an Array output

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not really convinced a proposal for changing the memory backend in FixedSizeArrays is going to fly: it's meant to follow Array very closely, and the memory backend is always a flattened dense vector.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean once #1696 goes in, I think we should be able to make that work (though the backing memory might need to be a reshape(tracedrarray{n}, 1), but the reshape will optimize out with the one that's emitted currently

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing the call to FixedSizedArray

Copy link
Author

@Qfl3x Qfl3x Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Returning a FixedSizeArray causes problems.

With track_numbers=Number I get this: (@jit(f(rx)))

ERROR: TypeError: in RNumber, in T, expected T<:Union{Complex{Int16}, Complex{Int32}, Complex{Int64}, Complex{Int8}, Complex{UInt16}, Complex{UInt32}, Complex{UInt64}, Complex{UInt8}, Core.BFloat16, Bool, Float16, Float32, Float64, Int16, Int32, Int64, Int8, UInt16, UInt32, UInt64, UInt8, Reactant.F8E4M3B11FNUZ, Reactant.F8E4M3FN, Reactant.F8E4M3FNUZ, Reactant.F8E5M2, Reactant.F8E5M2FNUZ, ComplexF64, ComplexF32}, got Type{Reactant.TracedRNumber{Float32}}
Stacktrace:
  [1] copyto!(dest::Memory{Reactant.TracedRNumber{Float32}}, src::Vector{Reactant.TracedRNumber{Float32}})
    @ Reactant.TracedRArrayOverrides ~/projects/Reactant.jl/src/TracedRArray.jl:492
  [2] collect_as_vectorlike_with_known_eltype_and_length(::Type{Memory{Reactant.TracedRNumber{Float32}}}, collection::Vector{Reactant.TracedRNumber{Float32}})
    @ Collects ~/.julia/packages/Collects/iMuH5/src/Collects.jl:269
  [3] collect_as_memory_with_known_eltype_and_known_length(::Type{Reactant.TracedRNumber{Float32}}, collection::Vector{Reactant.TracedRNumber{Float32}})
    @ Collects ~/.julia/packages/Collects/iMuH5/src/Collects.jl:345
  [4] collect_as_memory_with_known_eltype(::Type{Reactant.TracedRNumber{Float32}}, collection::Vector{Reactant.TracedRNumber{Float32}})
    @ Collects ~/.julia/packages/Collects/iMuH5/src/Collects.jl:353
  [5] collect_as_memory(::Function, ::Type{Memory{Reactant.TracedRNumber{Float32}}}, collection::Vector{Reactant.TracedRNumber{Float32}})
    @ Collects ~/.julia/packages/Collects/iMuH5/src/Collects.jl:373
  [6] collect_as_common_invariant(e::typeof(Collects.EmptyIteratorHandling.just_throws), ::Type{Memory{Reactant.TracedRNumber{Float32}}}, collection::Vector{Reactant.TracedRNumber{Float32}})
    @ Collects ~/.julia/packages/Collects/iMuH5/src/Collects.jl:376
  [7] collect_as_common(e::typeof(Collects.EmptyIteratorHandling.just_throws), type::Type{Memory{Reactant.TracedRNumber{Float32}}}, collection::Vector{Reactant.TracedRNumber{Float32}})
    @ Collects ~/.julia/packages/Collects/iMuH5/src/Collects.jl:400
  [8] (::Collects.Collect{typeof(Collects.EmptyIteratorHandling.just_throws)})(type::Type{Memory{Reactant.TracedRNumber{Float32}}}, collection::Vector{Reactant.TracedRNumber{Float32}})
    @ Collects ~/.julia/packages/Collects/iMuH5/src/Collects.jl:411
  [9] (::Collects.Collect{typeof(Collects.EmptyIteratorHandling.just_throws)})(::Type{FixedSizeArrayDefault{Reactant.TracedRNumber{Float32}}}, iterator::Vector{Reactant.TracedRNumber{Float32}})
    @ FixedSizeArrays ~/.julia/packages/FixedSizeArrays/VHLXx/src/collect_as.jl:45
 [10] collect_as(::Type{FixedSizeArrayDefault{Reactant.TracedRNumber{Float32}}}, collection::Vector{Reactant.TracedRNumber{Float32}}; empty_iterator_handler::typeof(Collects.EmptyIteratorHandling.just_throws))
    @ Collects ~/.julia/packages/Collects/iMuH5/src/Collects.jl:425
 [11] collect_as(::Type{FixedSizeArrayDefault{Reactant.TracedRNumber{Float32}}}, collection::Vector{Reactant.TracedRNumber{Float32}})
    @ Collects ~/.julia/packages/Collects/iMuH5/src/Collects.jl:423
 [12] collect_as_haseltype(::Type{FixedSizeArrayDefault}, iterator::Vector{Reactant.TracedRNumber{Float32}})
    @ FixedSizeArrays ~/.julia/packages/FixedSizeArrays/VHLXx/src/collect_as.jl:56
 [13] converting_constructor(::Type{FixedSizeArrayDefault}, src::Vector{Reactant.TracedRNumber{Float32}})
    @ FixedSizeArrays ~/.julia/packages/FixedSizeArrays/VHLXx/src/FixedSizeArray.jl:334
 [14] (FixedSizeArrayDefault)(src::Vector{Reactant.TracedRNumber{Float32}})
    @ FixedSizeArrays ~/.julia/packages/FixedSizeArrays/VHLXx/src/FixedSizeArray.jl:345
 [15] make_tracer(seen::Reactant.OrderedIdDict{Any, Any}, prev::FixedSizeArray{Float32, 1, Memory{Float32}}, path::Any, mode::Reactant.TraceMode; kwargs::@Kwargs{runtime::Val{:PJRT}})
    @ ReactantFixedSizeArraysExt ~/projects/Reactant.jl/ext/ReactantFixedSizeArraysExt.jl:33

I don't know why it's throwing this, as after inspection the inner Memory object is traced properly.

Without it I get:

ERROR: cannot copy Ptr{Nothing} @0x0000795f973dc1e0 of type Ptr{Nothing}
Stacktrace:
 [1] error(s::String)
   @ Base ./error.jl:35
 [2] create_result(tocopy::Ptr{…}, path::Tuple{…}, result_stores::Dict{…}, path_to_shard_info::Nothing, to_unreshard_results::Dict{…}, unresharded_code::Vector{…}, unresharded_arrays_cache::Dict{…}, used_shardinfo::Set{…}, result_cache::IdDict{…}, var_idx::Base.RefValue{…}, resultgen_code::Vector{…})
   @ Reactant.Compiler ~/projects/Reactant.jl/src/Compiler.jl:231
 [3] create_result(tocopy::Memory{…}, path::Tuple{…}, result_stores::Dict{…}, path_to_shard_info::Nothing, to_unreshard_results::Dict{…}, unresharded_code::Vector{…}, unresharded_arrays_cache::Dict{…}, used_shardinfo::Set{…}, result_cache::IdDict{…}, var_idx::Base.RefValue{…}, resultgen_code::Vector{…})
   @ Reactant.Compiler ~/projects/Reactant.jl/src/Compiler.jl:256
 [4] create_result(tocopy::FixedSizeArray{…}, path::Tuple{}, result_stores::Dict{…}, path_to_shard_info::Nothing, to_unreshard_results::Dict{…}, unresharded_code::Vector{…}, unresharded_arrays_cache::Dict{…}, used_shardinfo::Set{…}, result_cache::IdDict{…}, var_idx::Base.RefValue{…}, resultgen_code::Vector{…})
   @ Reactant.Compiler ~/projects/Reactant.jl/src/Compiler.jl:256
 [5] codegen_unflatten!(linear_args::Vector{…}, preserved_args::Vector{…}, concretized_res_names::Vector{…}, linear_results::Vector{…}, concrete_result::FixedSizeArray{…}, result_stores::Dict{…}, path_to_shard_info::Nothing, linear_result_shard_info::Vector{…}, client::Reactant.XLA.PJRT.Client, resharded_inputs::Dict{…})
   @ Reactant.Compiler ~/projects/Reactant.jl/src/Compiler.jl:3147
 [6] compile(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
   @ Reactant.Compiler ~/projects/Reactant.jl/src/Compiler.jl:3599
 [7] top-level scope
   @ ~/projects/Reactant.jl/src/Compiler.jl:2614

For this one I don't have a clue.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So:

mem = Memory{Float32}([1.0f0, 2.0f0])
getfield(mem, 2)
Ptr{nothing} <pointer address>

This is where the pointer is coming from.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since Reactant isn't supporting resizing for normal arrays, and FixedSizeArray isn't doing any fancy memory/optimization stuff other than fixed size (unlike OneHotArrays, FillArrays) shouldn't it be fine for the FixedSizeArray to transform into an ordinary Concrete Array?

Reactant.make_tracer(
seen, parent(prev), (path..., 1), mode; kwargs..., track_numbers=Number
),
shape,
)
end

end
51 changes: 48 additions & 3 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1516,11 +1516,11 @@ Base.@nospecializeinfer function make_tracer(
)
end

Base.@nospecializeinfer function make_tracer(
Base.@nospecializeinfer function make_tracer_array(
seen,
@nospecialize(prev::Array),
@nospecialize(prev::AbstractArray),
@nospecialize(path),
mode;
mode,
@nospecialize(track_numbers::Type = Union{}),
@nospecialize(sharding = Sharding.NoSharding()),
@nospecialize(runtime = nothing),
Expand Down Expand Up @@ -1605,6 +1605,23 @@ Base.@nospecializeinfer function make_tracer(
return newa
end

Base.@nospecializeinfer function make_tracer(
seen,
@nospecialize(prev::Array),
@nospecialize(path),
mode;
@nospecialize(track_numbers::Type = Union{}),
@nospecialize(sharding = Sharding.NoSharding()),
@nospecialize(runtime = nothing),
@nospecialize(device = nothing),
@nospecialize(client = nothing),
kwargs...,
)
return make_tracer_array(
seen, prev, path, mode, track_numbers, sharding, runtime, device, client, kwargs...
)
end

Base.@nospecializeinfer function make_tracer(
seen,
@nospecialize(prev::Dict{Key,Value}),
Expand Down Expand Up @@ -1812,6 +1829,34 @@ Base.@nospecializeinfer function make_tracer(
return res
end

if isdefined(Base, :Memory)
Base.@nospecializeinfer function make_tracer(
seen,
@nospecialize(prev::Memory),
@nospecialize(path),
mode;
@nospecialize(track_numbers::Type = Union{}),
@nospecialize(sharding = Sharding.NoSharding()),
@nospecialize(runtime = nothing),
@nospecialize(device = nothing),
@nospecialize(client = nothing),
kwargs...,
)
return make_tracer_array(
seen,
prev,
path,
mode,
track_numbers,
sharding,
runtime,
device,
client,
kwargs...,
)
end
end

Base.@nospecializeinfer function make_tracer(
seen,
@nospecialize(prev::Sharding.Mesh),
Expand Down
59 changes: 53 additions & 6 deletions src/Types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ function ConcretePJRTArray{T,N}(
return ConcretePJRTArray{T,N,D,typeof(sharding)}(data, shape, sharding)
end

function ConcretePJRTArray(
data::Array{T,N};
function make_concrete_PJRT_array(
data::AbstractArray{T,N},
client::Union{Nothing,XLA.PJRT.Client}=nothing,
idx::Union{Int,Nothing}=nothing,
device::Union{Nothing,XLA.PJRT.Device}=nothing,
Expand All @@ -228,6 +228,28 @@ function ConcretePJRTArray(
return ConcretePJRTArray{T,N,nsharded,typeof(shardinfo)}(sharded_data, shape, shardinfo)
end

function ConcretePJRTArray(
data::Array{T,N};
client::Union{Nothing,XLA.PJRT.Client}=nothing,
idx::Union{Int,Nothing}=nothing,
device::Union{Nothing,XLA.PJRT.Device}=nothing,
sharding::Sharding.AbstractSharding=Sharding.NoSharding(),
) where {T,N}
return make_concrete_PJRT_array(data, client, idx, device, sharding)
end

if isdefined(Base, :Memory)
function ConcretePJRTArray(
data::Memory{T};
client::Union{Nothing,XLA.PJRT.Client}=nothing,
idx::Union{Int,Nothing}=nothing,
device::Union{Nothing,XLA.PJRT.Device}=nothing,
sharding::Sharding.AbstractSharding=Sharding.NoSharding(),
) where {T}
return make_concrete_PJRT_array(data, client, idx, device, sharding)
end
end

Base.wait(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = foreach(wait, x.data)
XLA.client(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = XLA.client(x.data)
function XLA.device(x::Union{ConcretePJRTArray,ConcretePJRTNumber})
Expand Down Expand Up @@ -342,8 +364,8 @@ function ConcreteIFRTArray{T,N}(
return ConcreteIFRTArray{T,N,typeof(sharding)}(data, shape, sharding)
end

function ConcreteIFRTArray(
data::Array{T,N};
function make_concrete_IFRT_array(
data::AbstractArray{T,N},
client::Union{Nothing,XLA.IFRT.Client}=nothing,
idx::Union{Int,Nothing}=nothing,
device::Union{Nothing,XLA.IFRT.Device}=nothing,
Expand All @@ -356,6 +378,28 @@ function ConcreteIFRTArray(
return ConcreteIFRTArray{T,N,typeof(shardinfo)}(sharded_data, shape, shardinfo, padding)
end

function ConcreteIFRTArray(
data::Array{T,N};
client::Union{Nothing,XLA.IFRT.Client}=nothing,
idx::Union{Int,Nothing}=nothing,
device::Union{Nothing,XLA.IFRT.Device}=nothing,
sharding::Sharding.AbstractSharding=Sharding.NoSharding(),
) where {T,N}
return make_concrete_IFRT_array(data, client, idx, device, sharding)
end

if isdefined(Base, :Memory)
function ConcreteIFRTArray(
data::Memory{T};
client::Union{Nothing,XLA.IFRT.Client}=nothing,
idx::Union{Int,Nothing}=nothing,
device::Union{Nothing,XLA.IFRT.Device}=nothing,
sharding::Sharding.AbstractSharding=Sharding.NoSharding(),
) where {T}
return make_concrete_IFRT_array(data, client, idx, device, sharding)
end
end

# Assemble data from multiple arrays. Needed in distributed setting where each process wont
# have enough host memory to hold all the arrays. We assume that the data is only provided
# for all of the addressable devices.
Expand Down Expand Up @@ -472,8 +516,11 @@ elseif XLA.REACTANT_XLA_RUNTIME == "IFRT"
ConcreteIFRTArray
end

@inline ConcreteRArray{T}(::UndefInitializer, shape::Integer...; kwargs...) where {T} =
ConcreteRArray{T}(undef, Dims(shape); kwargs...)
@inline ConcreteRArray{T}(::UndefInitializer, shape::Integer...; kwargs...) where {T} = ConcreteRArray{
T
}(
undef, Dims(shape); kwargs...
)
Comment on lines +519 to +523
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/EnzymeAD/Reactant.jl/actions/runs/17949332627/job/51048248283?pr=1669#step:2:570

Suggested change
@inline ConcreteRArray{T}(::UndefInitializer, shape::Integer...; kwargs...) where {T} = ConcreteRArray{
T
}(
undef, Dims(shape); kwargs...
)
@inline ConcreteRArray{T}(::UndefInitializer, shape::Integer...; kwargs...) where {T} =
ConcreteRArray{T}(undef, Dims(shape); kwargs...)


"""
ConcreteRNumber(
Expand Down
63 changes: 54 additions & 9 deletions src/xla/IFRT/Array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,8 @@ function Array(
return Array(client, fill(array), device, memory_kind)
end

function Array(
client::Client,
array::Base.Array{T,N},
device::Device=XLA.default_device(client),
memory_kind::AbstractString=string(convert(MemoryKind, XLA.default_memory(device))),
function make_array_singleshard(
client::Client, array::AbstractArray{T,N}, device::Device, memory_kind::AbstractString
) where {T<:Reactant.ReactantPrimitive,N}
sizear = collect(Int64, reverse(size(array)))
buffer = GC.@preserve array sizear begin
Expand All @@ -39,7 +36,27 @@ function Array(
end

function Array(
client::Client, array::Base.Array{T,N}, sharding::Sharding
client::Client,
array::Base.Array{T,N},
device::Device=XLA.default_device(client),
memory_kind::AbstractString=string(convert(MemoryKind, XLA.default_memory(device))),
) where {T<:Reactant.ReactantPrimitive,N}
return make_array_singleshard(client, array, device, memory_kind)
end

if isdefined(Base, :Memory)
function Array(
client::Client,
memory::Base.Memory{T},
device::Device=XLA.default_device(client),
memory_kind::AbstractString=string(convert(MemoryKind, XLA.default_memory(device))),
) where {T<:Reactant.ReactantPrimitive}
return make_array_singleshard(client, memory, device, memory_kind)
end
end

function make_array_sharding(
client::Client, array::AbstractArray{T,N}, sharding::Sharding
) where {T<:Reactant.ReactantPrimitive,N}
all_devices = XLA.devices(sharding)
all_logical_device_ids = collect(Int64, 0:(length(all_devices) - 1))
Expand Down Expand Up @@ -75,14 +92,20 @@ function Array(
return Array(client, host_buffers, addressable_shard_indices, size(array), sharding)
end

function Array(
client::Client, array::Base.Array{T,N}, sharding::Sharding
) where {T<:Reactant.ReactantPrimitive,N}
return make_array_sharding(client, array, sharding)
end

function Array(
client::Client,
host_buffers::Vector{Base.Array{T,N}},
addressable_shard_indices::Vector{Vector{Int64}},
array_shape,
sharding::Sharding,
) where {T<:Reactant.ReactantPrimitive,N}
# Construct using the slower path, the faster path is only implemented for IFRT-Proxy
# make using the slower path, the faster path is only implemented for IFRT-Proxy
# and seems to cause issues with IFRT-PJRT
all_addressable_devices = filter(XLA.is_addressable, XLA.devices(sharding))

Expand Down Expand Up @@ -143,8 +166,16 @@ function Array(
return Array(buffer)
end

function Array(
client::Client, array::Base.Array{T,N}, sharding
if isdefined(Base, :Memory)
function Array(
client::Client, memory::Base.Memory{T}, sharding::Sharding
) where {T<:Reactant.ReactantPrimitive}
return make_array_sharding(client, memory, sharding)
end
end

function make_array_ifrt_sharding(
client::Client, array::Base.AbstractArray{T,N}, sharding
) where {T<:Reactant.ReactantPrimitive,N}
@assert sharding isa Reactant.Sharding.AbstractSharding
if !(sharding isa Reactant.Sharding.HloSharding)
Expand All @@ -158,6 +189,20 @@ function Array(
return Array(client, array, ifrt_sharding)
end

function Array(
client::Client, array::Base.Array{T,N}, sharding
) where {T<:Reactant.ReactantPrimitive,N}
return make_array_ifrt_sharding(client, array, sharding)
end

if isdefined(Base, :Memory)
function Array(
client::Client, memory::Base.Memory{T}, sharding
) where {T<:Reactant.ReactantPrimitive}
return make_array_ifrt_sharding(client, memory, sharding)
end
end

@inline function XLA.free_buffer(buffer::Array)
if buffer.buffer != C_NULL
@ccall MLIR.API.mlir_c.ifrt_free_array(buffer.buffer::Ptr{Cvoid})::Cvoid
Expand Down
14 changes: 13 additions & 1 deletion src/xla/PJRT/Buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ mutable struct Buffer <: XLA.AbstractBuffer
end
end

function Buffer(client::Client, array::Array{T,N}, device::Device) where {T,N}
function make_buffer_array(
client::Client, array::AbstractArray{T,N}, device::Device
) where {T,N}
sizear = collect(Int64, reverse(size(array)))
buffer = GC.@preserve array sizear begin
@ccall MLIR.API.mlir_c.ArrayFromHostBuffer(
Expand All @@ -21,6 +23,16 @@ function Buffer(client::Client, array::Array{T,N}, device::Device) where {T,N}
return Buffer(buffer)
end

function Buffer(client::Client, array::Array{T,N}, device::Device) where {T,N}
return make_buffer_array(client, array, device)
end

if isdefined(Base, :Memory)
function Buffer(client::Client, memory::Memory{T}, device::Device) where {T}
return make_buffer_array(client, memory, device)
end
end

function Base.similar(a::Buffer)
buffer = GC.@preserve a begin
@ccall MLIR.API.mlir_c.UninitPJRTBuffer(
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
FixedSizeArrays = "3821ddf9-e5b5-40d5-8e25-6813ab96b5e2"
Float8s = "81dfefd7-55b0-40c6-a251-db853704e186"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Expand Down
17 changes: 17 additions & 0 deletions test/integration/fixedsizearrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@

using Reactant, Test, FixedSizeArrays

fn(x, y) = (2 .* x .- 3) * y'

@testset "FixedSizeArrays" begin
@testset "1D" begin
x = FixedSizeArray(fill(3.0f0, 100))
rx = Reactant.to_rarray(x)
@test @jit(fn(rx, rx)) ≈ fn(x, x)
end
@testset "2D" begin
x = FixedSizeArray(fill(3.0f0, (4, 5)))
rx = Reactant.to_rarray(x)
@test @jit(fn(rx, rx)) ≈ fn(x, x)
end
end
10 changes: 10 additions & 0 deletions test/memory.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
using Reactant, Test

fn(x, y) = sin.(x) .+ cos.(y)

@testset "Memory test" begin
x = Memory{Float32}(fill(2.0f0, 10))
x_ra = Reactant.to_rarray(x)

@test @jit(fn(x_ra, x_ra)) ≈ fn(x, x)
end
Loading
Loading