Skip to content

Commit 6b9f46b

Browse files
authored
fix: clamp! dispatch (#1661)
* fix: clamp! dispatch * Update test/basic.jl
1 parent 20c6a4c commit 6b9f46b

File tree

2 files changed

+9
-0
lines changed

2 files changed

+9
-0
lines changed

src/TracedRArray.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,9 @@ end
846846

847847
for (minT, maxT) in Iterators.product((Number, TracedRNumber), (Number, TracedRNumber))
848848
@eval function Base.clamp!(x::AnyTracedRArray, min::$(minT), max::$(maxT))
849+
T = unwrapped_eltype(x)
850+
min = Reactant.promote_to(TracedRNumber{T}, min)
851+
max = Reactant.promote_to(TracedRNumber{T}, max)
849852
y = @opcall clamp(min, materialize_traced_array(x), max)
850853
TracedUtils.set_mlir_data!(x, y.mlir_data)
851854
return x

test/basic.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1568,3 +1568,9 @@ end
15681568

15691569
@test @jit(f(params_ra, points_ra)) f(params, points)
15701570
end
1571+
1572+
@testset "clamp!" begin
1573+
x = rand(Float32, 32, 32)
1574+
x_ra = Reactant.to_rarray(x)
1575+
@test @jit(clamp!(x_ra, 0.5, Inf32)) clamp!(x, 0.5, Inf32)
1576+
end

0 commit comments

Comments
 (0)