@@ -32,7 +32,6 @@ import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
3232import org .apache .spark .mllib .linalg .{Vector , DenseVector , SparseVector , Matrices }
3333import org .apache .spark .rdd .RDD
3434import org .apache .spark .util .Utils
35- import org .apache .spark .mllib .rdd .RDDFunctions ._
3635
3736
3837/**
@@ -223,10 +222,6 @@ class LDA private (
223222 this
224223 }
225224
226- object LDAMode extends Enumeration {
227- val EM, Online = Value
228- }
229-
230225 /**
231226 * Learn an LDA model using the given dataset.
232227 *
@@ -236,37 +231,30 @@ class LDA private (
236231 * Document IDs must be unique and >= 0.
237232 * @return Inferred LDA model
238233 */
239- def run (documents : RDD [(Long , Vector )], mode : LDAMode .Value = LDAMode .EM ): LDAModel = {
240- mode match {
241- case LDAMode .EM =>
242- val state = LDA .initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
243- checkpointInterval)
244- var iter = 0
245- val iterationTimes = Array .fill[Double ](maxIterations)(0 )
246- while (iter < maxIterations) {
247- val start = System .nanoTime()
248- state.next()
249- val elapsedSeconds = (System .nanoTime() - start) / 1e9
250- iterationTimes(iter) = elapsedSeconds
251- iter += 1
252- }
253- state.graphCheckpointer.deleteAllCheckpoints()
254- new DistributedLDAModel (state, iterationTimes)
255- case LDAMode .Online =>
256- val vocabSize = documents.first._2.size
257- val onlineLDA = new LDA .OnlineLDAOptimizer (documents, k, vocabSize)
258- var iter = 0
259- while (iter < onlineLDA.batchNumber) {
260- onlineLDA.next()
261- iter += 1
262- }
263- new LocalLDAModel (Matrices .fromBreeze(onlineLDA._lambda).transpose)
264- case _ => throw new IllegalArgumentException (s " Do not support mode $mode. " )
234+ def run (documents : RDD [(Long , Vector )]): DistributedLDAModel = {
235+ val state = LDA .initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
236+ checkpointInterval)
237+ var iter = 0
238+ val iterationTimes = Array .fill[Double ](maxIterations)(0 )
239+ while (iter < maxIterations) {
240+ val start = System .nanoTime()
241+ state.next()
242+ val elapsedSeconds = (System .nanoTime() - start) / 1e9
243+ iterationTimes(iter) = elapsedSeconds
244+ iter += 1
265245 }
246+ state.graphCheckpointer.deleteAllCheckpoints()
247+ new DistributedLDAModel (state, iterationTimes)
248+ }
249+
250+ def runOnlineLDA (documents : RDD [(Long , Vector )]): LDAModel = {
251+ val onlineLDA = new LDA .OnlineLDAOptimizer (documents, k)
252+ (0 until onlineLDA.batchNumber).map(_ => onlineLDA.next())
253+ new LocalLDAModel (Matrices .fromBreeze(onlineLDA.lambda).transpose)
266254 }
267255
268256 /** Java-friendly version of [[run() ]] */
269- def run (documents : JavaPairRDD [java.lang.Long , Vector ]): LDAModel = {
257+ def run (documents : JavaPairRDD [java.lang.Long , Vector ]): DistributedLDAModel = {
270258 run(documents.rdd.asInstanceOf [RDD [(Long , Vector )]])
271259 }
272260}
@@ -418,58 +406,66 @@ private[clustering] object LDA {
418406
419407 }
420408
421- // todo: add reference to paper and Hoffman
409+ /**
410+ * Optimizer for Online LDA algorithm which breaks corpus into mini-batches and scans only once.
411+ * Hoffman, Blei and Bach, “Online Learning for Latent Dirichlet Allocation.” NIPS, 2010.
412+ */
422413 private [clustering] class OnlineLDAOptimizer (
423- val documents : RDD [(Long , Vector )],
424- val k : Int ,
425- val vocabSize : Int ) extends Serializable {
414+ private val documents : RDD [(Long , Vector )],
415+ private val k : Int ) extends Serializable {
426416
427- private val kappa = 0.5 // (0.5, 1] how quickly old information is forgotten
428- private val tau0 = 1024 // down weights early iterations
429- private val D = documents.count()
417+ private val vocabSize = documents.first._2.size
418+ private val D = documents.count().toInt
430419 private val batchSize = if (D / 1000 > 4096 ) 4096
431420 else if (D / 1000 < 4 ) 4
432421 else D / 1000
433- val batchNumber = (D / batchSize + 1 ).toInt
434- private val batches = documents.sliding(batchNumber).collect()
422+ val batchNumber = D / batchSize
435423
436424 // Initialize the variational distribution q(beta|lambda)
437- var _lambda = getGammaMatrix(k, vocabSize) // K * V
438- private var _Elogbeta = dirichlet_expectation(_lambda) // K * V
439- private var _expElogbeta = exp(_Elogbeta) // K * V
425+ var lambda = getGammaMatrix(k, vocabSize) // K * V
426+ private var Elogbeta = dirichlet_expectation(lambda) // K * V
427+ private var expElogbeta = exp(Elogbeta ) // K * V
440428
441- private var batchCount = 0
429+ private var batchId = 0
442430 def next (): Unit = {
443- // weight of the mini-batch.
444- val rhot = math.pow(tau0 + batchCount, - kappa)
431+ require(batchId < batchNumber)
432+ // weight of the mini-batch. 1024 down weights early iterations
433+ val weight = math.pow(1024 + batchId, - 0.5 )
434+ val batch = documents.filter(doc => doc._1 % batchNumber == batchId)
445435
436+ // Given a mini-batch of documents, estimates the parameters gamma controlling the
437+ // variational distribution over the topic weights for each document in the mini-batch.
446438 var stat = BDM .zeros[Double ](k, vocabSize)
447- stat = batches(batchCount).aggregate(stat)(seqOp, _ += _)
448-
449- stat = stat :* _expElogbeta
450- _lambda = _lambda * (1 - rhot) + (stat * D .toDouble / batchSize.toDouble + 1.0 / k) * rhot
451- _Elogbeta = dirichlet_expectation(_lambda)
452- _expElogbeta = exp(_Elogbeta)
453- batchCount += 1
439+ stat = batch.aggregate(stat)(seqOp, _ += _)
440+ stat = stat :* expElogbeta
441+
442+ // Update lambda based on documents.
443+ lambda = lambda * (1 - weight) + (stat * D .toDouble / batchSize.toDouble + 1.0 / k) * weight
444+ Elogbeta = dirichlet_expectation(lambda)
445+ expElogbeta = exp(Elogbeta )
446+ batchId += 1
454447 }
455448
456- private def seqOp (other : BDM [Double ], doc : (Long , Vector )): BDM [Double ] = {
449+ // for each document d update that document's gamma and phi
450+ private def seqOp (stat : BDM [Double ], doc : (Long , Vector )): BDM [Double ] = {
457451 val termCounts = doc._2
458452 val (ids, cts) = termCounts match {
459453 case v : DenseVector => (((0 until v.size).toList), v.values)
460454 case v : SparseVector => (v.indices.toList, v.values)
461455 case v => throw new IllegalArgumentException (" Do not support vector type " + v.getClass)
462456 }
463457
458+ // Initialize the variational distribution q(theta|gamma) for the mini-batch
464459 var gammad = new Gamma (100 , 1.0 / 100.0 ).samplesVector(k).t // 1 * K
465460 var Elogthetad = vector_dirichlet_expectation(gammad.t).t // 1 * K
466461 var expElogthetad = exp(Elogthetad .t).t // 1 * K
467- val expElogbetad = _expElogbeta (:: , ids).toDenseMatrix // K * ids
462+ val expElogbetad = expElogbeta (:: , ids).toDenseMatrix // K * ids
468463
469464 var phinorm = expElogthetad * expElogbetad + 1e-100 // 1 * ids
470465 var meanchange = 1D
471- val ctsVector = new BDV [Double ](cts).t // 1 * ids
466+ val ctsVector = new BDV [Double ](cts).t // 1 * ids
472467
468+ // Iterate between gamma and phi until convergence
473469 while (meanchange > 1e-6 ) {
474470 val lastgamma = gammad
475471 // 1*K 1 * ids ids * k
@@ -480,30 +476,30 @@ private[clustering] object LDA {
480476 meanchange = sum(abs((gammad - lastgamma).t)) / gammad.t.size.toDouble
481477 }
482478
483- val v1 = expElogthetad.t.toDenseMatrix.t
484- val v2 = (ctsVector / phinorm).t.toDenseMatrix
485- val outerResult = kron(v1, v2 ) // K * ids
479+ val m1 = expElogthetad.t.toDenseMatrix.t
480+ val m2 = (ctsVector / phinorm).t.toDenseMatrix
481+ val outerResult = kron(m1, m2 ) // K * ids
486482 for (i <- 0 until ids.size) {
487- other (:: , ids(i)) := (other (:: , ids(i)) + outerResult(:: , i))
483+ stat (:: , ids(i)) := (stat (:: , ids(i)) + outerResult(:: , i))
488484 }
489- other
485+ stat
490486 }
491487
492- def getGammaMatrix (row: Int , col: Int ): BDM [Double ] = {
488+ private def getGammaMatrix (row: Int , col: Int ): BDM [Double ] = {
493489 val gammaRandomGenerator = new Gamma (100 , 1.0 / 100.0 )
494490 val temp = gammaRandomGenerator.sample(row * col).toArray
495491 (new BDM [Double ](col, row, temp)).t
496492 }
497493
498- def dirichlet_expectation (alpha : BDM [Double ]): BDM [Double ] = {
494+ private def dirichlet_expectation (alpha : BDM [Double ]): BDM [Double ] = {
499495 val rowSum = sum(alpha(breeze.linalg.* , :: ))
500496 val digAlpha = digamma(alpha)
501497 val digRowSum = digamma(rowSum)
502498 val result = digAlpha(:: , breeze.linalg.* ) - digRowSum
503499 result
504500 }
505501
506- def vector_dirichlet_expectation (v : BDV [Double ]): (BDV [Double ]) = {
502+ private def vector_dirichlet_expectation (v : BDV [Double ]): (BDV [Double ]) = {
507503 digamma(v) - digamma(sum(v))
508504 }
509505 }
0 commit comments