Skip to content

Commit e7bf3b0

Browse files
committed
move to seperate file
1 parent f367cc9 commit e7bf3b0

File tree

2 files changed

+207
-137
lines changed

2 files changed

+207
-137
lines changed

mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala

Lines changed: 2 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,15 @@ package org.apache.spark.mllib.clustering
1919

2020
import java.util.Random
2121

22-
import breeze.linalg.{DenseVector => BDV, normalize, kron, sum, axpy => brzAxpy, DenseMatrix => BDM}
23-
import breeze.numerics.{exp, abs, digamma}
24-
import breeze.stats.distributions.Gamma
22+
import breeze.linalg.{DenseVector => BDV, normalize, axpy => brzAxpy}
2523

2624
import org.apache.spark.Logging
2725
import org.apache.spark.annotation.Experimental
2826
import org.apache.spark.api.java.JavaPairRDD
2927
import org.apache.spark.graphx._
3028
import org.apache.spark.graphx.impl.GraphImpl
3129
import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
32-
import org.apache.spark.mllib.linalg.{Vector, DenseVector, SparseVector, Matrices}
30+
import org.apache.spark.mllib.linalg.Vector
3331
import org.apache.spark.rdd.RDD
3432
import org.apache.spark.util.Utils
3533

@@ -247,28 +245,6 @@ class LDA private (
247245
new DistributedLDAModel(state, iterationTimes)
248246
}
249247

250-
251-
/**
252-
* TODO: add API to take documents paths once tokenizer is ready.
253-
* Learn an LDA model using the given dataset, using online variational Bayes (VB) algorithm.
254-
*
255-
* @param documents RDD of documents, which are term (word) count vectors paired with IDs.
256-
* The term count vectors are "bags of words" with a fixed-size vocabulary
257-
* (where the vocabulary size is the length of the vector).
258-
* Document IDs must be unique and >= 0.
259-
* @param batchNumber Number of batches to split input corpus. For each batch, recommendation
260-
* size is [4, 16384]. -1 for automatic batchNumber.
261-
* @return Inferred LDA model
262-
*/
263-
def runOnlineLDA(documents: RDD[(Long, Vector)], batchNumber: Int = -1): LDAModel = {
264-
require(batchNumber > 0 || batchNumber == -1,
265-
s"batchNumber must be greater or -1, but was set to $batchNumber")
266-
267-
val onlineLDA = new LDA.OnlineLDAOptimizer(documents, k, batchNumber)
268-
val model = onlineLDA.optimize()
269-
new LocalLDAModel(Matrices.fromBreeze(model).transpose)
270-
}
271-
272248
/** Java-friendly version of [[run()]] */
273249
def run(documents: JavaPairRDD[java.lang.Long, Vector]): DistributedLDAModel = {
274250
run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]])
@@ -422,117 +398,6 @@ private[clustering] object LDA {
422398

423399
}
424400

