@@ -19,15 +19,17 @@ package org.apache.spark.mllib.clustering
1919
2020import java .util .Random
2121
22- import breeze .linalg .{DenseVector => BDV , normalize , axpy => brzAxpy }
22+ import breeze .linalg .{DenseVector => BDV , normalize , kron , sum , axpy => brzAxpy , DenseMatrix => BDM }
23+ import breeze .numerics .{exp , abs , digamma }
24+ import breeze .stats .distributions .Gamma
2325
2426import org .apache .spark .Logging
2527import org .apache .spark .annotation .Experimental
2628import org .apache .spark .api .java .JavaPairRDD
2729import org .apache .spark .graphx ._
2830import org .apache .spark .graphx .impl .GraphImpl
2931import org .apache .spark .mllib .impl .PeriodicGraphCheckpointer
30- import org .apache .spark .mllib .linalg .Vector
32+ import org .apache .spark .mllib .linalg .{ Vector , DenseVector , SparseVector , Matrices }
3133import org .apache .spark .rdd .RDD
3234import org .apache .spark .util .Utils
3335
@@ -250,6 +252,10 @@ class LDA private (
250252 this
251253 }
252254
255+ object LDAMode extends Enumeration {
256+ val EM, Online = Value
257+ }
258+
253259 /**
254260 * Learn an LDA model using the given dataset.
255261 *
@@ -259,24 +265,39 @@ class LDA private (
259265 * Document IDs must be unique and >= 0.
260266 * @return Inferred LDA model
261267 */
262- def run (documents : RDD [(Long , Vector )]): DistributedLDAModel = {
263- val state = LDA .initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
264- checkpointDir, checkpointInterval)
265- var iter = 0
266- val iterationTimes = Array .fill[Double ](maxIterations)(0 )
267- while (iter < maxIterations) {
268- val start = System .nanoTime()
269- state.next()
270- val elapsedSeconds = (System .nanoTime() - start) / 1e9
271- iterationTimes(iter) = elapsedSeconds
272- iter += 1
268+ def run (documents : RDD [(Long , Vector )], mode : LDAMode .Value = LDAMode .EM ): LDAModel = {
269+ mode match {
270+ case LDAMode .EM =>
271+ val state = LDA .initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
272+ checkpointDir, checkpointInterval)
273+ var iter = 0
274+ val iterationTimes = Array .fill[Double ](maxIterations)(0 )
275+ while (iter < maxIterations) {
276+ val start = System .nanoTime()
277+ state.next()
278+ val elapsedSeconds = (System .nanoTime() - start) / 1e9
279+ iterationTimes(iter) = elapsedSeconds
280+ iter += 1
281+ }
282+ state.graphCheckpointer.deleteAllCheckpoints()
283+ new DistributedLDAModel (state, iterationTimes)
284+ case LDAMode .Online =>
285+ // todo: delete the comment in next line
286+ // I changed the return type to LDAModel, as DistributedLDAModel is based on Graph.
287+ val vocabSize = documents.first._2.size
288+ val onlineLDA = new LDA .OnlineLDAOptimizer (documents, k, vocabSize)
289+ var iter = 0
290+ while (iter < onlineLDA.batchNumber) {
291+ onlineLDA.next()
292+ iter += 1
293+ }
294+ new LocalLDAModel (Matrices .fromBreeze(onlineLDA._lambda).transpose)
295+ case _ => throw new IllegalArgumentException (s " Do not support mode $mode. " )
273296 }
274- state.graphCheckpointer.deleteAllCheckpoints()
275- new DistributedLDAModel (state, iterationTimes)
276297 }
277298
278299 /** Java-friendly version of [[run() ]] */
279- def run (documents : JavaPairRDD [java.lang.Long , Vector ]): DistributedLDAModel = {
300+ def run (documents : JavaPairRDD [java.lang.Long , Vector ]): LDAModel = {
280301 run(documents.rdd.asInstanceOf [RDD [(Long , Vector )]])
281302 }
282303}
@@ -429,6 +450,97 @@ private[clustering] object LDA {
429450
430451 }
431452
453+ // todo: add reference to paper and Hoffman
454+ class OnlineLDAOptimizer (
455+ val documents : RDD [(Long , Vector )],
456+ val k : Int ,
457+ val vocabSize : Int ) extends Serializable {
458+
459+ private val kappa = 0.5 // (0.5, 1] how quickly old information is forgotten
460+ private val tau0 = 1024 // down weights early iterations
461+ private val D = documents.count()
462+ private val batchSize = if (D / 1000 > 4096 ) 4096
463+ else if (D / 1000 < 4 ) 4
464+ else D / 1000
465+ val batchNumber = (D / batchSize + 1 ).toInt
466+ // todo: performance killer, need to be replaced
467+ private val batches = documents.randomSplit(Array .fill[Double ](batchNumber)(1.0 ))
468+
469+ // Initialize the variational distribution q(beta|lambda)
470+ var _lambda = getGammaMatrix(k, vocabSize) // K * V
471+ private var _Elogbeta = dirichlet_expectation(_lambda) // K * V
472+ private var _expElogbeta = exp(_Elogbeta) // K * V
473+
474+ private var batchCount = 0
475+ def next (): Unit = {
476+ // weight of the mini-batch.
477+ val rhot = math.pow(tau0 + batchCount, - kappa)
478+
479+ var stat = BDM .zeros[Double ](k, vocabSize)
480+ stat = batches(batchCount).aggregate(stat)(seqOp, _ += _)
481+
482+ stat = stat :* _expElogbeta
483+ _lambda = _lambda * (1 - rhot) + (stat * D .toDouble / batchSize.toDouble + 1.0 / k) * rhot
484+ _Elogbeta = dirichlet_expectation(_lambda)
485+ _expElogbeta = exp(_Elogbeta)
486+ batchCount += 1
487+ }
488+
489+ private def seqOp (other : BDM [Double ], doc : (Long , Vector )): BDM [Double ] = {
490+ val termCounts = doc._2
491+ val (ids, cts) = termCounts match {
492+ case v : DenseVector => (((0 until v.size).toList), v.values)
493+ case v : SparseVector => (v.indices.toList, v.values)
494+ case v => throw new IllegalArgumentException (" Do not support vector type " + v.getClass)
495+ }
496+
497+ var gammad = new Gamma (100 , 1.0 / 100.0 ).samplesVector(k).t // 1 * K
498+ var Elogthetad = vector_dirichlet_expectation(gammad.t).t // 1 * K
499+ var expElogthetad = exp(Elogthetad .t).t // 1 * K
500+ val expElogbetad = _expElogbeta(:: , ids).toDenseMatrix // K * ids
501+
502+ var phinorm = expElogthetad * expElogbetad + 1e-100 // 1 * ids
503+ var meanchange = 1D
504+ val ctsVector = new BDV [Double ](cts).t // 1 * ids
505+
506+ while (meanchange > 1e-6 ) {
507+ val lastgamma = gammad
508+ // 1*K 1 * ids ids * k
509+ gammad = (expElogthetad :* ((ctsVector / phinorm) * (expElogbetad.t))) + 1.0 / k
510+ Elogthetad = vector_dirichlet_expectation(gammad.t).t
511+ expElogthetad = exp(Elogthetad .t).t
512+ phinorm = expElogthetad * expElogbetad + 1e-100
513+ meanchange = sum(abs((gammad - lastgamma).t)) / gammad.t.size.toDouble
514+ }
515+
516+ val v1 = expElogthetad.t.toDenseMatrix.t
517+ val v2 = (ctsVector / phinorm).t.toDenseMatrix
518+ val outerResult = kron(v1, v2) // K * ids
519+ for (i <- 0 until ids.size) {
520+ other(:: , ids(i)) := (other(:: , ids(i)) + outerResult(:: , i))
521+ }
522+ other
523+ }
524+
525+ def getGammaMatrix (row: Int , col: Int ): BDM [Double ] = {
526+ val gammaRandomGenerator = new Gamma (100 , 1.0 / 100.0 )
527+ val temp = gammaRandomGenerator.sample(row * col).toArray
528+ (new BDM [Double ](col, row, temp)).t
529+ }
530+
531+ def dirichlet_expectation (alpha : BDM [Double ]): BDM [Double ] = {
532+ val rowSum = sum(alpha(breeze.linalg.* , :: ))
533+ val digAlpha = digamma(alpha)
534+ val digRowSum = digamma(rowSum)
535+ val result = digAlpha(:: , breeze.linalg.* ) - digRowSum
536+ result
537+ }
538+
539+ def vector_dirichlet_expectation (v : BDV [Double ]): (BDV [Double ]) = {
540+ digamma(v) - digamma(sum(v))
541+ }
542+ }
543+
432544 /**
433545 * Compute gamma_{wjk}, a distribution over topics k.
434546 */
0 commit comments