Skip to content

Commit 859f724

Browse files
Jacky Limengxr
authored andcommitted
[SPARK-4001][MLlib] adding parallel FP-Growth algorithm for frequent pattern mining in MLlib
Apriori is the classic algorithm for frequent item set mining in a transactional data set. It will be useful if Apriori algorithm is added to MLLib in Spark. This PR add an implementation for it. There is a point I am not sure wether it is most efficient. In order to filter out the eligible frequent item set, currently I am using a cartesian operation on two RDDs to calculate the degree of support of each item set, not sure wether it is better to use broadcast variable to achieve the same. I will add an example to use this algorithm if requires Author: Jacky Li <[email protected]> Author: Jacky Li <[email protected]> Author: Xiangrui Meng <[email protected]> Closes apache#2847 from jackylk/apriori and squashes the following commits: bee3093 [Jacky Li] Merge pull request #1 from mengxr/SPARK-4001 7e69725 [Xiangrui Meng] simplify FPTree and update FPGrowth ec21f7d [Jacky Li] fix scalastyle 93f3280 [Jacky Li] create FPTree class d110ab2 [Jacky Li] change test case to use MLlibTestSparkContext a6c5081 [Jacky Li] Add Parallel FPGrowth algorithm eb3e4ca [Jacky Li] add FPGrowth 03df2b6 [Jacky Li] refactory according to comments 7b77ad7 [Jacky Li] fix scalastyle check f68a0bd [Jacky Li] add 2 apriori implemenation and fp-growth implementation 889b33f [Jacky Li] modify per scalastyle check da2cba7 [Jacky Li] adding apriori algorithm for frequent item set mining in Spark
1 parent d85cd4e commit 859f724

File tree

4 files changed

+484
-0
lines changed

4 files changed

