122
122
end
123
123
end
124
124
125
- @noinline function constant (
126
- x:: AbstractArray{T,N} ; location= mlir_stacktrace (" constant" , @__FILE__ , @__LINE__ )
127
- ) where {T,N}
128
- return constant (collect (x); location)
129
- end
130
-
131
- @noinline function constant (x:: Reactant.AbstractConcreteArray ; kwargs... )
132
- return constant (Base. convert (Array, x); kwargs... )
133
- end
134
-
135
125
@noinline function constant (
136
126
x:: T ; location= mlir_stacktrace (" constant" , @__FILE__ , @__LINE__ )
137
127
) where {T<: Number }
140
130
return TracedRNumber {T} ((), res. mlir_data)
141
131
end
142
132
143
- @noinline function constant (x:: Reactant.AbstractConcreteNumber{T} ; kwargs... ) where {T}
144
- return constant (Base. convert (T, x); kwargs... )
145
- end
146
-
147
133
function fill (
148
134
v, dims:: Base.DimOrInd... ; location= mlir_stacktrace (" fill" , @__FILE__ , @__LINE__ )
149
135
)
391
377
end
392
378
393
379
# shape ops
394
- function reshape (x:: TracedRArray , dims:: Integer ... ; kwargs... )
380
+ function reshape (x:: TracedRArray , dims... ; kwargs... )
395
381
return reshape (x, collect (dims); kwargs... )
396
382
end
397
383
@@ -2394,7 +2380,7 @@ end
2394
2380
x::TracedRArray{T},
2395
2381
init_values::TracedRNumber{T},
2396
2382
dimensions::Vector{Int},
2397
- fn::Function;
2383
+ fn::Function,
2398
2384
location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
2399
2385
)
2400
2386
@@ -2426,43 +2412,25 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`,
2426
2412
- **CPU version & Julia's `reduce`**:
2427
2413
- Reduce along dimension 1 → `[(15) (21); (18) (24)]`
2428
2414
- Reduce along dimension 3 → `[(33 + 2) (45 + 2)]` → `[35 47]`
2429
-
2415
+
2430
2416
- **GPU version**:
2431
2417
- Reduce along dimension 1 → `[(15 + 2) (21 + 2); (18 + 2) (24 + 2)]`
2432
2418
- Reduce along dimension 3 → `[37 49]`
2433
2419
"""
2434
2420
@noinline function reduce (
2435
2421
x:: TracedRArray{T} ,
2436
- init_values:: Union{ TracedRNumber{T},Nothing } ,
2422
+ init_values:: TracedRNumber{T} ,
2437
2423
dimensions:: Vector{Int} ,
2438
- fn:: Function ;
2424
+ fn:: Function ,
2439
2425
location= mlir_stacktrace (" reduce" , @__FILE__ , @__LINE__ ),
2440
2426
) where {T}
2441
- elT = T
2442
- if init_values === nothing
2443
- if fn === min || fn === Base. FastMath. min_fast
2444
- init = typemax (elT)
2445
- elseif fn === max || fn === Base. FastMath. max_fast
2446
- init = typemin (elT)
2447
- else
2448
- init = Base. reduce_empty (Base. BottomRF (fn), elT)
2449
- end
2450
-
2451
- initT = unwrapped_eltype (typeof (init))
2452
- if initT != elT # Bool, etc. reductions
2453
- elT = promote_type (initT, elT)
2454
- x = elT .(x)
2455
- end
2456
- init_values = Reactant. TracedUtils. promote_to (TracedRNumber{elT}, init)
2457
- end
2458
-
2459
2427
reduced_shape = Tuple (deleteat! (collect (size (x)), dimensions))
2460
2428
2461
- result_type = mlir_type (TracedRArray{elT ,length (reduced_shape)}, reduced_shape)
2429
+ result_type = mlir_type (TracedRArray{T ,length (reduced_shape)}, reduced_shape)
2462
2430
2463
2431
sample_inputs = [
2464
- Reactant. TracedUtils. promote_to (TracedRNumber{elT }, 0 ),
2465
- Reactant. TracedUtils. promote_to (TracedRNumber{elT }, 0 ),
2432
+ Reactant. TracedUtils. promote_to (TracedRNumber{T }, 0 ),
2433
+ Reactant. TracedUtils. promote_to (TracedRNumber{T }, 0 ),
2466
2434
]
2467
2435
2468
2436
func =
@@ -2476,8 +2444,14 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`,
2476
2444
return_dialect= :stablehlo ,
2477
2445
). f
2478
2446
@assert MLIR. IR. nregions (func) == 1
2479
- ftype = MLIR. IR. Type (MLIR. IR. attr (func, " function_type" ))
2480
- @assert MLIR. IR. result (ftype) == MLIR. IR. TensorType ((), MLIR. IR. Type (elT)) " $fn return type is not tensor<i1>"
2447
+ fn_name = String (
2448
+ MLIR. IR. attr (func, String (MLIR. API. mlirSymbolTableGetSymbolAttributeName ()))
2449
+ )
2450
+ ftype_attr = MLIR. IR. attr (func, " function_type" )
2451
+ ftype = MLIR. IR. Type (ftype_attr)
2452
+ @assert MLIR. IR. result (ftype) == MLIR. IR. TensorType ((), MLIR. IR. Type (T)) error (
2453
+ " $fn return type is not tensor<i1>"
2454
+ )
2481
2455
fn = MLIR. IR. Region ()
2482
2456
MLIR. API. mlirRegionTakeBody (fn, MLIR. IR. region (func, 1 ))
2483
2457
MLIR. IR. rmfromparent! (func)
@@ -2495,7 +2469,7 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`,
2495
2469
),
2496
2470
)
2497
2471
2498
- return TracedRArray {elT ,length(reduced_shape)} ((), res, reduced_shape)
2472
+ return TracedRArray {T ,length(reduced_shape)} ((), res, reduced_shape)
2499
2473
end
2500
2474
2501
2475
end # module Ops
0 commit comments