Skip to content

Commit f33c023

Browse files
authored
feat: add isinf dispatches (#826)
1 parent a39b055 commit f33c023

File tree

3 files changed

+34
-45
lines changed

3 files changed

+34
-45
lines changed

src/Ops.jl

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -280,20 +280,19 @@ for (dialect, op) in [
280280
end
281281

282282
# is* checks
283-
for (dialect, op) in [
284-
#(:stablehlo, :is_finite),
285-
(:chlo, :is_inf),
286-
(:chlo, :is_neg_inf),
287-
(:chlo, :is_pos_inf),
288-
]
283+
for (dialect, op) in
284+
[(:stablehlo, :is_finite), (:chlo, :is_inf), (:chlo, :is_neg_inf), (:chlo, :is_pos_inf)]
285+
result = dialect == :stablehlo ? :y : :result
289286
@eval begin
290287
@noinline function $op(
291288
x::TracedRArray{T,N};
292289
location=mlir_stacktrace($(string(op)), @__FILE__, @__LINE__),
293290
) where {T,N}
294291
res = MLIR.IR.result(
295292
$(:($dialect.$op))(
296-
x.mlir_data; result=mlir_type(TracedRArray{Bool,N}, size(x)), location
293+
x.mlir_data;
294+
$(result)=mlir_type(TracedRArray{Bool,N}, size(x)),
295+
location,
297296
),
298297
)
299298
return TracedRArray{Bool,N}((), res, size(x))
@@ -305,34 +304,14 @@ for (dialect, op) in [
305304
) where {T}
306305
res = MLIR.IR.result(
307306
$(:($dialect.$op))(
308-
x.mlir_data; result=mlir_type(TracedRArray{Bool,0}, ()), location
307+
x.mlir_data; $(result)=mlir_type(TracedRArray{Bool,0}, ()), location
309308
),
310309
)
311310
return TracedRNumber{Bool}((), res)
312311
end
313312
end
314313
end
315314

316-
@noinline function is_finite(
317-
x::TracedRArray{T,N}; location=mlir_stacktrace("is_finite", @__FILE__, @__LINE__)
318-
) where {T,N}
319-
res = MLIR.IR.result(
320-
stablehlo.is_finite(
321-
x.mlir_data; y=mlir_type(TracedRArray{Bool,N}, size(x)), location
322-
),
323-
)
324-
return TracedRArray{Bool,N}((), res, size(x))
325-
end
326-
327-
@noinline function is_finite(
328-
x::TracedRNumber{T}; location=mlir_stacktrace("is_finite", @__FILE__, @__LINE__)
329-
) where {T}
330-
res = MLIR.IR.result(
331-
stablehlo.is_finite(x.mlir_data; y=mlir_type(TracedRArray{Bool,0}, ()), location)
332-
)
333-
return TracedRNumber{Bool}((), res)
334-
end
335-
336315
# fixes to default automated implementations
337316
@noinline function abs(
338317
x::TracedRArray{Complex{T},N}; location=mlir_stacktrace("abs", @__FILE__, @__LINE__)

src/TracedRNumber.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,18 @@ end
1919
function Base.isfinite(x::TracedRNumber{<:Complex})
2020
return isfinite(real(x)) & isfinite(imag(x))
2121
end
22-
function Base.isfinite(x::TracedRNumber{T}) where {T<:AbstractFloat}
23-
return Reactant.Ops.is_finite(x)
24-
end
22+
Base.isfinite(x::TracedRNumber{<:AbstractFloat}) = Ops.is_finite(x)
2523

26-
function Base.isnan(x::TracedRNumber{T}) where {T<:AbstractFloat}
27-
return !isfinite(x) & (x != typemax(T)) & (x != typemin(T))
28-
end
2924
function Base.isnan(x::TracedRNumber{<:Complex})
3025
return isnan(real(x)) | isnan(imag(x))
3126
end
27+
function Base.isnan(x::TracedRNumber{T}) where {T<:AbstractFloat}
28+
return !isfinite(x) & (x != typemax(T)) & (x != typemin(T))
29+
end
30+
31+
Base.isinf(x::TracedRNumber{<:Complex}) = isinf(real(x)) | isinf(imag(x))
32+
Base.isinf(x::TracedRNumber{<:AbstractFloat}) = Ops.is_inf(x)
33+
Base.isinf(::TracedRNumber{<:Integer}) = false
3234

3335
function Base.show(io::IOty, X::TracedRNumber{T}) where {T,IOty<:Union{IO,IOContext}}
3436
return print(io, "TracedRNumber{", T, "}(", X.paths, ")")

test/basic.jl

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -766,18 +766,18 @@ end
766766

767767
@testset "isfinite" begin
768768
x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN])
769-
@test Reactant.@jit(isfinite.(x)) == [true, false, false, false, false]
769+
@test @jit(isfinite.(x)) == [true, false, false, false, false]
770770

771771
x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im)
772-
@test Reactant.@jit(isfinite.(x)) == [true, false, false, false, false]
772+
@test @jit(isfinite.(x)) == [true, false, false, false, false]
773773
end
774774

