Skip to content

Commit 11ac5b5

Browse files
authored
Rewrote corkendall (issue #634) (#647)
New version of corkendall is approx 4 times faster if both arguments are vectors and 7 times faster if at least one is a matrix. See issue #634 for details.
1 parent 3b0b2da commit 11ac5b5

File tree

2 files changed

+235
-87
lines changed

2 files changed

+235
-87
lines changed

src/rankcorr.jl

Lines changed: 162 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -33,121 +33,203 @@ corspearman(X::RealMatrix) = (Z = mapslices(tiedrank, X, dims=1); cor(Z, Z))
3333
#
3434
#######################################
3535

36-
# Knight JASA (1966)
37-
38-
function corkendall!(x::RealVector, y::RealVector)
36+
# Knight, William R. “A Computer Method for Calculating Kendall's Tau with Ungrouped Data.”
37+
# Journal of the American Statistical Association, vol. 61, no. 314, 1966, pp. 436–439.
38+
# JSTOR, www.jstor.org/stable/2282833.
39+
function corkendall!(x::RealVector, y::RealVector, permx::AbstractVector{<:Integer}=sortperm(x))
3940
if any(isnan, x) || any(isnan, y) return NaN end
4041
n = length(x)
4142
if n != length(y) error("Vectors must have same length") end
4243

4344
# Initial sorting
44-
pm = sortperm(y)
45-
x[:] = x[pm]
46-
y[:] = y[pm]
47-
pm[:] = sortperm(x)
48-
x[:] = x[pm]
49-
50-
# Counting ties in x and y
51-
iT = 1
52-
nT = 0
53-
iU = 1
54-
nU = 0
55-
for i = 2:n
56-
if x[i] == x[i-1]
57-
iT += 1
58-
else
59-
nT += iT*(iT - 1)
60-
iT = 1
61-
end
62-
if y[i] == y[i-1]
63-
iU += 1
64-
else
65-
nU += iU*(iU - 1)
66-
iU = 1
45+
permute!(x, permx)
46+
permute!(y, permx)
47+
48+
# Use widen to avoid overflows on both 32bit and 64bit
49+
npairs = div(widen(n) * (n - 1), 2)
50+
ntiesx = ndoubleties = nswaps = widen(0)
51+
k = 0
52+
53+
@inbounds for i = 2:n
54+
if x[i - 1] == x[i]
55+
k += 1
56+
elseif k > 0
57+
# Sort the corresponding chunk of y, so the rows of hcat(x,y) are
58+
# sorted first on x, then (where x values are tied) on y. Hence
59+
# double ties can be counted by calling countties.
60+
sort!(view(y, (i - k - 1):(i - 1)))
61+
ntiesx += div(widen(k) * (k + 1), 2) # Must use wide integers here
62+
ndoubleties += countties(y, i - k - 1, i - 1)
63+
k = 0
6764
end
6865
end
69-
if iT > 1 nT += iT*(iT - 1) end
70-
nT = div(nT,2)
71-
if iU > 1 nU += iU*(iU - 1) end
72-
nU = div(nU,2)
73-
74-
# Sort y after x
75-
y[:] = y[pm]
76-
77-
# Calculate double ties
78-
iV = 1
79-
nV = 0
80-
jV = 1
81-
for i = 2:n
82-
if x[i] == x[i-1] && y[i] == y[i-1]
83-
iV += 1
84-
else
85-
nV += iV*(iV - 1)
86-
iV = 1
87-
end
66+
if k > 0
67+
sort!(view(y, (n - k):n))
68+
ntiesx += div(widen(k) * (k + 1), 2)
69+
ndoubleties += countties(y, n - k, n)
8870
end
89-
if iV > 1 nV += iV*(iV - 1) end
90-
nV = div(nV,2)
9171

92-
nD = div(n*(n - 1),2)
93-
return (nD - nT - nU + nV - 2swaps!(y)) / (sqrt(nD - nT) * sqrt(nD - nU))
94-
end
72+
nswaps = merge_sort!(y, 1, n)
73+
ntiesy = countties(y, 1, n)
9574

75+
# Calls to float below prevent possible overflow errors when
76+
# length(x) exceeds 77_936 (32 bit) or 5_107_605_667 (64 bit)
77+
(npairs + ndoubleties - ntiesx - ntiesy - 2 * nswaps) /
78+
sqrt(float(npairs - ntiesx) * float(npairs - ntiesy))
79+
end
9680

9781
"""
9882
corkendall(x, y=x)
9983
10084
Compute Kendall's rank correlation coefficient, τ. `x` and `y` must both be either
10185
matrices or vectors.
10286
"""
103-
corkendall(x::RealVector, y::RealVector) = corkendall!(float(copy(x)), float(copy(y)))
87+
corkendall(x::RealVector, y::RealVector) = corkendall!(copy(x), copy(y))
10488

105-
corkendall(X::RealMatrix, y::RealVector) = Float64[corkendall!(float(X[:,i]), float(copy(y))) for i in 1:size(X, 2)]
106-
107-
corkendall(x::RealVector, Y::RealMatrix) = (n = size(Y,2); reshape(Float64[corkendall!(float(copy(x)), float(Y[:,i])) for i in 1:n], 1, n))
89+
function corkendall(X::RealMatrix, y::RealVector)
90+
permy = sortperm(y)
91+
return([corkendall!(copy(y), X[:,i], permy) for i in 1:size(X, 2)])
92+
end
10893

109-
corkendall(X::RealMatrix, Y::RealMatrix) = Float64[corkendall!(float(X[:,i]), float(Y[:,j])) for i in 1:size(X, 2), j in 1:size(Y, 2)]
94+
function corkendall(x::RealVector, Y::RealMatrix)
95+
n = size(Y, 2)
96+
permx = sortperm(x)
97+
return(reshape([corkendall!(copy(x), Y[:,i], permx) for i in 1:n], 1, n))
98+
end
11099

111100
function corkendall(X::RealMatrix)
112101
n = size(X, 2)
113-
C = Matrix{eltype(X)}(I, n, n)
102+
C = Matrix{Float64}(I, n, n)
114103
for j = 2:n
115-
for i = 1:j-1
116-
C[i,j] = corkendall!(X[:,i],X[:,j])
117-
C[j,i] = C[i,j]
104+
permx = sortperm(X[:,j])
105+
for i = 1:j - 1
106+
C[j,i] = corkendall!(X[:,j], X[:,i], permx)
107+
C[i,j] = C[j,i]
108+
end
109+
end
110+
return C
111+
end
112+
113+
function corkendall(X::RealMatrix, Y::RealMatrix)
114+
nr = size(X, 2)
115+
nc = size(Y, 2)
116+
C = Matrix{Float64}(undef, nr, nc)
117+
for j = 1:nr
118+
permx = sortperm(X[:,j])
119+
for i = 1:nc
120+
C[j,i] = corkendall!(X[:,j], Y[:,i], permx)
118121
end
119122
end
120123
return C
121124
end
122125

123126
# Auxilliary functions for Kendall's rank correlation
124127

125-
function swaps!(x::RealVector)
126-
n = length(x)
127-
if n == 1 return 0 end
128-
n2 = div(n, 2)
129-
xl = view(x, 1:n2)
130-
xr = view(x, n2+1:n)
131-
nsl = swaps!(xl)
132-
nsr = swaps!(xr)
133-
sort!(xl)
134-
sort!(xr)
135-
return nsl + nsr + mswaps(xl,xr)
128+
"""
129+
countties(x::RealVector, lo::Integer, hi::Integer)
130+
131+
Return the number of ties within `x[lo:hi]`. Assumes `x` is sorted.
132+
"""
133+
function countties(x::AbstractVector, lo::Integer, hi::Integer)
134+
# Use of widen below prevents possible overflow errors when
135+
# length(x) exceeds 2^16 (32 bit) or 2^32 (64 bit)
136+
thistiecount = result = widen(0)
137+
checkbounds(x, lo:hi)
138+
@inbounds for i = (lo + 1):hi
139+
if x[i] == x[i - 1]
140+
thistiecount += 1
141+
elseif thistiecount > 0
142+
result += div(thistiecount * (thistiecount + 1), 2)
143+
thistiecount = widen(0)
144+
end
145+
end
146+
147+
if thistiecount > 0
148+
result += div(thistiecount * (thistiecount + 1), 2)
149+
end
150+
result
136151
end
137152

138-
function mswaps(x::RealVector, y::RealVector)
139-
i = 1
140-
j = 1
141-
nSwaps = 0
142-
n = length(x)
143-
while i <= n && j <= length(y)
144-
if y[j] < x[i]
145-
nSwaps += n - i + 1
153+
# Tests appear to show that a value of 64 is optimal,
154+
# but note that the equivalent constant in base/sort.jl is 20.
155+
const SMALL_THRESHOLD = 64
156+
157+
# merge_sort! copied from Julia Base
158+
# (commit 28330a2fef4d9d149ba0fd3ffa06347b50067647, dated 20 Sep 2020)
159+
"""
160+
merge_sort!(v::AbstractVector, lo::Integer, hi::Integer, t::AbstractVector=similar(v, 0))
161+
162+
Mutates `v` by sorting elements `x[lo:hi]` using the merge sort algorithm.
163+
This method is a copy-paste-edit of sort! in base/sort.jl, amended to return the bubblesort distance.
164+
"""
165+
function merge_sort!(v::AbstractVector, lo::Integer, hi::Integer, t::AbstractVector=similar(v, 0))
166+
# Use of widen below prevents possible overflow errors when
167+
# length(v) exceeds 2^16 (32 bit) or 2^32 (64 bit)
168+
nswaps = widen(0)
169+
@inbounds if lo < hi
170+
hi - lo <= SMALL_THRESHOLD && return insertion_sort!(v, lo, hi)
171+
172+
m = midpoint(lo, hi)
173+
(length(t) < m - lo + 1) && resize!(t, m - lo + 1)
174+
175+
nswaps = merge_sort!(v, lo, m, t)
176+
nswaps += merge_sort!(v, m + 1, hi, t)
177+
178+
i, j = 1, lo
179+
while j <= m
180+
t[i] = v[j]
181+
i += 1
146182
j += 1
147-
else
183+
end
184+
185+
i, k = 1, lo
186+
while k < j <= hi
187+
if v[j] < t[i]
188+
v[k] = v[j]
189+
j += 1
190+
nswaps += m - lo + 1 - (i - 1)
191+
else
192+
v[k] = t[i]
193+
i += 1
194+
end
195+
k += 1
196+
end
197+
while k < j
198+
v[k] = t[i]
199+
k += 1
148200
i += 1
149201
end
150202
end
151-
return nSwaps
203+
return nswaps
152204
end
153205

206+
# insertion_sort! and midpoint copied from Julia Base
207+
# (commit 28330a2fef4d9d149ba0fd3ffa06347b50067647, dated 20 Sep 2020)
208+
midpoint(lo::T, hi::T) where T <: Integer = lo + ((hi - lo) >>> 0x01)
209+
midpoint(lo::Integer, hi::Integer) = midpoint(promote(lo, hi)...)
210+
211+
"""
212+
insertion_sort!(v::AbstractVector, lo::Integer, hi::Integer)
213+
214+
Mutates `v` by sorting elements `x[lo:hi]` using the insertion sort algorithm.
215+
This method is a copy-paste-edit of sort! in base/sort.jl, amended to return the bubblesort distance.
216+
"""
217+
function insertion_sort!(v::AbstractVector, lo::Integer, hi::Integer)
218+
if lo == hi return widen(0) end
219+
nswaps = widen(0)
220+
@inbounds for i = lo + 1:hi
221+
j = i
222+
x = v[i]
223+
while j > lo
224+
if x < v[j - 1]
225+
nswaps += 1
226+
v[j] = v[j - 1]
227+
j -= 1
228+
continue
229+
end
230+
break
231+
end
232+
v[j] = x
233+
end
234+
return nswaps
235+
end

test/rankcorr.jl

Lines changed: 73 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,86 @@ c22 = corspearman(x2, x2)
2323
@test corspearman(X, X) [c11 c12; c12 c22]
2424
@test corspearman(X) [c11 c12; c12 c22]
2525

26-
2726
# corkendall
2827

29-
@test corkendall(x1, y) -0.105409255338946
30-
@test corkendall(x2, y) -0.117851130197758
28+
# Check error, handling of NaN, Inf etc
29+
@test_throws ErrorException("Vectors must have same length") corkendall([1,2,3,4], [1,2,3])
30+
@test isnan(corkendall([1,2], [3,NaN]))
31+
@test isnan(corkendall([1,1,1], [1,2,3]))
32+
@test corkendall([-Inf,-0.0,Inf],[1,2,3]) == 1.0
33+
34+
# Test, with exact equality, some known results.
35+
# RealVector, RealVector
36+
@test corkendall(x1, y) == -1/sqrt(90)
37+
@test corkendall(x2, y) == -1/sqrt(72)
38+
# RealMatrix, RealVector
39+
@test corkendall(X, y) == [-1/sqrt(90), -1/sqrt(72)]
40+
# RealVector, RealMatrix
41+
@test corkendall(y, X) == [-1/sqrt(90) -1/sqrt(72)]
42+
43+
# n = 78_000 tests for overflow errors on 32 bit
44+
# Testing for overflow errors on 64bit would require n be too large for practicality
45+
# This also tests merge_sort! since n is (much) greater than SMALL_THRESHOLD.
46+
n = 78_000
47+
# Test with many repeats
48+
@test corkendall(repeat(x1, n), repeat(y, n)) -1/sqrt(90)
49+
@test corkendall(repeat(x2, n), repeat(y, n)) -1/sqrt(72)
50+
@test corkendall(repeat(X, n), repeat(y, n)) [-1/sqrt(90), -1/sqrt(72)]
51+
@test corkendall(repeat(y, n), repeat(X, n)) [-1/sqrt(90) -1/sqrt(72)]
52+
@test corkendall(repeat([0,1,1,0], n), repeat([1,0,1,0], n)) == 0.0
53+
54+
# Test with no repeats, note testing for exact equality
55+
@test corkendall(collect(1:n), collect(1:n)) == 1.0
56+
@test corkendall(collect(1:n), reverse(collect(1:n))) == -1.0
3157

32-
@test corkendall(X, y) [-0.105409255338946, -0.117851130197758]
33-
@test corkendall(y, X) [-0.105409255338946 -0.117851130197758]
58+
# All elements identical should yield NaN
59+
@test isnan(corkendall(repeat([1], n), collect(1:n)))
3460

3561
c11 = corkendall(x1, x1)
3662
c12 = corkendall(x1, x2)
3763
c22 = corkendall(x2, x2)
3864

39-
@test c11 1.0
40-
@test c22 1.0
65+
# RealMatrix, RealMatrix
4166
@test corkendall(X, X) [c11 c12; c12 c22]
67+
# RealMatrix
4268
@test corkendall(X) [c11 c12; c12 c22]
69+
70+
@test c11 == 1.0
71+
@test c22 == 1.0
72+
@test c12 == 3/sqrt(20)
73+
74+
# Finished testing for overflow, so redefine n for speedier tests
75+
n = 100
76+
77+
@test corkendall(repeat(X, n), repeat(X, n)) [c11 c12; c12 c22]
78+
@test corkendall(repeat(X, n)) [c11 c12; c12 c22]
79+
80+
# All eight three-element permutations
81+
z = [1 1 1;
82+
1 1 2;
83+
1 2 2;
84+
1 2 2;
85+
1 2 1;
86+
2 1 2;
87+
1 1 2;
88+
2 2 2]
89+
90+
@test corkendall(z) == [1 0 1/3; 0 1 0; 1/3 0 1]
91+
@test corkendall(z, z) == [1 0 1/3; 0 1 0; 1/3 0 1]
92+
@test corkendall(z[:,1], z) == [1 0 1/3]
93+
@test corkendall(z, z[:,1]) == [1; 0; 1/3]
94+
95+
z = float(z)
96+
@test corkendall(z) == [1 0 1/3; 0 1 0; 1/3 0 1]
97+
@test corkendall(z, z) == [1 0 1/3; 0 1 0; 1/3 0 1]
98+
@test corkendall(z[:,1], z) == [1 0 1/3]
99+
@test corkendall(z, z[:,1]) == [1; 0; 1/3]
100+
101+
w = repeat(z, n)
102+
@test corkendall(w) == [1 0 1/3; 0 1 0; 1/3 0 1]
103+
@test corkendall(w, w) == [1 0 1/3; 0 1 0; 1/3 0 1]
104+
@test corkendall(w[:,1], w) == [1 0 1/3]
105+
@test corkendall(w, w[:,1]) == [1; 0; 1/3]
106+
107+
StatsBase.midpoint(1,10) == 5
108+
StatsBase.midpoint(1,widen(10)) == 5

0 commit comments

Comments
 (0)