Skip to content

Commit dc77e38

Browse files
committed
test sparse vector RDD
1 parent 18cf072 commit dc77e38

File tree

2 files changed

+43
-35
lines changed

2 files changed

+43
-35
lines changed

mllib/src/main/scala/org/apache/spark/mllib/rdd/VectorRDDFunctions.scala

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ private class Aggregator(
4949
val deltaMean = currMean
5050
var i = 0
5151
while(i < currM2n.size) {
52-
currM2n(i) -= deltaMean(i) * deltaMean(i) * nnz(i) * (nnz(i)-totalCnt) / totalCnt
52+
currM2n(i) += deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt-nnz(i)) / totalCnt
5353
currM2n(i) /= totalCnt
5454
i += 1
5555
}
@@ -61,15 +61,15 @@ private class Aggregator(
6161
override lazy val numNonZeros: Vector = Vectors.fromBreeze(nnz)
6262

6363
override lazy val max: Vector = {
64-
nnz.activeIterator.foreach {
64+
nnz.iterator.foreach {
6565
case (id, count) =>
6666
if ((count == 0.0) || ((count < totalCnt) && (currMax(id) < 0.0))) currMax(id) = 0.0
6767
}
6868
Vectors.fromBreeze(currMax)
6969
}
7070

7171
override lazy val min: Vector = {
72-
nnz.activeIterator.foreach {
72+
nnz.iterator.foreach {
7373
case (id, count) =>
7474
if ((count == 0.0) || ((count < totalCnt) && (currMin(id) > 0.0))) currMin(id) = 0.0
7575
}
@@ -88,7 +88,7 @@ private class Aggregator(
8888
if (currMin(id) > value) currMin(id) = value
8989

9090
val tmpPrevMean = currMean(id)
91-
currMean(id) = (currMean(id) * totalCnt + value) / (totalCnt + 1.0)
91+
currMean(id) = (currMean(id) * nnz(id) + value) / (nnz(id) + 1.0)
9292
currM2n(id) += (value - currMean(id)) * (value - tmpPrevMean)
9393

9494
nnz(id) += 1.0
@@ -114,11 +114,14 @@ private class Aggregator(
114114
(currMean(id) * nnz(id) + other.currMean(id) * other.nnz(id)) / (nnz(id) + other.nnz(id))
115115
}
116116

117-
other.currM2n.activeIterator.foreach {
118-
case (id, 0.0) =>
119-
case (id, value) =>
120-
currM2n(id) +=
121-
value + deltaMean(id) * deltaMean(id) * nnz(id) * other.nnz(id) / (nnz(id)+other.nnz(id))
117+
var i = 0
118+
while(i < currM2n.size) {
119+
(nnz(i), other.nnz(i)) match {
120+
case (0.0, 0.0) =>
121+
case _ => currM2n(i) +=
122+
other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i) / (nnz(i)+other.nnz(i))
123+
}
124+
i += 1
122125
}
123126

124127
other.currMax.activeIterator.foreach {

mllib/src/test/scala/org/apache/spark/mllib/rdd/VectorRDDFunctionsSuite.scala

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -38,54 +38,59 @@ class VectorRDDFunctionsSuite extends FunSuite with LocalSparkContext {
3838
Vectors.dense(7.0, 8.0, 9.0)
3939
)
4040

41-
val sparseData = ArrayBuffer(Vectors.sparse(20, Seq((0, 1.0), (9, 2.0), (10, 7.0))))
42-
for (i <- 0 until 100) sparseData += Vectors.sparse(20, Seq((9, 0.0)))
43-
sparseData += Vectors.sparse(20, Seq((0, 5.0), (9, 13.0), (16, 2.0)))
44-
sparseData += Vectors.sparse(20, Seq((3, 5.0), (9, 13.0), (18, 2.0)))
41+
val sparseData = ArrayBuffer(Vectors.sparse(3, Seq((0, 1.0))))
42+
for (i <- 0 until 97) sparseData += Vectors.sparse(3, Seq((2, 0.0)))
43+
sparseData += Vectors.sparse(3, Seq((0, 5.0)))
44+
sparseData += Vectors.sparse(3, Seq((1, 5.0)))
4545

46-
test("full-statistics") {
46+
test("dense statistical summary") {
4747
val data = sc.parallelize(localData, 2)
48-
val (summary, denseTime) =
49-
time(data.summarizeStatistics())
48+
val summary = data.summarizeStatistics()
5049

5150
assert(equivVector(summary.mean, Vectors.dense(4.0, 5.0, 6.0)),
52-
"Column mean do not match.")
51+
"Dense column mean do not match.")
5352

5453
assert(equivVector(summary.variance, Vectors.dense(6.0, 6.0, 6.0)),
55-
"Column variance do not match.")
54+
"Dense column variance do not match.")
5655

57-
assert(summary.totalCount === 3, "Column cnt do not match.")
56+
assert(summary.totalCount === 3, "Dense column cnt do not match.")
5857

5958
assert(equivVector(summary.numNonZeros, Vectors.dense(3.0, 3.0, 3.0)),
60-
"Column nnz do not match.")
59+
"Dense column nnz do not match.")
6160

6261
assert(equivVector(summary.max, Vectors.dense(7.0, 8.0, 9.0)),
63-
"Column max do not match.")
62+
"Dense column max do not match.")
6463

6564
assert(equivVector(summary.min, Vectors.dense(1.0, 2.0, 3.0)),
66-
"Column min do not match.")
65+
"Dense column min do not match.")
66+
}
6767

68+
test("sparse statistical summary") {
6869
val dataForSparse = sc.parallelize(sparseData.toSeq, 2)
69-
val (_, sparseTime) = time(dataForSparse.summarizeStatistics())
70+
val summary = dataForSparse.summarizeStatistics()
71+
72+
assert(equivVector(summary.mean, Vectors.dense(0.06, 0.05, 0.0)),
73+
"Sparse column mean do not match.")
74+
75+
assert(equivVector(summary.variance, Vectors.dense(0.2564, 0.2475, 0.0)),
76+
"Sparse column variance do not match.")
77+
78+
assert(summary.totalCount === 100, "Sparse column cnt do not match.")
79+
80+
assert(equivVector(summary.numNonZeros, Vectors.dense(2.0, 1.0, 0.0)),
81+
"Sparse column nnz do not match.")
7082

71-
println(s"dense time is $denseTime, sparse time is $sparseTime.")
83+
assert(equivVector(summary.max, Vectors.dense(5.0, 5.0, 0.0)),
84+
"Sparse column max do not match.")
85+
86+
assert(equivVector(summary.min, Vectors.dense(0.0, 0.0, 0.0)),
87+
"Sparse column min do not match.")
7288
}
7389
}
7490

7591
object VectorRDDFunctionsSuite {
76-
def time[R](block: => R): (R, Double) = {
77-
val t0 = System.nanoTime()
78-
val result = block
79-
val t1 = System.nanoTime()
80-
(result, (t1 - t0).toDouble / 1.0e9)
81-
}
8292

8393
def equivVector(lhs: Vector, rhs: Vector): Boolean = {
8494
(lhs.toBreeze - rhs.toBreeze).norm(2) < 1e-9
8595
}
86-
87-
def relativeTime(lhs: Double, rhs: Double): Boolean = {
88-
val denominator = math.max(lhs, rhs)
89-
math.abs(lhs - rhs) / denominator < 0.3
90-
}
9196
}

0 commit comments

Comments
 (0)