Skip to content

Commit d68b3c7

Browse files
wsmosesgithub-actions[bot]giordano
authored
Add isless (#923)
* Add isless * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix test * Upgrade Reactant_jll to v0.0.89 --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Mosè Giordano <[email protected]>
1 parent 85ccf30 commit d68b3c7

File tree

3 files changed

+26
-1
lines changed

3 files changed

+26
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ PythonCall = "0.9"
8686
Random = "1.10"
8787
Random123 = "1.7"
8888
ReactantCore = "0.1.5"
89-
Reactant_jll = "0.0.88"
89+
Reactant_jll = "0.0.89"
9090
Scratch = "1.2"
9191
Sockets = "1.10"
9292
SpecialFunctions = "2.4"

ext/ReactantCUDAExt.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ function Base.convert(::Type{T}, RN::CuTracedRNumber) where {T<:Number}
5050
return Base.convert(T, Base.getindex(RN))
5151
end
5252

53+
Base.isless(a::CuTracedRNumber, b::CuTracedRNumber) = Base.isless(a[], b[])
54+
Base.isless(a, b::CuTracedRNumber) = Base.isless(a, b[])
55+
Base.isless(a::CuTracedRNumber, b) = Base.isless(a[], b)
56+
5357
function Base.promote_rule(
5458
::Type{<:CuTracedRNumber{T}}, ::Type{<:CuTracedRNumber{T2}}
5559
) where {T,T2}

test/integration/cuda.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,3 +154,24 @@ end
154154
@jit mul_number!(A, B)
155155
@test all(Array(A) .≈ oA .* 3.1)
156156
end
157+
158+
function searchsorted_kernel!(x, y)
159+
i = threadIdx().x
160+
times = 0:0.01:4.5
161+
z = searchsortedfirst(times, y)
162+
x[i] = z
163+
return nothing
164+
end
165+
166+
function searchsorted!(x, y)
167+
@cuda blocks = 1 threads = length(x) searchsorted_kernel!(x, y)
168+
return nothing
169+
end
170+
171+
@testset "Search sorted" begin
172+
oA = collect(Float64, 1:1:64)
173+
A = Reactant.to_rarray(oA)
174+
B = ConcreteRNumber(3.1)
175+
@jit searchsorted!(A, B)
176+
@test all(Array(A) .≈ 311)
177+
end

0 commit comments

Comments
 (0)