Skip to content

Conversation

Qfl3x
Copy link

@Qfl3x Qfl3x commented Sep 16, 2025

#1649
Struggling with the Memory object. This is what I'm getting currently:

ERROR: NoFieldMatchError:
Cannot convert type Memory{Float32}, best attempt Memory{ConcretePJRTNumber{Float32, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}} failed.
This could be because the type does not capture the fieldtypes that should be converted in its type parameters.
name=length idx=1 Derived: ConcretePJRTNumber{Int64, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}} Existing: Int64 Best Attempt: Int64
name=ptr idx=2 Derived: Ptr{Nothing} Existing: Ptr{Nothing} Best Attempt: Ptr{Nothing}

Stacktrace:
  [1] traced_type_inner(T::Type, seen::Dict{Type, Type}, mode::Reactant.TraceMode, track_numbers::Type, sharding::Any, runtime::Any)
    @ Reactant ~/projects/Reactant.jl/src/Tracing.jl:759
  [2] traced_type(T::Type, ::Val{Reactant.ArrayToConcrete}, track_numbers::Type, sharding::Reactant.Sharding.NoSharding, runtime::Val{:PJRT})
    @ Reactant ~/projects/Reactant.jl/src/Tracing.jl:870
  [3] make_tracer_unknown(seen::Reactant.OrderedIdDict{Any, Any}, prev::Any, path::Any, mode::Reactant.TraceMode; track_numbers::Type, sharding::Any, runtime::Any, kwargs::@Kwargs{device::Nothing, client::Nothing})
    @ Reactant ~/projects/Reactant.jl/src/Tracing.jl:1032
  [4] make_tracer_unknown
    @ ~/projects/Reactant.jl/src/Tracing.jl:1009 [inlined]
  [5] #make_tracer#191
    @ ~/projects/Reactant.jl/src/Tracing.jl:1146 [inlined]
  [6] make_tracer
    @ ~/projects/Reactant.jl/src/Tracing.jl:1136 [inlined]
  [7] make_tracer(seen::Reactant.OrderedIdDict{…}, prev::FixedSizeArray{…}, path::Any, mode::Reactant.TraceMode; kwargs::@Kwargs{…})
    @ ReactantFixedSizeArraysExt ~/projects/Reactant.jl/ext/ReactantFixedSizeArraysExt.jl:27
  [8] to_rarray_internal(x::Any, track_numbers::Type, sharding::Any, runtime::Any, device::Any, client::Any)
    @ Reactant ~/projects/Reactant.jl/src/Tracing.jl:1848
  [9] #to_rarray#156
    @ ~/projects/Reactant.jl/src/Tracing.jl:1841 [inlined]
 [10] to_rarray(x::Any)
    @ Reactant ~/projects/Reactant.jl/src/Tracing.jl:1827
 [11] top-level scope
    @ REPL[10]:1
Some type information was truncated. Use `show(err)` to see complete types.

Copy link

codecov bot commented Sep 16, 2025

Codecov Report

❌ Patch coverage is 75.60976% with 10 lines in your changes missing coverage. Please review.
✅ Project coverage is 69.82%. Comparing base (b39a1fc) to head (a033cc6).
⚠️ Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
src/xla/IFRT/Array.jl 50.00% 7 Missing ⚠️
ext/ReactantFixedSizeArraysExt.jl 50.00% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1669      +/-   ##
==========================================
+ Coverage   68.16%   69.82%   +1.66%     
==========================================
  Files         109      111       +2     
  Lines       11779    12157     +378     
==========================================
+ Hits         8029     8489     +460     
+ Misses       3750     3668      -82     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@Qfl3x
Copy link
Author

Qfl3x commented Sep 17, 2025

With the latest commit, I'm outputting a ConcretePJRTArray rather than a FixedSizeArray and assuming the default memory backend. So perhaps there is more work to do.

It turned out to be more intricate than I initially thought, as I had to add support for Memory as well. This may need some thinking over.

In addition, I stumbled on the fact that IFRT is broken for Arrays at least: #1672

Tests are still to be done. If this is shippable I'll start writing tests for it. (and for Memory too)

@Qfl3x Qfl3x changed the title [WIP] FixedSizeArrays support [WIP] FixedSizeArrays support (And Memory Support) Sep 17, 2025
@Qfl3x
Copy link
Author

Qfl3x commented Sep 17, 2025

Hmm, the builds are failing because they're running on v1.10 and Memory isn't defined there

@giordano
Copy link
Member

You can guard the Memory-specific bits with

if isdefined(Base, :Memory)
    # ...
end

Relatedly, note that FixedSizeArrays.jl uses Memory only in v1.11+: https://juliaarrays.github.io/FixedSizeArrays.jl/stable/usage/#The-memory-backend

