@@ -24,7 +24,6 @@ import org.apache.spark.annotation.Experimental
2424import org .apache .spark .mllib .linalg ._
2525import org .apache .spark .rdd .RDD
2626
27-
2827/**
2928 * :: Experimental ::
3029 * Latent Dirichlet Allocation (LDA), a topic model designed for text documents.
@@ -37,7 +36,58 @@ import org.apache.spark.rdd.RDD
3736 * Hoffman, Blei and Bach, "Online Learning for Latent Dirichlet Allocation." NIPS, 2010.
3837 */
3938@ Experimental
40- object OnlineLDA {
39+ class OnlineLDA (
40+ private var k : Int ,
41+ private var numIterations : Int ,
42+ private var miniBatchFraction : Double ,
43+ private var tau_0 : Double ,
44+ private var kappa : Double ) {
45+
46+ def this () = this (k = 10 , numIterations = 100 , miniBatchFraction = 0.01 ,
47+ tau_0 = 1024 , kappa = 0.5 )
48+
49+ /**
50+ * Number of topics to infer. I.e., the number of soft cluster centers.
51+ * (default = 10)
52+ */
53+ def setK (k : Int ): this .type = {
54+ require(k > 0 , s " OnlineLDA k (number of clusters) must be > 0, but was set to $k" )
55+ this .k = k
56+ this
57+ }
58+
59+ /**
60+ * Set the number of iterations for OnlineLDA. Default 100.
61+ */
62+ def setNumIterations (iters : Int ): this .type = {
63+ this .numIterations = iters
64+ this
65+ }
66+
67+ /**
68+ * Set fraction of data to be used for each iteration. Default 0.01.
69+ */
70+ def setMiniBatchFraction (fraction : Double ): this .type = {
71+ this .miniBatchFraction = fraction
72+ this
73+ }
74+
75+ /**
76+ * A (positive) learning parameter that downweights early iterations. Default 1024.
77+ */
78+ def setTau_0 (t : Double ): this .type = {
79+ this .tau_0 = t
80+ this
81+ }
82+
83+ /**
84+ * Learning rate: exponential decay rate. Default 0.5.
85+ */
86+ def setKappa (kappa : Double ): this .type = {
87+ this .kappa = kappa
88+ this
89+ }
90+
4191
4292 /**
4393 * Learns an LDA model from the given data set, using online variational Bayes (VB) algorithm.
@@ -49,33 +99,18 @@ object OnlineLDA{
4999 * The term count vectors are "bags of words" with a fixed-size vocabulary
50100 * (where the vocabulary size is the length of the vector).
51101 * Document IDs must be unique and >= 0.
52- * @param k Number of topics to infer.
53- * @param batchNumber Number of batches to split input corpus. For each batch, recommendation
54- * size is [4, 16384]. -1 for automatic batchNumber.
55102 * @return Inferred LDA model
56103 */
57- def run (documents : RDD [(Long , Vector )], k : Int , batchNumber : Int = - 1 ): LDAModel = {
58- require(batchNumber > 0 || batchNumber == - 1 ,
59- s " batchNumber must be greater or -1, but was set to $batchNumber" )
60- require(k > 0 , s " LDA k (number of clusters) must be > 0, but was set to $k" )
61-
104+ def run (documents : RDD [(Long , Vector )]): LDAModel = {
62105 val vocabSize = documents.first._2.size
63106 val D = documents.count().toInt // total documents count
64- val batchSize =
65- if (batchNumber == - 1 ) { // auto mode
66- if (D / 100 > 16384 ) 16384
67- else if (D / 100 < 4 ) 4
68- else D / 100
69- }
70- else {
71- D / batchNumber
72- }
73-
74- val onlineLDA = new OnlineLDAOptimizer (k, D , vocabSize)
75- val actualBatchNumber = Math .ceil(D .toDouble / batchSize).toInt
76- for (i <- 1 to actualBatchNumber){
77- val batch = documents.sample(true , batchSize.toDouble / D )
78- onlineLDA.submitMiniBatch(batch)
107+ val onlineLDA = new OnlineLDAOptimizer (k, D , vocabSize, tau_0, kappa)
108+
109+ val arr = Array .fill(math.ceil(1.0 / miniBatchFraction).toInt)(miniBatchFraction)
110+ for (i <- 0 until numIterations){
111+ val splits = documents.randomSplit(arr)
112+ val index = i % splits.size
113+ onlineLDA.submitMiniBatch(splits(index))
79114 }
80115 onlineLDA.getTopicDistribution()
81116 }
@@ -93,10 +128,12 @@ object OnlineLDA{
93128 * Hoffman, Blei and Bach, "Online Learning for Latent Dirichlet Allocation." NIPS, 2010.
94129 */
95130@ Experimental
96- class OnlineLDAOptimizer (
131+ private [clustering] class OnlineLDAOptimizer (
97132 private var k : Int ,
98133 private var D : Int ,
99- private val vocabSize : Int ) extends Serializable {
134+ private val vocabSize : Int ,
135+ private val tau_0 : Double ,
136+ private val kappa : Double ) extends Serializable {
100137
101138 // Initialize the variational distribution q(beta|lambda)
102139 private var lambda = getGammaMatrix(k, vocabSize) // K * V
@@ -115,7 +152,11 @@ class OnlineLDAOptimizer (
115152 * Document IDs must be unique and >= 0.
116153 * @return Inferred LDA model
117154 */
118- def submitMiniBatch (documents : RDD [(Long , Vector )]): Unit = {
155+ private [clustering] def submitMiniBatch (documents : RDD [(Long , Vector )]): Unit = {
156+ if (documents.isEmpty()){
157+ return
158+ }
159+
119160 var stat = BDM .zeros[Double ](k, vocabSize)
120161 stat = documents.treeAggregate(stat)(gradient, _ += _)
121162 update(stat, i, documents.count().toInt)
@@ -125,13 +166,13 @@ class OnlineLDAOptimizer (
125166 /**
126167 * get the topic-term distribution
127168 */
128- def getTopicDistribution (): LDAModel = {
169+ private [clustering] def getTopicDistribution (): LDAModel = {
129170 new LocalLDAModel (Matrices .fromBreeze(lambda).transpose)
130171 }
131172
132173 private def update (raw : BDM [Double ], iter: Int , batchSize : Int ): Unit = {
133- // weight of the mini-batch. 1024 helps down weights early iterations
134- val weight = math.pow(1024 + iter, - 0.5 )
174+ // weight of the mini-batch.
175+ val weight = math.pow(tau_0 + iter, - kappa )
135176
136177 // This step finishes computing the sufficient statistics for the M step
137178 val stat = raw :* expElogbeta
0 commit comments