diff --git a/Project.toml b/Project.toml index 88908f154b..55f5bed6f1 100644 --- a/Project.toml +++ b/Project.toml @@ -47,6 +47,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df" +FixedSizeArrays = "3821ddf9-e5b5-40d5-8e25-6813ab96b5e2" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [sources] @@ -58,6 +59,7 @@ ReactantArrayInterfaceExt = "ArrayInterface" ReactantCUDAExt = ["CUDA", "GPUCompiler", "KernelAbstractions", "LLVM"] ReactantDLFP8TypesExt = "DLFP8Types" ReactantFillArraysExt = "FillArrays" +ReactantFixedSizeArraysExt = "FixedSizeArrays" ReactantFloat8sExt = "Float8s" ReactantKernelAbstractionsExt = "KernelAbstractions" ReactantMPIExt = "MPI" @@ -84,6 +86,7 @@ EnumX = "1" Enzyme = "0.13.78" EnzymeCore = "0.8.13" FillArrays = "1.13" +FixedSizeArrays = "1.2.0" Float8s = "0.1" Functors = "0.5" GPUArraysCore = "0.2" diff --git a/ext/ReactantFixedSizeArraysExt.jl b/ext/ReactantFixedSizeArraysExt.jl new file mode 100644 index 0000000000..ba5045e07a --- /dev/null +++ b/ext/ReactantFixedSizeArraysExt.jl @@ -0,0 +1,36 @@ +module ReactantFixedSizeArraysExt + +using FixedSizeArrays +using Reactant +using Reactant: TracedRArray, TracedRNumber, Ops +using ReactantCore: ReactantCore + +function Reactant.traced_type_inner( + @nospecialize(_::Type{FixedSizeArrays.FixedSizeArrayDefault{T,N}}), + seen, + @nospecialize(mode::Reactant.TraceMode), + @nospecialize(track_numbers::Type), + @nospecialize(sharding), + @nospecialize(runtime) +) where {T,N} + T2 = Reactant.TracedRNumber{T} + return FixedSizeArrays.FixedSizeArrayDefault{T2,N} +end + +Base.@nospecializeinfer function Reactant.make_tracer( + seen, + @nospecialize(prev::FixedSizeArrays.FixedSizeArrayDefault{T,N}), + @nospecialize(path), + mode; + kwargs..., +) where {T,N} + shape = size(prev) + return reshape( + Reactant.make_tracer( + seen, parent(prev), (path..., 1), mode; kwargs..., track_numbers=Number + ), + shape, + ) +end + +end diff --git a/src/Tracing.jl b/src/Tracing.jl index 015a2ccff7..41ecd8e601 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -1516,11 +1516,11 @@ Base.@nospecializeinfer function make_tracer( ) end -Base.@nospecializeinfer function make_tracer( +Base.@nospecializeinfer function make_tracer_array( seen, - @nospecialize(prev::Array), + @nospecialize(prev::AbstractArray), @nospecialize(path), - mode; + mode, @nospecialize(track_numbers::Type = Union{}), @nospecialize(sharding = Sharding.NoSharding()), @nospecialize(runtime = nothing), @@ -1605,6 +1605,23 @@ Base.@nospecializeinfer function make_tracer( return newa end +Base.@nospecializeinfer function make_tracer( + seen, + @nospecialize(prev::Array), + @nospecialize(path), + mode; + @nospecialize(track_numbers::Type = Union{}), + @nospecialize(sharding = Sharding.NoSharding()), + @nospecialize(runtime = nothing), + @nospecialize(device = nothing), + @nospecialize(client = nothing), + kwargs..., +) + return make_tracer_array( + seen, prev, path, mode, track_numbers, sharding, runtime, device, client, kwargs... + ) +end + Base.@nospecializeinfer function make_tracer( seen, @nospecialize(prev::Dict{Key,Value}), @@ -1812,6 +1829,34 @@ Base.@nospecializeinfer function make_tracer( return res end +if isdefined(Base, :Memory) + Base.@nospecializeinfer function make_tracer( + seen, + @nospecialize(prev::Memory), + @nospecialize(path), + mode; + @nospecialize(track_numbers::Type = Union{}), + @nospecialize(sharding = Sharding.NoSharding()), + @nospecialize(runtime = nothing), + @nospecialize(device = nothing), + @nospecialize(client = nothing), + kwargs..., + ) + return make_tracer_array( + seen, + prev, + path, + mode, + track_numbers, + sharding, + runtime, + device, + client, + kwargs..., + ) + end +end + Base.@nospecializeinfer function make_tracer( seen, @nospecialize(prev::Sharding.Mesh), diff --git a/src/Types.jl b/src/Types.jl index cc257c4ebf..09d47d5fb9 100644 --- a/src/Types.jl +++ b/src/Types.jl @@ -214,8 +214,8 @@ function ConcretePJRTArray{T,N}( return ConcretePJRTArray{T,N,D,typeof(sharding)}(data, shape, sharding) end -function ConcretePJRTArray( - data::Array{T,N}; +function make_concrete_PJRT_array( + data::AbstractArray{T,N}, client::Union{Nothing,XLA.PJRT.Client}=nothing, idx::Union{Int,Nothing}=nothing, device::Union{Nothing,XLA.PJRT.Device}=nothing, @@ -228,6 +228,28 @@ function ConcretePJRTArray( return ConcretePJRTArray{T,N,nsharded,typeof(shardinfo)}(sharded_data, shape, shardinfo) end +function ConcretePJRTArray( + data::Array{T,N}; + client::Union{Nothing,XLA.PJRT.Client}=nothing, + idx::Union{Int,Nothing}=nothing, + device::Union{Nothing,XLA.PJRT.Device}=nothing, + sharding::Sharding.AbstractSharding=Sharding.NoSharding(), +) where {T,N} + return make_concrete_PJRT_array(data, client, idx, device, sharding) +end + +if isdefined(Base, :Memory) + function ConcretePJRTArray( + data::Memory{T}; + client::Union{Nothing,XLA.PJRT.Client}=nothing, + idx::Union{Int,Nothing}=nothing, + device::Union{Nothing,XLA.PJRT.Device}=nothing, + sharding::Sharding.AbstractSharding=Sharding.NoSharding(), + ) where {T} + return make_concrete_PJRT_array(data, client, idx, device, sharding) + end +end + Base.wait(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = foreach(wait, x.data) XLA.client(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) = XLA.client(x.data) function XLA.device(x::Union{ConcretePJRTArray,ConcretePJRTNumber}) @@ -342,8 +364,8 @@ function ConcreteIFRTArray{T,N}( return ConcreteIFRTArray{T,N,typeof(sharding)}(data, shape, sharding) end -function ConcreteIFRTArray( - data::Array{T,N}; +function make_concrete_IFRT_array( + data::AbstractArray{T,N}, client::Union{Nothing,XLA.IFRT.Client}=nothing, idx::Union{Int,Nothing}=nothing, device::Union{Nothing,XLA.IFRT.Device}=nothing, @@ -356,6 +378,28 @@ function ConcreteIFRTArray( return ConcreteIFRTArray{T,N,typeof(shardinfo)}(sharded_data, shape, shardinfo, padding) end +function ConcreteIFRTArray( + data::Array{T,N}; + client::Union{Nothing,XLA.IFRT.Client}=nothing, + idx::Union{Int,Nothing}=nothing, + device::Union{Nothing,XLA.IFRT.Device}=nothing, + sharding::Sharding.AbstractSharding=Sharding.NoSharding(), +) where {T,N} + return make_concrete_IFRT_array(data, client, idx, device, sharding) +end + +if isdefined(Base, :Memory) + function ConcreteIFRTArray( + data::Memory{T}; + client::Union{Nothing,XLA.IFRT.Client}=nothing, + idx::Union{Int,Nothing}=nothing, + device::Union{Nothing,XLA.IFRT.Device}=nothing, + sharding::Sharding.AbstractSharding=Sharding.NoSharding(), + ) where {T} + return make_concrete_IFRT_array(data, client, idx, device, sharding) + end +end + # Assemble data from multiple arrays. Needed in distributed setting where each process wont # have enough host memory to hold all the arrays. We assume that the data is only provided # for all of the addressable devices. @@ -472,8 +516,11 @@ elseif XLA.REACTANT_XLA_RUNTIME == "IFRT" ConcreteIFRTArray end -@inline ConcreteRArray{T}(::UndefInitializer, shape::Integer...; kwargs...) where {T} = - ConcreteRArray{T}(undef, Dims(shape); kwargs...) +@inline ConcreteRArray{T}(::UndefInitializer, shape::Integer...; kwargs...) where {T} = ConcreteRArray{ + T +}( + undef, Dims(shape); kwargs... +) """ ConcreteRNumber( diff --git a/src/xla/IFRT/Array.jl b/src/xla/IFRT/Array.jl index 5a618dd131..c45dfc34c4 100644 --- a/src/xla/IFRT/Array.jl +++ b/src/xla/IFRT/Array.jl @@ -16,11 +16,8 @@ function Array( return Array(client, fill(array), device, memory_kind) end -function Array( - client::Client, - array::Base.Array{T,N}, - device::Device=XLA.default_device(client), - memory_kind::AbstractString=string(convert(MemoryKind, XLA.default_memory(device))), +function make_array_singleshard( + client::Client, array::AbstractArray{T,N}, device::Device, memory_kind::AbstractString ) where {T<:Reactant.ReactantPrimitive,N} sizear = collect(Int64, reverse(size(array))) buffer = GC.@preserve array sizear begin @@ -39,7 +36,27 @@ function Array( end function Array( - client::Client, array::Base.Array{T,N}, sharding::Sharding + client::Client, + array::Base.Array{T,N}, + device::Device=XLA.default_device(client), + memory_kind::AbstractString=string(convert(MemoryKind, XLA.default_memory(device))), +) where {T<:Reactant.ReactantPrimitive,N} + return make_array_singleshard(client, array, device, memory_kind) +end + +if isdefined(Base, :Memory) + function Array( + client::Client, + memory::Base.Memory{T}, + device::Device=XLA.default_device(client), + memory_kind::AbstractString=string(convert(MemoryKind, XLA.default_memory(device))), + ) where {T<:Reactant.ReactantPrimitive} + return make_array_singleshard(client, memory, device, memory_kind) + end +end + +function make_array_sharding( + client::Client, array::AbstractArray{T,N}, sharding::Sharding ) where {T<:Reactant.ReactantPrimitive,N} all_devices = XLA.devices(sharding) all_logical_device_ids = collect(Int64, 0:(length(all_devices) - 1)) @@ -75,6 +92,12 @@ function Array( return Array(client, host_buffers, addressable_shard_indices, size(array), sharding) end +function Array( + client::Client, array::Base.Array{T,N}, sharding::Sharding +) where {T<:Reactant.ReactantPrimitive,N} + return make_array_sharding(client, array, sharding) +end + function Array( client::Client, host_buffers::Vector{Base.Array{T,N}}, @@ -82,7 +105,7 @@ function Array( array_shape, sharding::Sharding, ) where {T<:Reactant.ReactantPrimitive,N} - # Construct using the slower path, the faster path is only implemented for IFRT-Proxy + # make using the slower path, the faster path is only implemented for IFRT-Proxy # and seems to cause issues with IFRT-PJRT all_addressable_devices = filter(XLA.is_addressable, XLA.devices(sharding)) @@ -143,8 +166,16 @@ function Array( return Array(buffer) end -function Array( - client::Client, array::Base.Array{T,N}, sharding +if isdefined(Base, :Memory) + function Array( + client::Client, memory::Base.Memory{T}, sharding::Sharding + ) where {T<:Reactant.ReactantPrimitive} + return make_array_sharding(client, memory, sharding) + end +end + +function make_array_ifrt_sharding( + client::Client, array::Base.AbstractArray{T,N}, sharding ) where {T<:Reactant.ReactantPrimitive,N} @assert sharding isa Reactant.Sharding.AbstractSharding if !(sharding isa Reactant.Sharding.HloSharding) @@ -158,6 +189,20 @@ function Array( return Array(client, array, ifrt_sharding) end +function Array( + client::Client, array::Base.Array{T,N}, sharding +) where {T<:Reactant.ReactantPrimitive,N} + return make_array_ifrt_sharding(client, array, sharding) +end + +if isdefined(Base, :Memory) + function Array( + client::Client, memory::Base.Memory{T}, sharding + ) where {T<:Reactant.ReactantPrimitive} + return make_array_ifrt_sharding(client, memory, sharding) + end +end + @inline function XLA.free_buffer(buffer::Array) if buffer.buffer != C_NULL @ccall MLIR.API.mlir_c.ifrt_free_array(buffer.buffer::Ptr{Cvoid})::Cvoid diff --git a/src/xla/PJRT/Buffer.jl b/src/xla/PJRT/Buffer.jl index 2b36292c93..a8a6725717 100644 --- a/src/xla/PJRT/Buffer.jl +++ b/src/xla/PJRT/Buffer.jl @@ -6,7 +6,9 @@ mutable struct Buffer <: XLA.AbstractBuffer end end -function Buffer(client::Client, array::Array{T,N}, device::Device) where {T,N} +function make_buffer_array( + client::Client, array::AbstractArray{T,N}, device::Device +) where {T,N} sizear = collect(Int64, reverse(size(array))) buffer = GC.@preserve array sizear begin @ccall MLIR.API.mlir_c.ArrayFromHostBuffer( @@ -21,6 +23,16 @@ function Buffer(client::Client, array::Array{T,N}, device::Device) where {T,N} return Buffer(buffer) end +function Buffer(client::Client, array::Array{T,N}, device::Device) where {T,N} + return make_buffer_array(client, array, device) +end + +if isdefined(Base, :Memory) + function Buffer(client::Client, memory::Memory{T}, device::Device) where {T} + return make_buffer_array(client, memory, device) + end +end + function Base.similar(a::Buffer) buffer = GC.@preserve a begin @ccall MLIR.API.mlir_c.UninitPJRTBuffer( diff --git a/test/Project.toml b/test/Project.toml index 05b1e8906a..1fa4be3887 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -9,6 +9,7 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +FixedSizeArrays = "3821ddf9-e5b5-40d5-8e25-6813ab96b5e2" Float8s = "81dfefd7-55b0-40c6-a251-db853704e186" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" diff --git a/test/integration/fixedsizearrays.jl b/test/integration/fixedsizearrays.jl new file mode 100644 index 0000000000..3200df2e57 --- /dev/null +++ b/test/integration/fixedsizearrays.jl @@ -0,0 +1,17 @@ + +using Reactant, Test, FixedSizeArrays + +fn(x, y) = (2 .* x .- 3) * y' + +@testset "FixedSizeArrays" begin + @testset "1D" begin + x = FixedSizeArray(fill(3.0f0, 100)) + rx = Reactant.to_rarray(x) + @test @jit(fn(rx, rx)) ≈ fn(x, x) + end + @testset "2D" begin + x = FixedSizeArray(fill(3.0f0, (4, 5))) + rx = Reactant.to_rarray(x) + @test @jit(fn(rx, rx)) ≈ fn(x, x) + end +end diff --git a/test/memory.jl b/test/memory.jl new file mode 100644 index 0000000000..d2a9e13558 --- /dev/null +++ b/test/memory.jl @@ -0,0 +1,10 @@ +using Reactant, Test + +fn(x, y) = sin.(x) .+ cos.(y) + +@testset "Memory test" begin + x = Memory{Float32}(fill(2.0f0, 10)) + x_ra = Reactant.to_rarray(x) + + @test @jit(fn(x_ra, x_ra)) ≈ fn(x, x) +end diff --git a/test/runtests.jl b/test/runtests.jl index 98a02a7de0..1d078e9d35 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -38,6 +38,9 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Config" include("config.jl") @safetestset "Batching" include("batching.jl") @safetestset "QA" include("qa.jl") + if isdefined(Base, :Memory) + @safetestset "Memory" include("memory.jl") + end end if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" @@ -52,6 +55,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Python" include("integration/python.jl") @safetestset "Optimisers" include("integration/optimisers.jl") @safetestset "FillArrays" include("integration/fillarrays.jl") + @safetestset "FixedSizeArrays" include("integration/fixedsizearrays.jl") @safetestset "Zygote" include("integration/zygote.jl") end