@@ -2,16 +2,10 @@ module ReactantNNlibExt
22
33using  NNlib
44using  GPUArraysCore:  @allowscalar 
5- using  Reactant: 
6-     Reactant,
7-     Ops,
8-     TracedRArray,
9-     AnyTracedRArray,
10-     materialize_traced_array,
11-     MLIR,
12-     TracedRNumber,
13-     get_mlir_data,
14-     set_mlir_data!
5+ using  Reactant:  Reactant, Ops, TracedRArray, AnyTracedRArray, MLIR, TracedRNumber
6+ 
7+ using  Reactant. TracedUtils:  materialize_traced_array, get_mlir_data, set_mlir_data!
8+ 
159using  ReactantCore:  @trace 
1610using  LinearAlgebra:  LinearAlgebra, triu
1711
@@ -238,9 +232,9 @@ function NNlib.batched_mul!(
238232    if  size (x, 3 ) !=  size (y, 3 )
239233        B =  max (size (x, 3 ), size (y, 3 ))
240234        if  size (x, 3 ) ==  1 
241-             x =  Reactant. broadcast_to_size (x, (size (x, 1 ), size (x, 2 ), B))
235+             x =  Reactant. TracedUtils . broadcast_to_size (x, (size (x, 1 ), size (x, 2 ), B))
242236        elseif  size (y, 3 ) ==  1 
243-             y =  Reactant. broadcast_to_size (y, (size (y, 1 ), size (y, 2 ), B))
237+             y =  Reactant. TracedUtils . broadcast_to_size (y, (size (y, 1 ), size (y, 2 ), B))
244238        end 
245239    end 
246240
@@ -250,9 +244,9 @@ function NNlib.batched_mul!(
250244    if  size (x, 1 ) !=  size (y, 1 )
251245        B =  max (size (x, 1 ), size (y, 1 ))
252246        if  size (x, 1 ) ==  1 
253-             x =  Reactant. broadcast_to_size (x, (B, size (x, 2 ), size (x, 3 )))
247+             x =  Reactant. TracedUtils . broadcast_to_size (x, (B, size (x, 2 ), size (x, 3 )))
254248        elseif  size (y, 1 ) ==  1 
255-             y =  Reactant. broadcast_to_size (y, (B, size (y, 2 ), size (y, 3 )))
249+             y =  Reactant. TracedUtils . broadcast_to_size (y, (B, size (y, 2 ), size (y, 3 )))
256250        end 
257251    end 
258252
270264function  NNlib. pad_constant (
271265    x:: AnyTracedRArray{T,N} , pad:: NTuple{N,Tuple{Int,Int}} , value
272266) where  {T,N}
273-     value =  Reactant. promote_to (TracedRNumber{T}, value)
267+     value =  Reactant. TracedUtils . promote_to (TracedRNumber{T}, value)
274268    low =  [i[1 ] for  i in  pad]
275269    high =  [i[2 ] for  i in  pad]
276270    interior =  [0  for  i in  pad]
@@ -329,7 +323,8 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr
329323    start_sizes =  ntuple (i ->  size (src, i), dims)
330324    results =  map (CartesianIndices (idxs)) do  k
331325        res =  @allowscalar  src[colons... , Tuple (idxs[k])... ]
332-         res isa  TracedRNumber &&  (res =  Reactant. broadcast_to_size (res, (1 ,)))
326+         res isa  TracedRNumber && 
327+             (res =  Reactant. TracedUtils. broadcast_to_size (res, (1 ,)))
333328        return  reshape (res, start_sizes... , :)
334329    end 
335330    res =  reshape (cat (results... ; dims= (dims +  1 )), size (dst))
0 commit comments