Skip to content

Commit 1bb0000

Browse files
fix: correct usage of Ops.select for Base.ifelse (#332)
* fix: correct usage of Ops.select for Base.ifelse * Update src/TracedRNumber.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent b7de1e6 commit 1bb0000

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

src/TracedRNumber.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -200,12 +200,19 @@ function Base.ifelse(
200200
@nospecialize(x::TracedRNumber{T1}),
201201
@nospecialize(y::TracedRNumber{T2})
202202
) where {T1,T2}
203-
return TracedRNumber{promote_type(T1, T2)}(
204-
(),
205-
MLIR.IR.result(
206-
MLIR.Dialects.stablehlo.select(pred.mlir_data, x.mlir_data, y.mlir_data), 1
207-
),
208-
)
203+
@warn "`ifelse` with different element-types in Reactant works by promoting the \
204+
element-type to the common type. This is semantically different from the \
205+
behavior of `ifelse` in Base. Use with caution" maxlog = 1
206+
T = promote_type(T1, T2)
207+
return ifelse(pred, promote_to(TracedRNumber{T}, x), promote_to(TracedRNumber{T}, y))
208+
end
209+
210+
function Base.ifelse(
211+
@nospecialize(pred::TracedRNumber{Bool}),
212+
@nospecialize(x::TracedRNumber{T}),
213+
@nospecialize(y::TracedRNumber{T})
214+
) where {T}
215+
return Ops.select(pred, x, y)
209216
end
210217

211218
for (T1, T2) in zip((Bool, Integer), (Bool, Integer))

test/basic.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,3 +640,16 @@ end
640640

641641
@test @jit(f_row_major(x_ra)) f_row_major(x)
642642
end
643+
644+
@testset "ifelse" begin
645+
@test 1.0 ==
646+
@jit ifelse(ConcreteRNumber(true), ConcreteRNumber(1.0), ConcreteRNumber(0.0f0))
647+
@test @jit(
648+
ifelse(ConcreteRNumber(false), ConcreteRNumber(1.0), ConcreteRNumber(0.0f0))
649+
) isa ConcreteRNumber{Float64}
650+
@test 0.0f0 ==
651+
@jit ifelse(ConcreteRNumber(false), ConcreteRNumber(1.0), ConcreteRNumber(0.0f0))
652+
@test @jit(
653+
ifelse(ConcreteRNumber(false), ConcreteRNumber(1.0f0), ConcreteRNumber(0.0f0))
654+
) isa ConcreteRNumber{Float32}
655+
end

0 commit comments

Comments
 (0)