Skip to content

Commit b14cd23

Browse files
committed
[SPARK-7140] [MLLIB] only scan the first 16 entries in Vector.hashCode
The Python SerDe calls `Object.hashCode`, which is very expensive for Vectors. It is not necessary to scan the whole vector, especially for large ones. In this PR, we only scan the first 16 nonzeros. srowen Author: Xiangrui Meng <[email protected]> Closes apache#5697 from mengxr/SPARK-7140 and squashes the following commits: 2abc86d [Xiangrui Meng] typo 8fb7d74 [Xiangrui Meng] update impl 1ebad60 [Xiangrui Meng] only scan the first 16 nonzeros in Vector.hashCode
1 parent 6a827d5 commit b14cd23

File tree

1 file changed

+67
-21
lines changed

1 file changed

+67
-21
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala

Lines changed: 67 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ sealed trait Vector extends Serializable {
5252

5353
override def equals(other: Any): Boolean = {
5454
other match {
55-
case v2: Vector => {
55+
case v2: Vector =>
5656
if (this.size != v2.size) return false
5757
(this, v2) match {
5858
case (s1: SparseVector, s2: SparseVector) =>
@@ -63,20 +63,28 @@ sealed trait Vector extends Serializable {
6363
Vectors.equals(0 until d1.size, d1.values, s1.indices, s1.values)
6464
case (_, _) => util.Arrays.equals(this.toArray, v2.toArray)
6565
}
66-
}
6766
case _ => false
6867
}
6968
}
7069

70+
/**
71+
* Returns a hash code value for the vector. The hash code is based on its size and its nonzeros
72+
* in the first 16 entries, using a hash algorithm similar to [[java.util.Arrays.hashCode]].
73+
*/
7174
override def hashCode(): Int = {
72-
var result: Int = size + 31
73-
this.foreachActive { case (index, value) =>
74-
// ignore explict 0 for comparison between sparse and dense
75-
if (value != 0) {
76-
result = 31 * result + index
77-
// refer to {@link java.util.Arrays.equals} for hash algorithm
78-
val bits = java.lang.Double.doubleToLongBits(value)
79-
result = 31 * result + (bits ^ (bits >>> 32)).toInt
75+
// This is a reference implementation. It calls return in foreachActive, which is slow.
76+
// Subclasses should override it with optimized implementation.
77+
var result: Int = 31 + size
78+
this.foreachActive { (index, value) =>
79+
if (index < 16) {
80+
// ignore explicit 0 for comparison between sparse and dense
81+
if (value != 0) {
82+
result = 31 * result + index
83+
val bits = java.lang.Double.doubleToLongBits(value)
84+
result = 31 * result + (bits ^ (bits >>> 32)).toInt
85+
}
86+
} else {
87+
return result
8088
}
8189
}
8290
result
@@ -317,7 +325,7 @@ object Vectors {
317325
case SparseVector(n, ids, vs) => vs
318326
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
319327
}
320-
val size = values.size
328+
val size = values.length
321329

322330
if (p == 1) {
323331
var sum = 0.0
@@ -371,8 +379,8 @@ object Vectors {
371379
val v1Indices = v1.indices
372380
val v2Values = v2.values
373381
val v2Indices = v2.indices
374-
val nnzv1 = v1Indices.size
375-
val nnzv2 = v2Indices.size
382+
val nnzv1 = v1Indices.length
383+
val nnzv2 = v2Indices.length
376384

377385
var kv1 = 0
378386
var kv2 = 0
@@ -401,7 +409,7 @@ object Vectors {
401409

402410
case (DenseVector(vv1), DenseVector(vv2)) =>
403411
var kv = 0
404-
val sz = vv1.size
412+
val sz = vv1.length
405413
while (kv < sz) {
406414
val score = vv1(kv) - vv2(kv)
407415
squaredDistance += score * score
@@ -422,7 +430,7 @@ object Vectors {
422430
var kv2 = 0
423431
val indices = v1.indices
424432
var squaredDistance = 0.0
425-
val nnzv1 = indices.size
433+
val nnzv1 = indices.length
426434
val nnzv2 = v2.size
427435
var iv1 = if (nnzv1 > 0) indices(kv1) else -1
428436

@@ -451,8 +459,8 @@ object Vectors {
451459
v1Values: Array[Double],
452460
v2Indices: IndexedSeq[Int],
453461
v2Values: Array[Double]): Boolean = {
454-
val v1Size = v1Values.size
455-
val v2Size = v2Values.size
462+
val v1Size = v1Values.length
463+
val v2Size = v2Values.length
456464
var k1 = 0
457465
var k2 = 0
458466
var allEqual = true
@@ -493,14 +501,30 @@ class DenseVector(val values: Array[Double]) extends Vector {
493501

494502
private[spark] override def foreachActive(f: (Int, Double) => Unit) = {
495503
var i = 0
496-
val localValuesSize = values.size
504+
val localValuesSize = values.length
497505
val localValues = values
498506

499507
while (i < localValuesSize) {
500508
f(i, localValues(i))
501509
i += 1
502510
}
503511
}
512+
513+
override def hashCode(): Int = {
514+
var result: Int = 31 + size
515+
var i = 0
516+
val end = math.min(values.length, 16)
517+
while (i < end) {
518+
val v = values(i)
519+
if (v != 0.0) {
520+
result = 31 * result + i
521+
val bits = java.lang.Double.doubleToLongBits(values(i))
522+
result = 31 * result + (bits ^ (bits >>> 32)).toInt
523+
}
524+
i += 1
525+
}
526+
result
527+
}
504528
}
505529

506530
object DenseVector {
@@ -522,8 +546,8 @@ class SparseVector(
522546
val values: Array[Double]) extends Vector {
523547

524548
require(indices.length == values.length, "Sparse vectors require that the dimension of the" +
525-
s" indices match the dimension of the values. You provided ${indices.size} indices and " +
526-
s" ${values.size} values.")
549+
s" indices match the dimension of the values. You provided ${indices.length} indices and " +
550+
s" ${values.length} values.")
527551

528552
override def toString: String =
529553
s"($size,${indices.mkString("[", ",", "]")},${values.mkString("[", ",", "]")})"
@@ -547,7 +571,7 @@ class SparseVector(
547571

548572
private[spark] override def foreachActive(f: (Int, Double) => Unit) = {
549573
var i = 0
550-
val localValuesSize = values.size
574+
val localValuesSize = values.length
551575
val localIndices = indices
552576
val localValues = values
553577

@@ -556,6 +580,28 @@ class SparseVector(
556580
i += 1
557581
}
558582
}
583+
584+
override def hashCode(): Int = {
585+
var result: Int = 31 + size
586+
val end = values.length
587+
var continue = true
588+
var k = 0
589+
while ((k < end) & continue) {
590+
val i = indices(k)
591+
if (i < 16) {
592+
val v = values(k)
593+
if (v != 0.0) {
594+
result = 31 * result + i
595+
val bits = java.lang.Double.doubleToLongBits(v)
596+
result = 31 * result + (bits ^ (bits >>> 32)).toInt
597+
}
598+
} else {
599+
continue = false
600+
}
601+
k += 1
602+
}
603+
result
604+
}
559605
}
560606

561607
object SparseVector {

0 commit comments

Comments
 (0)