Skip to content

Commit 59db2f5

Browse files
committed
Merge pull request #1 from mengxr/SPARK-8997
update LocalPrefixSpan impl
2 parents 9212256 + 91e4357 commit 59db2f5

File tree

2 files changed

+23
-30
lines changed

2 files changed

+23
-30
lines changed

mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala

Lines changed: 21 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -30,34 +30,25 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
3030
* Calculate all patterns of a projected database.
3131
* @param minCount minimum count
3232
* @param maxPatternLength maximum pattern length
33-
* @param prefix prefix
34-
* @param database the projected dabase
33+
* @param prefixes prefixes in reversed order
34+
* @param database the projected database
3535
* @return a set of sequential pattern pairs,
36-
* the key of pair is sequential pattern (a list of items),
36+
* the key of pair is sequential pattern (a list of items in reversed order),
3737
* the value of pair is the pattern's count.
3838
*/
3939
def run(
4040
minCount: Long,
4141
maxPatternLength: Int,
42-
prefix: List[Int],
42+
prefixes: List[Int],
4343
database: Array[Array[Int]]): Iterator[(List[Int], Long)] = {
44-
45-
if (database.isEmpty) return Iterator.empty
46-
44+
if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty
4745
val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
48-
val frequentItems = frequentItemAndCounts.map(_._1).toSet
49-
val frequentPatternAndCounts = frequentItemAndCounts
50-
.map { case (item, count) => ((item :: prefix), count) }
51-
52-
53-
if (prefix.length + 1 < maxPatternLength) {
54-
val filteredProjectedDatabase = database.map(x => x.filter(frequentItems.contains(_)))
55-
frequentPatternAndCounts.iterator ++ frequentItems.flatMap { item =>
56-
val nextProjected = project(filteredProjectedDatabase, item)
57-
run(minCount, maxPatternLength, item :: prefix, nextProjected)
58-
}
59-
} else {
60-
frequentPatternAndCounts.iterator
46+
val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains))
47+
frequentItemAndCounts.iterator.flatMap { case (item, count) =>
48+
val newPrefixes = item :: prefixes
49+
val newProjected = project(filteredDatabase, item)
50+
Iterator.single((newPrefixes, count)) ++
51+
run(minCount, maxPatternLength, newPrefixes, newProjected)
6152
}
6253
}
6354

@@ -78,24 +69,26 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
7869

7970
def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = {
8071
database
81-
.map(candidateSeq => getSuffix(prefix, candidateSeq))
72+
.map(getSuffix(prefix, _))
8273
.filter(_.nonEmpty)
8374
}
8475

8576
/**
8677
* Generates frequent items by filtering the input data using minimal count level.
8778
* @param minCount the minimum count for an item to be frequent
8879
* @param database database of sequences
89-
* @return item and count pairs
80+
* @return freq item to count map
9081
*/
9182
private def getFreqItemAndCounts(
9283
minCount: Long,
93-
database: Array[Array[Int]]): Iterable[(Int, Long)] = {
94-
database.flatMap(_.distinct)
95-
.foldRight(mutable.Map[Int, Long]().withDefaultValue(0L)) { case (item, ctr) =>
96-
ctr(item) += 1
97-
ctr
84+
database: Array[Array[Int]]): mutable.Map[Int, Long] = {
85+
// TODO: use PrimitiveKeyOpenHashMap
86+
val counts = mutable.Map[Int, Long]().withDefaultValue(0L)
87+
database.foreach { sequence =>
88+
sequence.distinct.foreach { item =>
89+
counts(item) += 1L
9890
}
99-
.filter(_._2 >= minCount)
91+
}
92+
counts.filter(_._2 >= minCount)
10093
}
10194
}

mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
4747
def compareResult(
4848
expectedValue: Array[(Array[Int], Long)],
4949
actualValue: Array[(Array[Int], Long)]): Boolean = {
50-
expectedValue.map(x => (x._1.toList, x._2)).toSet ==
51-
actualValue.map(x => (x._1.toList, x._2)).toSet
50+
expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
51+
actualValue.map(x => (x._1.toSeq, x._2)).toSet
5252
}
5353

5454
val prefixspan = new PrefixSpan()

0 commit comments

Comments
 (0)