Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions src/InlineStrings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ function ==(x::String, y::T) where {T <: InlineString}
end
==(y::InlineString, x::String) = x == y

Base.cmp(a::T, b::T) where {T <: InlineString} =
Base.eq_int(a, b) ? 0 : Base.ult_int(a, b) ? -1 : 1

function Base.hash(x::T, h::UInt) where {T <: InlineString}
h += Base.memhash_seed
ref = Ref{T}(_bswap(x))
Expand Down Expand Up @@ -579,4 +582,138 @@ function Parsers.xparse(::Type{T}, source::Union{AbstractVector{UInt8}, IO}, pos
return Parsers.Result{S}(code, res.tlen, x)
end

## InlineString sorting
using Base.Sort, Base.Order

# Only small-ish InlineStrings benefit from RadixSort algorithm
const SmallInlineStrings = Union{String1, String3, String7, String15}

# And under certain thresholds, MergeSort is faster than RadixSort, even for small InlineStrings
const MergeSortThresholds = Dict(
1 => 2^5,
4 => 2^7,
8 => 2^9,
16 => 2^23
)

struct InlineStringSortAlg <: Algorithm end
const InlineStringSort = InlineStringSortAlg()

Base.Sort.defalg(::AbstractArray{<:Union{SmallInlineStrings, Missing}}) = InlineStringSort

struct Radix
size::Int
pow::Int
mask::UInt16
end

Radix(size) = Radix(size, 2^size, typemax(UInt16) >> (16 - size))

sortvalue(o::By, x ) = sortvalue(Forward, o.by(x))
sortvalue(o::Perm, i::Int) = sortvalue(o.order, o.data[i])
sortvalue(o::Lt, x ) = error("sortvalue does not work with general Lt Orderings")
sortvalue(rev::ReverseOrdering, x) = Base.not_int(sortvalue(rev.fwd, x))
sortvalue(::Base.ForwardOrdering, x) = x

_oftype(::Type{T}, x::S) where {T, S} = sizeof(T) == sizeof(S) ? Base.bitcast(T, x) : sizeof(T) > sizeof(S) ? Base.zext_int(T, x) : Base.trunc_int(T, x)

radix(v::T, j, radix_size, radix_mask) where {T} = _oftype(Int64, Base.and_int(Base.lshr_int(v, (j - 1) * radix_size), _oftype(T, radix_mask))) + 1

@noinline requireprimitivetype(T) = throw(ArgumentError("InlineStringSort requires isprimitivetype input: `$T` invalid"))

function Base.sort!(vs::AbstractVector, lo::Int, hi::Int, ::InlineStringSortAlg, o::Ordering)
# Input checking
lo >= hi && return vs

# Make sure we're sorting a primitive type
T = Base.Order.ordtype(o, vs)
isprimitivetype(T) || requireprimitivetype(T)

if hi - lo < MergeSortThresholds[sizeof(T)]
return sort!(vs, lo, hi, MergeSort, o)
end

# setup
ts = similar(vs)
rdx = Radix(sizeof(T) == 1 ? 8 : 11)
radix_size = rdx.size
radix_mask = rdx.mask
radix_size_pow = rdx.pow
# iters is the # of 11-bit chunks we split each element up into
# they each represent a "significant digit" we'll be sorting on
iters = cld(sizeof(T) * 8, radix_size)
# bin has a row for each unique 11-bit pattern
# and a column for each 11-bit chunk we'll split each element up into
bin = zeros(UInt32, radix_size_pow, iters)
# if for some reason our lo isn't 1, we want to start our
# 1st row bin values as the 1st index we'll start at in the output
# i.e. we're assuming firstindex(vs):(lo - 1) is already sorted
if lo > 1; bin[1, :] .= lo-1; end

# for each element, split into 11-bit chunks (radix)
# and accumulate counts per unique pattern in bin
for i = lo:hi
v = sortvalue(o, vs[i])
for j = 1:iters
idx = radix(v, j, radix_size, radix_mask)
@inbounds bin[idx, j] += 1
end
end

# now we sort elements by sorting each radix using counting sort
swaps = 0
len = hi - lo + 1
@inbounds for j = 1:iters
# we first check if the radix for each element happened to be
# the exact same bit pattern; if so, they're "already sorted"
# for this radix and we can skip to the next. This would be common
# if we, for example, had many small integer values stored in Int64
# which would result in many "wasted" zero bits in most elements
v = sortvalue(o, vs[hi])
idx = radix(v, j, radix_size, radix_mask)

# if every element was counted at this bit pattern
# we can skip to the next radix chunk
bin[idx, j] == len && continue

# otherwise, we perform the counting sort for this radix
# by doing a cumulative sum for this radix column in bin
x = bin[1, j]
for i = 2:radix_size_pow
x += bin[i, j]
bin[i, j] = x
end
# now we extract the output index for our 1st element (vs[hi])
ci = bin[idx, j]
# and decrement the count for that bit pattern which
# will result in a subsequent identical bit pattern being
# placed one index ahead of the current one
bin[idx, j] -= 1
ts[ci] = vs[hi]

# now we sort the rest of the elements' radix similarly
for i in (hi - 1):-1:lo
v = sortvalue(o, vs[i])
idx = radix(v, j, radix_size, radix_mask)
ci = bin[idx, j]
bin[idx, j] -= 1
ts[ci] = vs[i]
end
# we keep 2 arrays, vs and ts
# because we can't overwrite where the current
# element will go in the output before we've sorted
# the element already there
vs, ts = ts, vs
swaps += 1
end

if isodd(swaps)
vs, ts = ts, vs
@inbounds for i = lo:hi
vs[i] = ts[i]
end
end
return vs
end

end # module
13 changes: 12 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Test, InlineStrings, Parsers, Serialization
using Test, InlineStrings, Parsers, Serialization, Random
import Parsers: SENTINEL, OK, EOF, OVERFLOW, QUOTED, DELIMITED, INVALID_QUOTED_FIELD, ESCAPED_STRING, NEWLINE, SUCCESS, peekbyte, incr!, checksentinel, checkdelim, checkcmtemptylines

@testset "InlineString basics" begin
Expand Down Expand Up @@ -189,3 +189,14 @@ end # @testset
@test String127 == InlineString127
@test String255 == InlineString255
end

@testset "sorting tests" begin
for nelems in (50, 100, 500, 1000, 5000, 100_000)
for T in (String1, String3, String7, String15, String31, String63, String127, String255)
x = [randstring(rand(1:(max(1, sizeof(T) - 1)))) for _ = 1:nelems];
y = map(T, x);
@test sort(x) == sort(y)
@test sort(x; rev=true) == sort(y; rev=true)
end
end
end