+484
-0
lines changed
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.fpm
19+
20+
import java.{util => ju}
21+
22+
import scala.collection.mutable
23+
24+
import org.apache.spark.{SparkException, HashPartitioner, Logging, Partitioner}
25+
import org.apache.spark.rdd.RDD
26+
import org.apache.spark.storage.StorageLevel
27+
28+
class FPGrowthModel(val freqItemsets: RDD[(Array[String], Long)]) extends Serializable
29+
30+
/**
31+
* This class implements Parallel FP-growth algorithm to do frequent pattern matching on input data.
32+
* Parallel FPGrowth (PFP) partitions computation in such a way that each machine executes an
33+
* independent group of mining tasks. More detail of this algorithm can be found at
34+
* [[http://dx.doi.org/10.1145/1454008.1454027, PFP]], and the original FP-growth paper can be
35+
* found at [[http://dx.doi.org/10.1145/335191.335372, FP-growth]]
36+
*
37+
* @param minSupport the minimal support level of the frequent pattern, any pattern appears
38+
* more than (minSupport * size-of-the-dataset) times will be output
39+
* @param numPartitions number of partitions used by parallel FP-growth
40+
*/
41+
class FPGrowth private (
42+
private var minSupport: Double,
43+
private var numPartitions: Int) extends Logging with Serializable {
44+
45+
/**
46+
* Constructs a FPGrowth instance with default parameters:
47+
* {minSupport: 0.3, numPartitions: auto}
48+
*/
49+
def this() = this(0.3, -1)
50+
51+
/**
52+
* Sets the minimal support level (default: 0.3).
53+
*/
54+
def setMinSupport(minSupport: Double): this.type = {
55+
this.minSupport = minSupport
56+
this
57+
}
58+
59+
/**
60+
* Sets the number of partitions used by parallel FP-growth (default: same as input data).
61+
*/
62+
def setNumPartitions(numPartitions: Int): this.type = {
63+
this.numPartitions = numPartitions
64+
this
65+
}
66+
67+
/**
68+
* Computes an FP-Growth model that contains frequent itemsets.
69+
* @param data input data set, each element contains a transaction
70+
* @return an [[FPGrowthModel]]
71+
*/
72+
def run(data: RDD[Array[String]]): FPGrowthModel = {
73+
if (data.getStorageLevel == StorageLevel.NONE) {
74+
logWarning("Input data is not cached.")
75+
}
76+
val count = data.count()
77+
val minCount = math.ceil(minSupport * count).toLong
78+
val numParts = if (numPartitions > 0) numPartitions else data.partitions.length
79+
val partitioner = new HashPartitioner(numParts)
80+
val freqItems = genFreqItems(data, minCount, partitioner)
81+
val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner)
82+
new FPGrowthModel(freqItemsets)
83+
}
84+
85+
/**
86+
* Generates frequent items by filtering the input data using minimal support level.
87+
* @param minCount minimum count for frequent itemsets
88+
* @param partitioner partitioner used to distribute items
89+
* @return array of frequent pattern ordered by their frequencies
90+
*/
91+
private def genFreqItems(
92+
data: RDD[Array[String]],
93+
minCount: Long,
94+
partitioner: Partitioner): Array[String] = {
95+
data.flatMap { t =>
96+
val uniq = t.toSet
97+
if (t.length != uniq.size) {
98+
throw new SparkException(s"Items in a transaction must be unique but got ${t.toSeq}.")
99+
}
100+
t
101+
}.map(v => (v, 1L))
102+
.reduceByKey(partitioner, _ + _)
103+
.filter(_._2 >= minCount)
104+
.collect()
105+
.sortBy(-_._2)
106+
.map(_._1)
107+
}
108+
109+
/**
110+
* Generate frequent itemsets by building FP-Trees, the extraction is done on each partition.
111+
* @param data transactions
112+
* @param minCount minimum count for frequent itemsets
113+
* @param freqItems frequent items
114+
* @param partitioner partitioner used to distribute transactions
115+
* @return an RDD of (frequent itemset, count)
116+
*/
117+
private def genFreqItemsets(
118+
data: RDD[Array[String]],
119+
minCount: Long,
120+
freqItems: Array[String],
121+
partitioner: Partitioner): RDD[(Array[String], Long)] = {
122+
val itemToRank = freqItems.zipWithIndex.toMap
123+
data.flatMap { transaction =>
124+
genCondTransactions(transaction, itemToRank, partitioner)
125+
}.aggregateByKey(new FPTree[Int], partitioner.numPartitions)(
126+
(tree, transaction) => tree.add(transaction, 1L),
127+
(tree1, tree2) => tree1.merge(tree2))
128+
.flatMap { case (part, tree) =>
129+
tree.extract(minCount, x => partitioner.getPartition(x) == part)
130+
}.map { case (ranks, count) =>
131+
(ranks.map(i => freqItems(i)).toArray, count)
132+
}
133+
}
134+
135+
/**
136+
* Generates conditional transactions.
137+
* @param transaction a transaction
138+
* @param itemToRank map from item to their rank
139+
* @param partitioner partitioner used to distribute transactions
140+
* @return a map of (target partition, conditional transaction)
141+
*/
142+
private def genCondTransactions(
143+
transaction: Array[String],
144+
itemToRank: Map[String, Int],
145+
partitioner: Partitioner): mutable.Map[Int, Array[Int]] = {
146+
val output = mutable.Map.empty[Int, Array[Int]]
147+
// Filter the basket by frequent items pattern and sort their ranks.
148+
val filtered = transaction.flatMap(itemToRank.get)
149+
ju.Arrays.sort(filtered)
150+
val n = filtered.length
151+
var i = n - 1
152+
while (i >= 0) {
153+
val item = filtered(i)
154+
val part = partitioner.getPartition(item)
155+
if (!output.contains(part)) {
156+
output(part) = filtered.slice(0, i + 1)
157+
}
158+
i -= 1
159+
}
160+
output
161+
}
162+
}
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.fpm
19+
20+
import scala.collection.mutable
21+
import scala.collection.mutable.ListBuffer
22+
23+
/**
24+
* FP-Tree data structure used in FP-Growth.
25+
* @tparam T item type
26+
*/
27+
private[fpm] class FPTree[T] extends Serializable {
28+
29+
import FPTree._
30+
31+
val root: Node[T] = new Node(null)
32+
33+
private val summaries: mutable.Map[T, Summary[T]] = mutable.Map.empty
34+
35+
/** Adds a transaction with count. */
36+
def add(t: Iterable[T], count: Long = 1L): this.type = {
37+
require(count > 0)
38+
var curr = root
39+
curr.count += count
40+
t.foreach { item =>
41+
val summary = summaries.getOrElseUpdate(item, new Summary)
42+
summary.count += count
43+
val child = curr.children.getOrElseUpdate(item, {
44+
val newNode = new Node(curr)
45+
newNode.item = item
46+
summary.nodes += newNode
47+
newNode
48+
})
49+
child.count += count
50+
curr = child
51+
}
52+
this
53+
}
54+
55+
/** Merges another FP-Tree. */
56+
def merge(other: FPTree[T]): this.type = {
57+
other.transactions.foreach { case (t, c) =>
58+
add(t, c)
59+
}
60+
this
61+
}
62+
63+
/** Gets a subtree with the suffix. */
64+
private def project(suffix: T): FPTree[T] = {
65+
val tree = new FPTree[T]
66+
if (summaries.contains(suffix)) {
67+
val summary = summaries(suffix)
68+
summary.nodes.foreach { node =>
69+
var t = List.empty[T]
70+
var curr = node.parent
71+
while (!curr.isRoot) {
72+
t = curr.item :: t
73+
curr = curr.parent
74+
}
75+
tree.add(t, node.count)
76+
}
77+
}
78+
tree
79+
}
80+
81+
/** Returns all transactions in an iterator. */
82+
def transactions: Iterator[(List[T], Long)] = getTransactions(root)
83+
84+
/** Returns all transactions under this node. */
85+
private def getTransactions(node: Node[T]): Iterator[(List[T], Long)] = {
86+
var count = node.count
87+
node.children.iterator.flatMap { case (item, child) =>
88+
getTransactions(child).map { case (t, c) =>
89+
count -= c
90+
(item :: t, c)
91+
}
92+
} ++ {
93+
if (count > 0) {
94+
Iterator.single((Nil, count))
95+
} else {
96+
Iterator.empty
97+
}
98+
}
99+
}
100+
101+
/** Extracts all patterns with valid suffix and minimum count. */
102+
def extract(
103+
minCount: Long,
104+
validateSuffix: T => Boolean = _ => true): Iterator[(List[T], Long)] = {
105+
summaries.iterator.flatMap { case (item, summary) =>
106+
if (validateSuffix(item) && summary.count >= minCount) {
107+
Iterator.single((item :: Nil, summary.count)) ++
108+
project(item).extract(minCount).map { case (t, c) =>
109+
(item :: t, c)
110+
}
111+
} else {
112+
Iterator.empty
113+
}
114+
}
115+
}
116+
}
117+
118+
private[fpm] object FPTree {
119+
120+
/** Representing a node in an FP-Tree. */
121+
class Node[T](val parent: Node[T]) extends Serializable {
122+
var item: T = _
123+
var count: Long = 0L
124+
val children: mutable.Map[T, Node[T]] = mutable.Map.empty
125+
126+
def isRoot: Boolean = parent == null
127+
}
128+
129+
/** Summary of a item in an FP-Tree. */
130+
private class Summary[T] extends Serializable {
131+
var count: Long = 0L
132+
val nodes: ListBuffer[Node[T]] = ListBuffer.empty
133+
}
134+
}
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.mllib.fpm
18+
19+
import org.scalatest.FunSuite
20+
21+
import org.apache.spark.mllib.util.MLlibTestSparkContext
22+
23+
class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
24+
25+
test("FP-Growth") {
26+
val transactions = Seq(
27+
"r z h k p",
28+
"z y x w v u t s",
29+
"s x o n r",
30+
"x z y m t s q e",
31+
"z",
32+
"x z y r q t p")
33+
.map(_.split(" "))
34+
val rdd = sc.parallelize(transactions, 2).cache()
35+
36+
val fpg = new FPGrowth()
37+
38+
val model6 = fpg
39+
.setMinSupport(0.9)
40+
.setNumPartitions(1)
41+
.run(rdd)
42+
assert(model6.freqItemsets.count() === 0)
43+
44+
val model3 = fpg
45+
.setMinSupport(0.5)
46+
.setNumPartitions(2)
47+
.run(rdd)
48+
val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) =>
49+
(items.toSet, count)
50+
}
51+
val expected = Set(
52+
(Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L),
53+
(Set("r"), 3L),
54+
(Set("x", "z"), 3L), (Set("t", "y"), 3L), (Set("t", "x"), 3L), (Set("s", "x"), 3L),
55+
(Set("y", "x"), 3L), (Set("y", "z"), 3L), (Set("t", "z"), 3L),
56+
(Set("y", "x", "z"), 3L), (Set("t", "x", "z"), 3L), (Set("t", "y", "z"), 3L),
57+
(Set("t", "y", "x"), 3L),
58+
(Set("t", "y", "x", "z"), 3L))
59+
assert(freqItemsets3.toSet === expected)
60+
61+
val model2 = fpg
62+
.setMinSupport(0.3)
63+
.setNumPartitions(4)
64+
.run(rdd)
65+
assert(model2.freqItemsets.count() === 54)
66+
67+
val model1 = fpg
68+
.setMinSupport(0.1)
69+
.setNumPartitions(8)
70+
.run(rdd)
71+
assert(model1.freqItemsets.count() === 625)
72+
}
73+
}

0 commit comments

Comments
 (0)