425-
/**
426-
* Optimizer for Online LDA algorithm which breaks corpus into mini-batches and scans only once.
427-
* Hoffman, Blei and Bach, "Online Learning for Latent Dirichlet Allocation." NIPS, 2010.
428-
*/
429-
private[clustering] class OnlineLDAOptimizer(
430-
private val documents: RDD[(Long, Vector)],
431-
private val k: Int,
432-
private val batchNumber: Int) extends Serializable{
433-
434-
private val vocabSize = documents.first._2.size
435-
private val D = documents.count().toInt
436-
private val batchSize =
437-
if (batchNumber == -1) { // auto mode
438-
if (D / 100 > 16384) 16384
439-
else if (D / 100 < 4) 4
440-
else D / 100
441-
}
442-
else {
443-
D / batchNumber
444-
}
445-
446-
// Initialize the variational distribution q(beta|lambda)
447-
private var lambda = getGammaMatrix(k, vocabSize) // K * V
448-
private var Elogbeta = dirichlet_expectation(lambda) // K * V
449-
private var expElogbeta = exp(Elogbeta) // K * V
450-
451-
def optimize(): BDM[Double] = {
452-
val actualBatchNumber = Math.ceil(D.toDouble / batchSize).toInt
453-
for(i <- 1 to actualBatchNumber){
454-
val batch = documents.sample(true, batchSize.toDouble / D)
455-
456-
// Given a mini-batch of documents, estimates the parameters gamma controlling the
457-
// variational distribution over the topic weights for each document in the mini-batch.
458-
var stat = BDM.zeros[Double](k, vocabSize)
459-
stat = batch.treeAggregate(stat)(gradient, _ += _)
460-
update(stat, i)
461-
}
462-
lambda
463-
}
464-
465-
private def update(raw: BDM[Double], iter:Int): Unit ={
466-
// weight of the mini-batch. 1024 helps down weights early iterations
467-
val weight = math.pow(1024 + iter, -0.5)
468-
469-
// This step finishes computing the sufficient statistics for the M step
470-
val stat = raw :* expElogbeta
471-
472-
// Update lambda based on documents.
473-
lambda = lambda * (1 - weight) + (stat * D.toDouble / batchSize.toDouble + 1.0 / k) * weight
474-
Elogbeta = dirichlet_expectation(lambda)
475-
expElogbeta = exp(Elogbeta)
476-
}
477-
478-
// for each document d update that document's gamma and phi
479-
private def gradient(stat: BDM[Double], doc: (Long, Vector)): BDM[Double] = {
480-
val termCounts = doc._2
481-
val (ids, cts) = termCounts match {
482-
case v: DenseVector => (((0 until v.size).toList), v.values)
483-
case v: SparseVector => (v.indices.toList, v.values)
484-
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
485-
}
486-
487-
// Initialize the variational distribution q(theta|gamma) for the mini-batch
488-
var gammad = new Gamma(100, 1.0 / 100.0).samplesVector(k).t // 1 * K
489-
var Elogthetad = vector_dirichlet_expectation(gammad.t).t // 1 * K
490-
var expElogthetad = exp(Elogthetad.t).t // 1 * K
491-
val expElogbetad = expElogbeta(::, ids).toDenseMatrix // K * ids
492-
493-
var phinorm = expElogthetad * expElogbetad + 1e-100 // 1 * ids
494-
var meanchange = 1D
495-
val ctsVector = new BDV[Double](cts).t // 1 * ids
496-
497-
// Iterate between gamma and phi until convergence
498-
while (meanchange > 1e-5) {
499-
val lastgamma = gammad
500-
// 1*K 1 * ids ids * k
501-
gammad = (expElogthetad :* ((ctsVector / phinorm) * (expElogbetad.t))) + 1.0/k
502-
Elogthetad = vector_dirichlet_expectation(gammad.t).t
503-
expElogthetad = exp(Elogthetad.t).t
504-
phinorm = expElogthetad * expElogbetad + 1e-100
505-
meanchange = sum(abs((gammad - lastgamma).t)) / gammad.t.size.toDouble
506-
}
507-
508-
val m1 = expElogthetad.t.toDenseMatrix.t
509-
val m2 = (ctsVector / phinorm).t.toDenseMatrix
510-
val outerResult = kron(m1, m2) // K * ids
511-
for (i <- 0 until ids.size) {
512-
stat(::, ids(i)) := (stat(::, ids(i)) + outerResult(::, i))
513-
}
514-
stat
515-
}
516-
517-
private def getGammaMatrix(row:Int, col:Int): BDM[Double] ={
518-
val gammaRandomGenerator = new Gamma(100, 1.0 / 100.0)
519-
val temp = gammaRandomGenerator.sample(row * col).toArray
520-
(new BDM[Double](col, row, temp)).t
521-
}
522-
523-
private def dirichlet_expectation(alpha : BDM[Double]): BDM[Double] = {
524-
val rowSum = sum(alpha(breeze.linalg.*, ::))
525-
val digAlpha = digamma(alpha)
526-
val digRowSum = digamma(rowSum)
527-
val result = digAlpha(::, breeze.linalg.*) - digRowSum
528-
result
529-
}
530-
531-
private def vector_dirichlet_expectation(v : BDV[Double]): (BDV[Double]) ={
532-
digamma(v) - digamma(sum(v))
533-
}
534-
}
535-
536401
/**
537402
* Compute gamma_{wjk}, a distribution over topics k.
538403
*/
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.mllib.clustering
19+
20+
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, kron, sum}
21+
import breeze.numerics._
22+
import breeze.stats.distributions.Gamma
23+
import org.apache.spark.annotation.Experimental
24+
import org.apache.spark.mllib.linalg._
25+
import org.apache.spark.rdd.RDD
26+
27+
28+
/**
29+
* :: Experimental ::
30+
* Latent Dirichlet Allocation (LDA), a topic model designed for text documents.
31+
*
32+
* Online LDA breaks the massive corps into mini batches and scans the corpus (doc sets) only
33+
* once. Thus it needs not locally store or collect the documents and can be handily applied to
34+
* streaming document collections.
35+
*
36+
* References:
37+
* Hoffman, Blei and Bach, "Online Learning for Latent Dirichlet Allocation." NIPS, 2010.
38+
*/
39+
@Experimental
40+
object OnlineLDA{
41+
42+
/**
43+
* Learns an LDA model from the given data set, using online variational Bayes (VB) algorithm.
44+
* This is just designed as a handy API. For massive corpus, it's recommended to use
45+
* OnlineLDAOptimizer directly and call submitMiniBatch in your application, which can help
46+
* downgrade time and space complexity by not loading the entire corpus.
47+
*
48+
* @param documents RDD of documents, which are term (word) count vectors paired with IDs.
49+
* The term count vectors are "bags of words" with a fixed-size vocabulary
50+
* (where the vocabulary size is the length of the vector).
51+
* Document IDs must be unique and >= 0.
52+
* @param k Number of topics to infer.
53+
* @param batchNumber Number of batches to split input corpus. For each batch, recommendation
54+
* size is [4, 16384]. -1 for automatic batchNumber.
55+
* @return Inferred LDA model
56+
*/
57+
def run(documents: RDD[(Long, Vector)], k: Int, batchNumber: Int = -1): LDAModel = {
58+
require(batchNumber > 0 || batchNumber == -1,
59+
s"batchNumber must be greater or -1, but was set to $batchNumber")
60+
require(k > 0, s"LDA k (number of clusters) must be > 0, but was set to $k")
61+
62+
val vocabSize = documents.first._2.size
63+
val D = documents.count().toInt // total documents count
64+
val batchSize =
65+
if (batchNumber == -1) { // auto mode
66+
if (D / 100 > 16384) 16384
67+
else if (D / 100 < 4) 4
68+
else D / 100
69+
}
70+
else {
71+
D / batchNumber
72+
}
73+
74+
val onlineLDA = new OnlineLDAOptimizer(k, D, vocabSize)
75+
val actualBatchNumber = Math.ceil(D.toDouble / batchSize).toInt
76+
for(i <- 1 to actualBatchNumber){
77+
val batch = documents.sample(true, batchSize.toDouble / D)
78+
onlineLDA.submitMiniBatch(batch)
79+
}
80+
onlineLDA.getTopicDistribution()
81+
}
82+
}
83+
84+
/**
85+
* :: Experimental ::
86+
* Latent Dirichlet Allocation (LDA), a topic model designed for text documents.
87+
*
88+
* An online training optimizer for LDA. The Optimizer processes a subset (like 1%) of the corpus
89+
* by each call to submitMiniBatch, and update the term-topic distribution adaptively. User can
90+
* get the result from getTopicDistribution.
91+
*
92+
* References:
93+
* Hoffman, Blei and Bach, "Online Learning for Latent Dirichlet Allocation." NIPS, 2010.
94+
*/
95+
@Experimental
96+
class OnlineLDAOptimizer (
97+
private var k: Int,
98+
private var D: Int,
99+
private val vocabSize:Int) extends Serializable {
100+
101+
// Initialize the variational distribution q(beta|lambda)
102+
private var lambda = getGammaMatrix(k, vocabSize) // K * V
103+
private var Elogbeta = dirichlet_expectation(lambda) // K * V
104+
private var expElogbeta = exp(Elogbeta) // K * V
105+
private var i = 0
106+
107+
/**
108+
* Submit a a subset (like 1%) of the corpus to the Online LDA model, and it will update
109+
* the topic distribution adaptively for the terms appearing in the subset (minibatch).
110+
* The documents RDD can be discarded after submitMiniBatch finished.
111+
*
112+
* @param documents RDD of documents, which are term (word) count vectors paired with IDs.
113+
* The term count vectors are "bags of words" with a fixed-size vocabulary
114+
* (where the vocabulary size is the length of the vector).
115+
* Document IDs must be unique and >= 0.
116+
* @return Inferred LDA model
117+
*/
118+
def submitMiniBatch(documents: RDD[(Long, Vector)]): Unit = {
119+
var stat = BDM.zeros[Double](k, vocabSize)
120+
stat = documents.treeAggregate(stat)(gradient, _ += _)
121+
update(stat, i, documents.count().toInt)
122+
i += 1
123+
}
124+
125+
/**
126+
* get the topic-term distribution
127+
*/
128+
def getTopicDistribution(): LDAModel ={
129+
new LocalLDAModel(Matrices.fromBreeze(lambda).transpose)
130+
}
131+
132+
private def update(raw: BDM[Double], iter:Int, batchSize: Int): Unit ={
133+
// weight of the mini-batch. 1024 helps down weights early iterations
134+
val weight = math.pow(1024 + iter, -0.5)
135+
136+
// This step finishes computing the sufficient statistics for the M step
137+
val stat = raw :* expElogbeta
138+
139+
// Update lambda based on documents.
140+
lambda = lambda * (1 - weight) + (stat * D.toDouble / batchSize.toDouble + 1.0 / k) * weight
141+
Elogbeta = dirichlet_expectation(lambda)
142+
expElogbeta = exp(Elogbeta)
143+
}
144+
145+
// for each document d update that document's gamma and phi
146+
private def gradient(stat: BDM[Double], doc: (Long, Vector)): BDM[Double] = {
147+
val termCounts = doc._2
148+
val (ids, cts) = termCounts match {
149+
case v: DenseVector => (((0 until v.size).toList), v.values)
150+
case v: SparseVector => (v.indices.toList, v.values)
151+
case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
152+
}
153+
154+
// Initialize the variational distribution q(theta|gamma) for the mini-batch
155+
var gammad = new Gamma(100, 1.0 / 100.0).samplesVector(k).t // 1 * K
156+
var Elogthetad = vector_dirichlet_expectation(gammad.t).t // 1 * K
157+
var expElogthetad = exp(Elogthetad.t).t // 1 * K
158+
val expElogbetad = expElogbeta(::, ids).toDenseMatrix // K * ids
159+
160+
var phinorm = expElogthetad * expElogbetad + 1e-100 // 1 * ids
161+
var meanchange = 1D
162+
val ctsVector = new BDV[Double](cts).t // 1 * ids
163+
164+
// Iterate between gamma and phi until convergence
165+
while (meanchange > 1e-5) {
166+
val lastgamma = gammad
167+
// 1*K 1 * ids ids * k
168+
gammad = (expElogthetad :* ((ctsVector / phinorm) * (expElogbetad.t))) + 1.0/k
169+
Elogthetad = vector_dirichlet_expectation(gammad.t).t
170+
expElogthetad = exp(Elogthetad.t).t
171+
phinorm = expElogthetad * expElogbetad + 1e-100
172+
meanchange = sum(abs((gammad - lastgamma).t)) / gammad.t.size.toDouble
173+
}
174+
175+
val m1 = expElogthetad.t.toDenseMatrix.t
176+
val m2 = (ctsVector / phinorm).t.toDenseMatrix
177+
val outerResult = kron(m1, m2) // K * ids
178+
for (i <- 0 until ids.size) {
179+
stat(::, ids(i)) := (stat(::, ids(i)) + outerResult(::, i))
180+
}
181+
stat
182+
}
183+
184+
private def getGammaMatrix(row:Int, col:Int): BDM[Double] ={
185+
val gammaRandomGenerator = new Gamma(100, 1.0 / 100.0)
186+
val temp = gammaRandomGenerator.sample(row * col).toArray
187+
(new BDM[Double](col, row, temp)).t
188+
}
189+
190+
private def dirichlet_expectation(alpha : BDM[Double]): BDM[Double] = {
191+
val rowSum = sum(alpha(breeze.linalg.*, ::))
192+
val digAlpha = digamma(alpha)
193+
val digRowSum = digamma(rowSum)
194+
val result = digAlpha(::, breeze.linalg.*) - digRowSum
195+
result
196+
}
197+
198+
private def vector_dirichlet_expectation(v : BDV[Double]): (BDV[Double]) ={
199+
digamma(v) - digamma(sum(v))
200+
}
201+
}
202+
203+
204+
205+

0 commit comments

Comments
 (0)