1717
1818package org .apache .spark .mllib .fpm
1919
20- import java .lang .{Iterable => JavaIterable }
2120import java .{util => ju }
21+ import java .lang .{Iterable => JavaIterable }
2222
23- import scala .collection .JavaConverters ._
2423import scala .collection .mutable
24+ import scala .collection .JavaConverters ._
2525import scala .reflect .ClassTag
2626
27- import org .apache .spark .api .java .JavaRDD
27+ import org .apache .spark .{HashPartitioner , Logging , Partitioner , SparkException }
28+ import org .apache .spark .api .java .{JavaPairRDD , JavaRDD }
29+ import org .apache .spark .api .java .JavaSparkContext .fakeClassTag
2830import org .apache .spark .rdd .RDD
2931import org .apache .spark .storage .StorageLevel
30- import org .apache .spark .{HashPartitioner , Logging , Partitioner , SparkException }
3132
32- class FPGrowthModel [Item ](val freqItemsets : RDD [(Array [Item ], Long )]) extends Serializable {
33- def javaFreqItemsets (): JavaRDD [(Array [Item ], Long )] = {
34- freqItemsets.toJavaRDD()
33+ /**
34+ * Model trained by [[FPGrowth ]], which holds frequent itemsets.
35+ * @param freqItemsets frequent itemset, which is an RDD of (itemset, frequency) pairs
36+ * @tparam Item item type
37+ */
38+ class FPGrowthModel [Item : ClassTag ](
39+ val freqItemsets : RDD [(Array [Item ], Long )]) extends Serializable {
40+
41+ /** Returns frequent itemsets as a [[org.apache.spark.api.java.JavaPairRDD ]]. */
42+ def javaFreqItemsets (): JavaPairRDD [Array [Item ], java.lang.Long ] = {
43+ JavaPairRDD .fromRDD(freqItemsets).asInstanceOf [JavaPairRDD [Array [Item ], java.lang.Long ]]
3544 }
3645}
3746
@@ -77,22 +86,22 @@ class FPGrowth private (
7786 * @param data input data set, each element contains a transaction
7887 * @return an [[FPGrowthModel ]]
7988 */
80- def run [Item : ClassTag , Basket <: Iterable [ Item ]] (data : RDD [Basket ]): FPGrowthModel [Item ] = {
89+ def run [Item : ClassTag ] (data : RDD [Array [ Item ] ]): FPGrowthModel [Item ] = {
8190 if (data.getStorageLevel == StorageLevel .NONE ) {
8291 logWarning(" Input data is not cached." )
8392 }
8493 val count = data.count()
8594 val minCount = math.ceil(minSupport * count).toLong
8695 val numParts = if (numPartitions > 0 ) numPartitions else data.partitions.length
8796 val partitioner = new HashPartitioner (numParts)
88- val freqItems = genFreqItems[ Item , Basket ] (data, minCount, partitioner)
89- val freqItemsets = genFreqItemsets[ Item , Basket ] (data, minCount, freqItems, partitioner)
97+ val freqItems = genFreqItems(data, minCount, partitioner)
98+ val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner)
9099 new FPGrowthModel (freqItemsets)
91100 }
92101
93- def run [Item : ClassTag , Basket <: JavaIterable [Item ]](
94- data : JavaRDD [ Basket ]) : FPGrowthModel [Item ] = {
95- this . run(data.rdd.map(_.asScala))
102+ def run [Item , Basket <: JavaIterable [Item ]](data : JavaRDD [ Basket ]) : FPGrowthModel [ Item ] = {
103+ implicit val tag = fakeClassTag [Item ]
104+ run(data.rdd.map(_.asScala.toArray ))
96105 }
97106
98107 /**
@@ -101,8 +110,8 @@ class FPGrowth private (
101110 * @param partitioner partitioner used to distribute items
102111 * @return array of frequent pattern ordered by their frequencies
103112 */
104- private def genFreqItems [Item : ClassTag , Basket <: Iterable [ Item ] ](
105- data : RDD [Basket ],
113+ private def genFreqItems [Item : ClassTag ](
114+ data : RDD [Array [ Item ] ],
106115 minCount : Long ,
107116 partitioner : Partitioner ): Array [Item ] = {
108117 data.flatMap { t =>
@@ -127,8 +136,8 @@ class FPGrowth private (
127136 * @param partitioner partitioner used to distribute transactions
128137 * @return an RDD of (frequent itemset, count)
129138 */
130- private def genFreqItemsets [Item : ClassTag , Basket <: Iterable [ Item ] ](
131- data : RDD [Basket ],
139+ private def genFreqItemsets [Item : ClassTag ](
140+ data : RDD [Array [ Item ] ],
132141 minCount : Long ,
133142 freqItems : Array [Item ],
134143 partitioner : Partitioner ): RDD [(Array [Item ], Long )] = {
@@ -152,13 +161,13 @@ class FPGrowth private (
152161 * @param partitioner partitioner used to distribute transactions
153162 * @return a map of (target partition, conditional transaction)
154163 */
155- private def genCondTransactions [Item : ClassTag , Basket <: Iterable [ Item ] ](
156- transaction : Basket ,
164+ private def genCondTransactions [Item : ClassTag ](
165+ transaction : Array [ Item ] ,
157166 itemToRank : Map [Item , Int ],
158167 partitioner : Partitioner ): mutable.Map [Int , Array [Int ]] = {
159168 val output = mutable.Map .empty[Int , Array [Int ]]
160169 // Filter the basket by frequent items pattern and sort their ranks.
161- val filtered = transaction.flatMap(itemToRank.get).toArray
170+ val filtered = transaction.flatMap(itemToRank.get)
162171 ju.Arrays .sort(filtered)
163172 val n = filtered.length
164173 var i = n - 1
0 commit comments