@Qfl3x
Copy link
Author

Qfl3x commented Sep 18, 2025

I've found a problem with this implementation, it only works in the 1D case. the parent of a FixedSizeArray{T,N} is always a Memory object with 1 Dimension, the shape of the array is stored in the original FixedSizeArray.

In the latest commit, I reshaped the RArray to the shape of the underlying FixedSizeArray.

@Qfl3x Qfl3x force-pushed the FixedSizeArrays_support branch from 8625b1a to 658a3d7 Compare September 18, 2025 09:35
@Qfl3x
Copy link
Author

Qfl3x commented Sep 19, 2025

Those parts have been refactored into their own functions (with AbstractArray{T,N} as a type.

@avik-pal
Copy link
Collaborator

1D: Error During Test at D:\a\Reactant.jl\Reactant.jl\test\integration\fixedsizearrays.jl:7
  Got exception outside of a @test
  UndefVarError: `array` not defined in `Reactant.XLA.PJRT`
  Stacktrace:
    [1] Reactant.XLA.PJRT.Buffer(client::Reactant.XLA.PJRT.Client, memory::Memory{Float32}, device::Reactant.XLA.PJRT.Device)
      @ Reactant.XLA.PJRT D:\a\Reactant.jl\Reactant.jl\src\xla\PJRT\Buffer.jl:31
    [2] Reactant.XLA.PJRT.AsyncBuffer(::Reactant.XLA.PJRT.Client, ::Vararg{Any}; kwargs::@Kwargs{})
      @ Reactant.XLA.PJRT D:\a\Reactant.jl\Reactant.jl\src\xla\PJRT\AsyncBuffer.jl:8
    [3] AsyncBuffer
      @ D:\a\Reactant.jl\Reactant.jl\src\xla\PJRT\AsyncBuffer.jl:8 [inlined]
    [4] NoSharding
      @ D:\a\Reactant.jl\Reactant.jl\src\Sharding.jl:193 [inlined]
    [5] make_concrete_PJRT_array(data::Memory{Float32}, client::Nothing, idx::Nothing, device::Nothing, sharding::Reactant.Sharding.NoSharding)
      @ Reactant D:\a\Reactant.jl\Reactant.jl\src\Types.jl:225
    [6] #ConcretePJRTArray#35
      @ D:\a\Reactant.jl\Reactant.jl\src\Types.jl:249 [inlined]
    [7] make_tracer_array(::Reactant.OrderedIdDict{Any, Any}, ::AbstractArray, ::Any, ::Reactant.TraceMode, ::Type, ::Any, ::Any, ::Any, ::Any)
      @ Reactant D:\a\Reactant.jl\Reactant.jl\src\Tracing.jl:1544
    [8] #make_tracer#162
      @ D:\a\Reactant.jl\Reactant.jl\src\Tracing.jl:1843 [inlined]
    [9] make_tracer
      @ D:\a\Reactant.jl\Reactant.jl\src\Tracing.jl:1831 [inlined]
   [10] make_tracer(seen::Reactant.OrderedIdDict{Any, Any}, prev::FixedSizeArrays.FixedSizeArray{Float32, 1, Memory{Float32}}, path::Any, mode::Reactant.TraceMode; kwargs::@Kwargs{track_numbers::Core.TypeofBottom, sharding::Reactant.Sharding.NoSharding, runtime::Val{:PJRT}, device::Nothing, client::Nothing})
      @ ReactantFixedSizeArraysExt D:\a\Reactant.jl\Reactant.jl\ext\ReactantFixedSizeArraysExt.jl:28
   [11] to_rarray_internal(x::Any, track_numbers::Type, sharding::Any, runtime::Any, device::Any, client::Any)
      @ Reactant D:\a\Reactant.jl\Reactant.jl\src\Tracing.jl:1881

probably should be memory instead of array

@avik-pal avik-pal requested a review from giordano September 19, 2025 14:23
@giordano
Copy link
Member

https://github.com/EnzymeAD/Reactant.jl/actions/runs/17861101375/job/50791872236?pr=1669#step:11:920

Memory test: Error During Test at /home/runner/work/Reactant.jl/Reactant.jl/test/memory.jl:5
  Got exception outside of a @test
  MethodError: no method matching make_buffer_array(::Reactant.XLA.PJRT.Client, ::Memory{Float32}, ::Reactant.XLA.PJRT.Device)
  The function `make_buffer_array` exists, but no method is defined for this combination of argument types.
  
  Closest candidates are:
    make_buffer_array(::Reactant.XLA.PJRT.Client, ::Array{T, N}, ::Reactant.XLA.PJRT.Device) where {T, N}
     @ Reactant ~/work/Reactant.jl/Reactant.jl/src/xla/PJRT/Buffer.jl:9
  
  Stacktrace:
    [1] Reactant.XLA.PJRT.Buffer(client::Reactant.XLA.PJRT.Client, memory::Memory{Float32}, device::Reactant.XLA.PJRT.Device)
      @ Reactant.XLA.PJRT ~/work/Reactant.jl/Reactant.jl/src/xla/PJRT/Buffer.jl:31
    [2] Reactant.XLA.PJRT.AsyncBuffer(::Reactant.XLA.PJRT.Client, ::Vararg{Any}; kwargs::@Kwargs{})
      @ Reactant.XLA.PJRT ~/work/Reactant.jl/Reactant.jl/src/xla/PJRT/AsyncBuffer.jl:8
    [3] (::Reactant.Sharding.NoSharding)(client::Reactant.XLA.PJRT.Client, device::Reactant.XLA.PJRT.Device, x::Memory{Float32})
      @ Reactant.Sharding ~/work/Reactant.jl/Reactant.jl/src/Sharding.jl:193
    [4] make_concrete_PJRT_array(data::Memory{Float32}, client::Nothing, idx::Nothing, device::Nothing, sharding::Reactant.Sharding.NoSharding)
      @ Reactant ~/work/Reactant.jl/Reactant.jl/src/Types.jl:225
    [5] ConcretePJRTArray(data::Memory{Float32}; client::Nothing, idx::Nothing, device::Nothing, sharding::Reactant.Sharding.NoSharding)
      @ Reactant ~/work/Reactant.jl/Reactant.jl/src/Types.jl:249
    [6] make_tracer_array(::Reactant.OrderedIdDict{Any, Any}, ::AbstractArray, ::Any, ::Reactant.TraceMode, ::Type, ::Any, ::Any, ::Any, ::Any)
      @ Reactant ~/work/Reactant.jl/Reactant.jl/src/Tracing.jl:1544
    [7] #make_tracer#162
      @ ~/work/Reactant.jl/Reactant.jl/src/Tracing.jl:1843 [inlined]
    [8] to_rarray_internal(x::Any, track_numbers::Type, sharding::Any, runtime::Any, device::Any, client::Any)
      @ Reactant ~/work/Reactant.jl/Reactant.jl/src/Tracing.jl:1881
    [9] #to_rarray#164
      @ ~/work/Reactant.jl/Reactant.jl/src/Tracing.jl:1870 [inlined]
   [10] to_rarray(x::Any)
      @ Reactant ~/work/Reactant.jl/Reactant.jl/src/Tracing.jl:1860
   [11] macro expansion
      @ ~/work/Reactant.jl/Reactant.jl/test/memory.jl:7 [inlined]
   [12] macro expansion
      @ /opt/hostedtoolcache/julia/1.11.7/x64/share/julia/stdlib/v1.11/Test/src/Test.jl:1709 [inlined]
   [13] top-level scope
      @ ~/work/Reactant.jl/Reactant.jl/test/memory.jl:6
   [14] include(mod::Module, _path::String)
      @ Base ./Base.jl:562
   [15] include(x::String)
      @ Main.var"##Memory#257" ~/.julia/packages/SafeTestsets/raUNr/src/SafeTestsets.jl:28
   [16] macro expansion
      @ ~/.julia/packages/SafeTestsets/raUNr/src/SafeTestsets.jl:24 [inlined]
   [17] macro expansion
      @ /opt/hostedtoolcache/julia/1.11.7/x64/share/julia/stdlib/v1.11/Test/src/Test.jl:1709 [inlined]
   [18] top-level scope
      @ ~/.julia/packages/SafeTestsets/raUNr/src/SafeTestsets.jl:24
   [19] eval(m::Module, e::Any)
      @ Core ./boot.jl:430
   [20] macro expansion
      @ ~/.julia/packages/SafeTestsets/raUNr/src/SafeTestsets.jl:28 [inlined]
   [21] macro expansion
      @ ~/work/Reactant.jl/Reactant.jl/test/runtests.jl:42 [inlined]
   [22] macro expansion
      @ /opt/hostedtoolcache/julia/1.11.7/x64/share/julia/stdlib/v1.11/Test/src/Test.jl:1709 [inlined]
   [23] top-level scope
      @ ~/work/Reactant.jl/Reactant.jl/test/runtests.jl:10
   [24] include(fname::String)
      @ Main ./sysimg.jl:38
   [25] top-level scope
      @ none:6

I can also reproduce it with

julia> using FixedSizeArrays, Reactant

julia> x = FixedSizeArray(randn(3, 4));

julia> rx = Reactant.to_rarray(x);
ERROR: MethodError: no method matching make_buffer_array(::Reactant.XLA.PJRT.Client, ::Memory{Float64}, ::Reactant.XLA.PJRT.Device)
The function `make_buffer_array` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  make_buffer_array(::Reactant.XLA.PJRT.Client, ::Array{T, N}, ::Reactant.XLA.PJRT.Device) where {T, N}
   @ Reactant ~/.julia/packages/Reactant/eaYLR/src/xla/PJRT/Buffer.jl:9

Stacktrace:

@avik-pal avik-pal changed the title [WIP] FixedSizeArrays support (And Memory Support) FixedSizeArrays support (And Memory Support) Sep 23, 2025
Comment on lines +519 to +523
@inline ConcreteRArray{T}(::UndefInitializer, shape::Integer...; kwargs...) where {T} = ConcreteRArray{
T
}(
undef, Dims(shape); kwargs...
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/EnzymeAD/Reactant.jl/actions/runs/17949332627/job/51048248283?pr=1669#step:2:570

Suggested change
@inline ConcreteRArray{T}(::UndefInitializer, shape::Integer...; kwargs...) where {T} = ConcreteRArray{
T
}(
undef, Dims(shape); kwargs...
)
@inline ConcreteRArray{T}(::UndefInitializer, shape::Integer...; kwargs...) where {T} =
ConcreteRArray{T}(undef, Dims(shape); kwargs...)

@giordano
Copy link
Member

julia> using FixedSizeArrays, Reactant

julia> M = randn(3, 4);

julia> M_fs = FixedSizeMatrix(M);

julia> rM = Reactant.to_rarray(M);

julia> rM_fs = Reactant.to_rarray(M_fs);

julia> fn(x, y) = sin.(x) .+ cos.(y)
fn (generic function with 1 method)

julia> @code_hlo fn(rM, rM)
module @reactant_fn attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  func.func @main(%arg0: tensor<4x3xf64>) -> tensor<4x3xf64> attributes {enzymexla.memory_effects = []} {
    %0 = stablehlo.sine %arg0 : tensor<4x3xf64>
    %1 = stablehlo.cosine %arg0 : tensor<4x3xf64>
    %2 = stablehlo.add %0, %1 : tensor<4x3xf64>
    return %2 : tensor<4x3xf64>
  }
}

julia> @code_hlo fn(rM_fs, rM_fs)
module @reactant_fn attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  func.func @main(%arg0: tensor<12xf64>) -> tensor<4x3xf64> attributes {enzymexla.memory_effects = []} {
    %0 = stablehlo.reshape %arg0 : (tensor<12xf64>) -> tensor<4x3xf64>
    %1 = stablehlo.sine %0 : tensor<4x3xf64>
    %2 = stablehlo.cosine %0 : tensor<4x3xf64>
    %3 = stablehlo.add %1, %2 : tensor<4x3xf64>
    return %3 : tensor<4x3xf64>
  }
}

