@@ -36,35 +36,16 @@ ReactantCore.is_traced(::TracedRArray) = true
3636
3737new_traced_value (A:: TracedRArray{T,N} ) where  {T,N} =  TracedRArray {T,N} ((), nothing , size (A))
3838
39+ TracedRArray {T,N} (rhs:: TracedRArray{T,N} ) where  {T,N} =  rhs
3940function  TracedRArray {T,N} (rhs:: TracedRArray{T0,N} ) where  {T,T0,N}
40-     if  T ==  T0
41-         return  rhs
42-     else 
43-         return  TracedRArray {T,N} (
44-             (),
45-             MLIR. IR. result (
46-                 MLIR. Dialects. stablehlo. convert (
47-                     rhs. mlir_data; result= mlir_type (TracedRArray{T,N}, size (rhs))
48-                 ),
49-                 1 ,
50-             ),
51-             size (rhs),
52-         )
53-     end 
41+     return  Ops. convert (TracedRArray{T,N}, rhs)
5442end 
5543
5644function  TracedRArray {T,N} (rhs:: WrappedTracedRArray{T0,N} ) where  {T0,T,N}
5745    return  TracedRArray {T,N} (materialize_traced_array (rhs))
5846end 
5947
60- function  TracedRArray {T,N} (rhs:: AbstractArray{T0,N} ) where  {T0,T,N}
61-     attr =  MLIR. IR. DenseElementsAttribute (collect (rhs))
62-     return  TracedRArray {T,N} (
63-         TracedRArray {T0,length(size(rhs))} (
64-             (), MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 ), size (rhs)
65-         ),
66-     )
67- end 
48+ TracedRArray {T,N} (rhs:: AbstractArray{T0,N} ) where  {T0,T,N} =  Ops. constant (collect (rhs))
6849
6950materialize_traced_array (x:: TracedRArray ) =  x
7051materialize_traced_array (x:: WrappedTracedRArray ) =  x[axes (x)... ]
@@ -164,7 +145,6 @@ function Base.getindex(
164145        ),
165146        1 ,
166147    )
167- 
168148    return  TracedRNumber {T} ((), res2)
169149end 
170150
@@ -254,9 +234,7 @@ Base.copy(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), A.mlir_data,
254234
255235#  TODO  is there a way to create an unitialized `tensor`? does it show an advantage? maybe `fill`?
256236function  Base. similar (:: TracedRArray , :: Type{T} , dims:: Dims{N} ) where  {T,N}
257-     attr =  MLIR. IR. DenseElementsAttribute (zeros (T, dims))
258-     res =  MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 )
259-     return  TracedRArray {T,N} ((), res, dims)
237+     return  Ops. constant (zeros (T, dims))
260238end 
261239
262240function  Base. show (io:: IOty , X:: TracedRArray{T,N} ) where  {T,N,IOty<: Union{IO,IOContext} }
@@ -266,69 +244,20 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC
266244end 
267245
268246function  Base. permutedims (A:: AnyTracedRArray{T,N} , perm) where  {T,N}
269-     return  TracedRArray {T,N} (
270-         (),
271-         MLIR. IR. result (
272-             MLIR. Dialects. stablehlo. transpose (
273-                 get_mlir_data (A);
274-                 permutation= MLIR. IR. DenseArrayAttribute ([Int64 (i -  1 ) for  i in  perm]),
275-             ),
276-             1 ,
277-         ),
278-         Tuple (size (A, i) for  i in  perm),
279-     )
247+     return  Ops. transpose (materialize_traced_array (A), perm)
280248end 
281249
282- Base. conj (A:: TracedRArray ) =  A
283- function  Base. conj (A:: TracedRArray{T,N} ) where  {T<: Complex ,N}
284-     return  TracedRArray {T,N} (
285-         (),
286-         MLIR. IR. result (
287-             MLIR. Dialects. chlo. conj (
288-                 A. mlir_data; result= mlir_type (TracedRArray{T,N}, size (A))
289-             ),
290-             1 ,
291-         ),
292-         size (A),
293-     )
294- end 
295- 
296- Base. conj! (A:: TracedRArray ) =  A
297- function  Base. conj! (A:: TracedRArray{T,N} ) where  {T<: Complex ,N}
298-     A. mlir_data =  MLIR. IR. result (
299-         MLIR. Dialects. chlo. conj (A. mlir_data; result= mlir_type (TracedRArray{T,N}, size (A))),
300-         1 ,
301-     )
250+ Base. conj! (A:: AnyTracedRArray ) =  A
251+ function  Base. conj! (A:: AnyTracedRArray{<:Complex} )
252+     set_mlir_data! (A, Ops. conj (materialize_traced_array (A)). mlir_data)
302253    return  A
303254end 
304255
305- Base. real (A:: TracedRArray ) =  A
306- function  Base. real (A:: TracedRArray{Complex{T},N} ) where  {T,N}
307-     return  TracedRArray {T,N} (
308-         (),
309-         MLIR. IR. result (
310-             MLIR. Dialects. stablehlo. real (
311-                 A. mlir_data; result= mlir_type (TracedRArray{T,N}, size (A))
312-             ),
313-             1 ,
314-         ),
315-         size (A),
316-     )
317- end 
256+ Base. real (A:: AnyTracedRArray ) =  A
257+ Base. real (A:: AnyTracedRArray{<:Complex} ) =  Ops. real (materialize_traced_array (A))
318258
319- Base. imag (A:: TracedRArray ) =  zero (A)
320- function  Base. imag (A:: TracedRArray{Complex{T},N} ) where  {T,N}
321-     return  TracedRArray {T,N} (
322-         (),
323-         MLIR. IR. result (
324-             MLIR. Dialects. stablehlo. imag (
325-                 A. mlir_data; result= mlir_type (TracedRArray{T,N}, size (A))
326-             ),
327-             1 ,
328-         ),
329-         size (A),
330-     )
331- end 
259+ Base. imag (A:: AnyTracedRArray ) =  zero (A)
260+ Base. imag (A:: AnyTracedRArray{<:Complex} ) =  Ops. imag (materialize_traced_array (A))
332261
333262promote_to (:: Type{TracedRArray{T,N}} , rhs) where  {T,N} =  TracedRArray {T,N} (rhs)
334263
@@ -521,13 +450,7 @@ function Base.mapreduce(
521450    redT =  eltype (MLIR. IR. julia_type (MLIR. IR. type (red)))
522451
523452    if  dims !=  (:)
524-         red =  MLIR. IR. result (
525-             MLIR. Dialects. stablehlo. reshape (
526-                 red; result_0= MLIR. IR. TensorType (toonedims, eltype (MLIR. IR. type (red)))
527-             ),
528-             1 ,
529-         )
530-         red =  TracedRArray {redT,length(toonedims)} ((), red, (toonedims... ,))
453+         red =  Ops. reshape (red, toonedims... )
531454    else 
532455        if  length (outdims) ==  0 
533456            red =  TracedRNumber {redT} ((), red)
@@ -633,27 +556,14 @@ function Base.copyto!(dest::TracedRArray{T,N}, src::TracedRArray{T,N}) where {T,
633556    return  dest
634557end 
635558
636- function  broadcast_to_size (arg:: AbstractArray , rsize)
637-     attr =  MLIR. IR. DenseElementsAttribute (arg)
638-     len =  ndims (arg)
639-     @assert  typeof (len) ==  Int
640-     arg =  TracedRArray {eltype(arg),len} (
641-         (), MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 ), size (arg)
642-     )
643-     return  broadcast_to_size (arg, rsize)
644- end 
559+ broadcast_to_size (arg:: AbstractArray , rsize) =  broadcast_to_size (Ops. constant (arg), rsize)
645560
646561function  broadcast_to_size (arg:: Base.RefValue , rsize)
647562    #  XXX : don't we want to expand here to rsize?
648563    return  arg
649564end 
650565
651- function  broadcast_to_size (arg:: T , rsize) where  {T<: Number }
652-     attr =  MLIR. IR. DenseElementsAttribute (Base. fill (arg, Tuple (rsize)))
653-     return  TracedRArray {T,length(rsize)} (
654-         (), MLIR. IR. result (MLIR. Dialects. stablehlo. constant (; value= attr), 1 ), rsize
655-     )
656- end 
566+ broadcast_to_size (arg:: Number , rsize) =  Ops. constant (Base. fill (arg, Tuple (rsize)))
657567
658568function  broadcast_to_size (arg:: TracedRNumber , rsize)
659569    length (rsize) ==  0  &&  return  arg
0 commit comments