775775
@testset "isnan" begin
776776
x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN])
777-
@test Reactant.@jit(isnan.(x)) == [false, true, false, false, true]
777+
@test @jit(isnan.(x)) == [false, true, false, false, true]
778778

779779
x = Reactant.to_rarray([1.0, NaN, Inf, -Inf, NaN] .* im)
780-
@test Reactant.@jit(isnan.(x)) == [false, true, false, false, true]
780+
@test @jit(isnan.(x)) == [false, true, false, false, true]
781781
end
782782

783783
@testset "isnan/isfinite" begin
@@ -787,19 +787,27 @@ end
787787
@test !isfinite(Reactant.to_rarray(Inf; track_numbers=Number))
788788
end
789789

790+
@testset "isinf" begin
791+
@test Bool(@jit(isinf(ConcreteRNumber(Inf))))
792+
@test Bool(@jit(isinf(ConcreteRNumber(-Inf))))
793+
@test !Bool(@jit(isinf(ConcreteRNumber(2))))
794+
@test !Bool(@jit(isinf(ConcreteRNumber(2.0))))
795+
@test !Bool(@jit(isinf(ConcreteRNumber(true))))
796+
end
797+
790798
@testset "mod and rem" begin
791799
a = [-1.1, 7.7, -3.3, 9.9, -5.5]
792800
b = [6.6, -2.2, -8.8, 4.4, -10.1]
793801

794802
expected_mod = mod.(a, b)
795-
@test Reactant.@jit(mod.(Reactant.to_rarray(a), Reactant.to_rarray(b))) expected_mod
796-
@test Reactant.@jit(mod.(a, Reactant.to_rarray(b))) expected_mod
797-
@test Reactant.@jit(mod.(Reactant.to_rarray(a), b)) expected_mod
803+
@test @jit(mod.(Reactant.to_rarray(a), Reactant.to_rarray(b))) expected_mod
804+
@test @jit(mod.(a, Reactant.to_rarray(b))) expected_mod
805+
@test @jit(mod.(Reactant.to_rarray(a), b)) expected_mod
798806

799807
expected_rem = rem.(a, b)
800-
@test Reactant.@jit(rem.(Reactant.to_rarray(a), Reactant.to_rarray(b))) expected_rem
801-
@test Reactant.@jit(rem.(a, Reactant.to_rarray(b))) expected_rem
802-
@test Reactant.@jit(rem.(Reactant.to_rarray(a), b)) expected_rem
808+
@test @jit(rem.(Reactant.to_rarray(a), Reactant.to_rarray(b))) expected_rem
809+
@test @jit(rem.(a, Reactant.to_rarray(b))) expected_rem
810+
@test @jit(rem.(Reactant.to_rarray(a), b)) expected_rem
803811
end
804812

805813
@testset "xor" begin
@@ -910,7 +918,7 @@ end
910918
x[:b] = 3.1 * ones(4)
911919

912920
ra = Reactant.to_rarray(x)
913-
Reactant.@jit dip!(ra)
921+
@jit dip!(ra)
914922
ra[:a] (2.7 * 2) * ones(4)
915923
end
916924

0 commit comments

Comments
 (0)