Why there's an extra stablehlo.reshape? That looks not ideal?

@wsmoses
Copy link
Member

wsmoses commented Sep 23, 2025

presumably the memory itself is being stored 1-dimensionally

@giordano
Copy link
Member

It is, as it the case for Array though, the memory backend is basically the same.

kwargs...,
) where {T,N}
shape = size(prev)
return reshape(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this reshape is the culprit? How does Array work? Is it possible to construct this object directly with the right size instead of reshaping it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but array itself we override to keep the dimensionality and allocate everything ourselves for Array.

if the actual tracing of a fixedsizearray does the current "generic recursion into structs" it will eventually allocate a 1-dim memory, always

Copy link
Member

@giordano giordano Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The whole point of FixedSizeArray is that the size is...well...fixed. Having to reshape it all the time seems to go into the opposite direction, especially when Array doesn't have that.

Copy link
Member

@giordano giordano Sep 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels to me like make_tracer should take the size as an (optional) argument. Looking at

Reactant.jl/src/Tracing.jl

Lines 1177 to 1196 in e4bb34f

Base.@nospecializeinfer function make_tracer(
seen,
@nospecialize(prev::ConcreteIFRTArray{T,N}),
@nospecialize(path),
mode;
@nospecialize(sharding = Sharding.NoSharding()),
@nospecialize(device = nothing),
@nospecialize(client = nothing),
kwargs...,
) where {T,N}
if mode == TracedToTypes
throw("Cannot have ConcreteIFRTArray as function call argument.")
end
mode == ArrayToConcrete && return ConcreteIFRTArray(prev; sharding, device, client)
mode != ConcreteToTraced && throw("Cannot trace concrete")
haskey(seen, prev) && return seen[prev]::TracedRArray{T,N}
res = TracedRArray{T,N}((path,), nothing, size(prev))
seen[prev] = res
return res
end
(and all similar methods) the size could be another argument which defaults to size(prev) but could be overridden if passed explicitly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants