@@ -49,6 +49,7 @@ class PrefixSpan private (
4949 * The maximum number of items allowed in a projected database before local processing. If a
5050 * projected database exceeds this size, another iteration of distributed PrefixSpan is run.
5151 */
52+ // TODO: make configurable with a better default value, 10000 may be too small
5253 private val maxLocalProjDBSize : Long = 10000
5354
5455 /**
@@ -61,7 +62,7 @@ class PrefixSpan private (
6162 * Get the minimal support (i.e. the frequency of occurrence before a pattern is considered
6263 * frequent).
6364 */
64- def getMinSupport () : Double = this .minSupport
65+ def getMinSupport : Double = this .minSupport
6566
6667 /**
6768 * Sets the minimal support level (default: `0.1`).
@@ -75,7 +76,7 @@ class PrefixSpan private (
7576 /**
7677 * Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider.
7778 */
78- def getMaxPatternLength () : Double = this .maxPatternLength
79+ def getMaxPatternLength : Double = this .maxPatternLength
7980
8081 /**
8182 * Sets maximal pattern length (default: `10`).
@@ -96,6 +97,8 @@ class PrefixSpan private (
9697 * the value of pair is the pattern's count.
9798 */
9899 def run (sequences : RDD [Array [Int ]]): RDD [(Array [Int ], Long )] = {
100+ val sc = sequences.sparkContext
101+
99102 if (sequences.getStorageLevel == StorageLevel .NONE ) {
100103 logWarning(" Input data is not cached." )
101104 }
@@ -108,10 +111,11 @@ class PrefixSpan private (
108111 .flatMap(seq => seq.distinct.map(item => (item, 1L )))
109112 .reduceByKey(_ + _)
110113 .filter(_._2 >= minCount)
114+ .collect()
111115
112116 // Pairs of (length 1 prefix, suffix consisting of frequent items)
113117 val itemSuffixPairs = {
114- val freqItems = freqItemCounts.keys.collect( ).toSet
118+ val freqItems = freqItemCounts.map(_._1 ).toSet
115119 sequences.flatMap { seq =>
116120 val filteredSeq = seq.filter(freqItems.contains(_))
117121 freqItems.flatMap { item =>
@@ -141,13 +145,14 @@ class PrefixSpan private (
141145 pairsForDistributed = largerPairsPart
142146 pairsForDistributed.persist(StorageLevel .MEMORY_AND_DISK )
143147 pairsForLocal ++= smallerPairsPart
144- resultsAccumulator ++= nextPatternAndCounts
148+ resultsAccumulator ++= nextPatternAndCounts.collect()
145149 }
146150
147151 // Process the small projected databases locally
148- resultsAccumulator ++ = getPatternsInLocal(minCount, pairsForLocal.groupByKey())
152+ val remainingResults = getPatternsInLocal(minCount, pairsForLocal.groupByKey())
149153
150- resultsAccumulator.map { case (pattern, count) => (pattern.toArray, count) }
154+ (sc.parallelize(resultsAccumulator, 1 ) ++ remainingResults)
155+ .map { case (pattern, count) => (pattern.toArray, count) }
151156 }
152157
153158
0 commit comments