@@ -43,6 +43,8 @@ class PrefixSpan private (
4343 private var minSupport : Double ,
4444 private var maxPatternLength : Int ) extends Logging with Serializable {
4545
46+ private val minPatternsBeforeShuffle : Int = 20
47+
4648 /**
4749 * Constructs a default instance with default parameters
4850 * {minSupport: `0.1`, maxPatternLength: `10`}.
@@ -86,16 +88,69 @@ class PrefixSpan private (
8688 getFreqItemAndCounts(minCount, sequences).collect()
8789 val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase(
8890 lengthOnePatternsAndCounts.map(_._1), sequences)
89- val groupedProjectedDatabase = prefixAndProjectedDatabase
90- .map(x => (x._1.toSeq, x._2))
91- .groupByKey()
92- .map(x => (x._1.toArray, x._2.toArray))
93- val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase)
94- val lengthOnePatternsAndCountsRdd =
95- sequences.sparkContext.parallelize(
96- lengthOnePatternsAndCounts.map(x => (Array (x._1), x._2)))
97- val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns
98- allPatterns
91+
92+ var patternsCount = lengthOnePatternsAndCounts.length
93+ var allPatternAndCounts = sequences.sparkContext.parallelize(
94+ lengthOnePatternsAndCounts.map(x => (Array (x._1), x._2)))
95+ var currentProjectedDatabase = prefixAndProjectedDatabase
96+ while (patternsCount <= minPatternsBeforeShuffle &&
97+ currentProjectedDatabase.count() != 0 ) {
98+ val (nextPatternAndCounts, nextProjectedDatabase) =
99+ getPatternCountsAndProjectedDatabase(minCount, currentProjectedDatabase)
100+ patternsCount = nextPatternAndCounts.count().toInt
101+ currentProjectedDatabase = nextProjectedDatabase
102+ allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts
103+ }
104+ if (patternsCount > 0 ) {
105+ val groupedProjectedDatabase = currentProjectedDatabase
106+ .map(x => (x._1.toSeq, x._2))
107+ .groupByKey()
108+ .map(x => (x._1.toArray, x._2.toArray))
109+ val nextPatternAndCounts = getPatternsInLocal(minCount, groupedProjectedDatabase)
110+ allPatternAndCounts = allPatternAndCounts ++ nextPatternAndCounts
111+ }
112+ allPatternAndCounts
113+ }
114+
115+ /**
116+ * Get the pattern and counts, and projected database
117+ * @param minCount minimum count
118+ * @param prefixAndProjectedDatabase prefix and projected database,
119+ * @return pattern and counts, and projected database
120+ * (Array[pattern, count], RDD[prefix, projected database ])
121+ */
122+ private def getPatternCountsAndProjectedDatabase (
123+ minCount : Long ,
124+ prefixAndProjectedDatabase : RDD [(Array [Int ], Array [Int ])]):
125+ (RDD [(Array [Int ], Long )], RDD [(Array [Int ], Array [Int ])]) = {
126+ val prefixAndFreqentItemAndCounts = prefixAndProjectedDatabase.flatMap{ x =>
127+ x._2.distinct.map(y => ((x._1.toSeq, y), 1L ))
128+ }.reduceByKey(_+_)
129+ .filter(_._2 >= minCount)
130+ val patternAndCounts = prefixAndFreqentItemAndCounts
131+ .map(x => (x._1._1.toArray ++ Array (x._1._2), x._2))
132+ val prefixlength = prefixAndProjectedDatabase.take(1 )(0 )._1.length
133+ if (prefixlength + 1 >= maxPatternLength) {
134+ (patternAndCounts, prefixAndProjectedDatabase.filter(x => false ))
135+ } else {
136+ val frequentItemsMap = prefixAndFreqentItemAndCounts
137+ .keys.map(x => (x._1, x._2))
138+ .groupByKey()
139+ .mapValues(_.toSet)
140+ .collect
141+ .toMap
142+ val nextPrefixAndProjectedDatabase = prefixAndProjectedDatabase
143+ .filter(x => frequentItemsMap.contains(x._1))
144+ .flatMap { x =>
145+ val frequentItemSet = frequentItemsMap(x._1)
146+ val filteredSequence = x._2.filter(frequentItemSet.contains(_))
147+ val subProjectedDabase = frequentItemSet.map{ y =>
148+ (y, LocalPrefixSpan .getSuffix(y, filteredSequence))
149+ }.filter(_._2.nonEmpty)
150+ subProjectedDabase.map(y => (x._1 ++ Array (y._1), y._2))
151+ }
152+ (patternAndCounts, nextPrefixAndProjectedDatabase)
153+ }
99154 }
100155
101156 /**
0 commit comments