diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala new file mode 100644 index 000000000000..1f0d8e76ca28 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import java.util.Random + +import breeze.linalg.{DenseVector => BDV} + +import org.apache.spark.{AccumulableParam, Logging, SparkContext} +import org.apache.spark.mllib.expectation.GibbsSampling +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD + +case class Document(docId: Int, content: Iterable[Int]) + +case class LDAParams ( + docCounts: Vector, + topicCounts: Vector, + docTopicCounts: Array[Vector], + topicTermCounts: Array[Vector]) + extends Serializable { + + def update(docId: Int, term: Int, topic: Int, inc: Int) = { + docCounts.toBreeze(docId) += inc + topicCounts.toBreeze(topic) += inc + docTopicCounts(docId).toBreeze(topic) += inc + topicTermCounts(topic).toBreeze(term) += inc + this + } + + def merge(other: LDAParams) = { + docCounts.toBreeze += other.docCounts.toBreeze + topicCounts.toBreeze += other.topicCounts.toBreeze + + var i = 0 + while (i < docTopicCounts.length) { + docTopicCounts(i).toBreeze += other.docTopicCounts(i).toBreeze + i += 1 + } + + i = 0 + while (i < topicTermCounts.length) { + topicTermCounts(i).toBreeze += other.topicTermCounts(i).toBreeze + i += 1 + } + this + } + + /** + * This function used for computing the new distribution after drop one from current document, + * which is a really essential part of Gibbs sampling for LDA, you can refer to the paper: + * Parameter estimation for text analysis + */ + def dropOneDistSampler( + docTopicSmoothing: Double, + topicTermSmoothing: Double, + termId: Int, + docId: Int, + rand: Random): Int = { + val (numTopics, numTerms) = (topicCounts.size, topicTermCounts.head.size) + val topicThisTerm = BDV.zeros[Double](numTopics) + var i = 0 + while (i < numTopics) { + topicThisTerm(i) = + ((topicTermCounts(i)(termId) + topicTermSmoothing) + / (topicCounts(i) + (numTerms * topicTermSmoothing)) + ) + (docTopicCounts(docId)(i) + docTopicSmoothing) + i += 1 + } + GibbsSampling.multinomialDistSampler(rand, topicThisTerm) + } +} + +object LDAParams { + implicit val ldaParamsAP = new LDAParamsAccumulableParam + + def apply(numDocs: Int, numTopics: Int, numTerms: Int) = new LDAParams( + Vectors.fromBreeze(BDV.zeros[Double](numDocs)), + Vectors.fromBreeze(BDV.zeros[Double](numTopics)), + Array(0 until numDocs: _*).map(_ => Vectors.fromBreeze(BDV.zeros[Double](numTopics))), + Array(0 until numTopics: _*).map(_ => Vectors.fromBreeze(BDV.zeros[Double](numTerms)))) +} + +class LDAParamsAccumulableParam extends AccumulableParam[LDAParams, (Int, Int, Int, Int)] { + def addAccumulator(r: LDAParams, t: (Int, Int, Int, Int)) = { + val (docId, term, topic, inc) = t + r.update(docId, term, topic, inc) + } + + def addInPlace(r1: LDAParams, r2: LDAParams): LDAParams = r1.merge(r2) + + def zero(initialValue: LDAParams): LDAParams = initialValue +} + +class LDA private ( + var numTopics: Int, + var docTopicSmoothing: Double, + var topicTermSmoothing: Double, + var numIteration: Int, + var numDocs: Int, + var numTerms: Int) + extends Serializable with Logging { + def run(input: RDD[Document]): (GibbsSampling, LDAParams) = { + val trainer = new GibbsSampling( + input, + numIteration, + 1, + docTopicSmoothing, + topicTermSmoothing) + (trainer, trainer.runGibbsSampling(LDAParams(numDocs, numTopics, numTerms))) + } +} + +object LDA extends Logging { + + def train( + data: RDD[Document], + numTopics: Int, + docTopicSmoothing: Double, + topicTermSmoothing: Double, + numIterations: Int, + numDocs: Int, + numTerms: Int): (Array[Vector], Array[Vector]) = { + val lda = new LDA(numTopics, + docTopicSmoothing, + topicTermSmoothing, + numIterations, + numDocs, + numTerms) + val (trainer, model) = lda.run(data) + trainer.solvePhiAndTheta(model) + } + + def main(args: Array[String]) { + if (args.length != 5) { + println("Usage: LDA ") + System.exit(1) + } + + val (master, inputDir, k, iters, minSplit) = + (args(0), args(1), args(2).toInt, args(3).toInt, args(4).toInt) + val checkPointDir = System.getProperty("spark.gibbsSampling.checkPointDir", "/tmp/lda-cp") + val sc = new SparkContext(master, "LDA") + sc.setCheckpointDir(checkPointDir) + val (data, wordMap, docMap) = MLUtils.loadCorpus(sc, inputDir, minSplit) + val numDocs = docMap.size + val numTerms = wordMap.size + + val (phi, theta) = LDA.train(data, k, 0.01, 0.01, iters, numDocs, numTerms) + val pp = GibbsSampling.perplexity(data, phi, theta) + println(s"final mode perplexity is $pp") + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/expectation/GibbsSampling.scala b/mllib/src/main/scala/org/apache/spark/mllib/expectation/GibbsSampling.scala new file mode 100644 index 000000000000..5e4f776c9166 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/expectation/GibbsSampling.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.expectation + +import java.util.Random + +import breeze.linalg.{DenseVector => BDV, sum} + +import org.apache.spark.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.clustering.{Document, LDAParams} +import org.apache.spark.mllib.linalg.{Vector, Vectors} + +/** + * Gibbs sampling from a given dataset and org.apache.spark.mllib.model. + * @param data Dataset, such as corpus. + * @param numOuterIterations Number of outer iteration. + * @param numInnerIterations Number of inner iteration, used in each partition. + * @param docTopicSmoothing Document-topic smoothing. + * @param topicTermSmoothing Topic-term smoothing. + */ +class GibbsSampling( + data: RDD[Document], + numOuterIterations: Int, + numInnerIterations: Int, + docTopicSmoothing: Double, + topicTermSmoothing: Double) + extends Logging with Serializable { + + import GibbsSampling._ + + /** + * Main function of running a Gibbs sampling method. It contains two phases of total Gibbs + * sampling: first is initialization, second is real sampling. + */ + def runGibbsSampling( + initParams: LDAParams, + data: RDD[Document] = data, + numOuterIterations: Int = numOuterIterations, + numInnerIterations: Int = numInnerIterations, + docTopicSmoothing: Double = docTopicSmoothing, + topicTermSmoothing: Double = topicTermSmoothing): LDAParams = { + + val numTerms = initParams.topicTermCounts.head.size + val numDocs = initParams.docCounts.size + val numTopics = initParams.topicCounts.size + + // Construct topic assignment RDD + logInfo("Start initialization") + + val cpInterval = System.getProperty("spark.gibbsSampling.checkPointInterval", "10").toInt + val sc = data.context + val (initialParams, initialChosenTopics) = sampleTermAssignment(initParams, data) + + // Gibbs sampling + val (params, _, _) = Iterator.iterate((sc.accumulable(initialParams), initialChosenTopics, 0)) { + case (lastParams, lastChosenTopics, i) => + logInfo("Start Gibbs sampling") + + val rand = new Random(42 + i * i) + val params = sc.accumulable(LDAParams(numDocs, numTopics, numTerms)) + val chosenTopics = data.zip(lastChosenTopics).map { + case (Document(docId, content), topics) => + content.zip(topics).map { case (term, topic) => + lastParams += (docId, term, topic, -1) + + val chosenTopic = lastParams.localValue.dropOneDistSampler( + docTopicSmoothing, topicTermSmoothing, term, docId, rand) + + lastParams += (docId, term, chosenTopic, 1) + params += (docId, term, chosenTopic, 1) + + chosenTopic + } + }.cache() + + if (i + 1 % cpInterval == 0) { + chosenTopics.checkpoint() + } + + // Trigger a job to collect accumulable LDA parameters. + chosenTopics.count() + lastChosenTopics.unpersist() + + (params, chosenTopics, i + 1) + }.drop(1 + numOuterIterations).next() + + params.value + } + + /** + * Model matrix Phi and Theta are inferred via LDAParams. + */ + def solvePhiAndTheta( + params: LDAParams, + docTopicSmoothing: Double = docTopicSmoothing, + topicTermSmoothing: Double = topicTermSmoothing): (Array[Vector], Array[Vector]) = { + val numTopics = params.topicCounts.size + val numTerms = params.topicTermCounts.head.size + + val docCount = params.docCounts.toBreeze :+ (docTopicSmoothing * numTopics) + val topicCount = params.topicCounts.toBreeze :+ (topicTermSmoothing * numTerms) + val docTopicCount = params.docTopicCounts.map(vec => vec.toBreeze :+ docTopicSmoothing) + val topicTermCount = params.topicTermCounts.map(vec => vec.toBreeze :+ topicTermSmoothing) + + var i = 0 + while (i < numTopics) { + topicTermCount(i) :/= topicCount(i) + i += 1 + } + + i = 0 + while (i < docCount.length) { + docTopicCount(i) :/= docCount(i) + i += 1 + } + + (topicTermCount.map(vec => Vectors.fromBreeze(vec)), + docTopicCount.map(vec => Vectors.fromBreeze(vec))) + } +} + +object GibbsSampling extends Logging { + + /** + * Initial step of Gibbs sampling, which supports incremental LDA. + */ + private def sampleTermAssignment( + params: LDAParams, + data: RDD[Document]): (LDAParams, RDD[Iterable[Int]]) = { + + val sc = data.context + val initialParams = sc.accumulable(params) + val rand = new Random(42) + val initialChosenTopics = data.map { case Document(docId, content) => + val docTopics = params.docTopicCounts(docId) + if (docTopics.toBreeze.norm(2) == 0) { + content.map { term => + val topic = uniformDistSampler(rand, params.topicCounts.size) + initialParams += (docId, term, topic, 1) + topic + } + } else { + content.map { term => + val topicTerms = Vectors.dense(params.topicTermCounts.map(_(term))).toBreeze + val dist = docTopics.toBreeze :* topicTerms + multinomialDistSampler(rand, dist.asInstanceOf[BDV[Double]]) + } + } + }.cache() + + // Trigger a job to collect accumulable LDA parameters. + initialChosenTopics.count() + + (initialParams.value, initialChosenTopics) + } + + /** + * A uniform distribution sampler, which is only used for initialization. + */ + private def uniformDistSampler(rand: Random, dimension: Int): Int = rand.nextInt(dimension) + + /** + * A multinomial distribution sampler, using roulette method to sample an Int back. + */ + def multinomialDistSampler(rand: Random, dist: BDV[Double]): Int = { + val roulette = rand.nextDouble() + + dist :/= sum[BDV[Double], Double](dist) + + def loop(index: Int, accum: Double): Int = { + if(index == dist.length) return dist.length - 1 + val sum = accum + dist(index) + if (sum >= roulette) index else loop(index + 1, sum) + } + + loop(0, 0.0) + } + + /** + * Perplexity is a kind of evaluation method of LDA. Usually it is used on unseen data. But here + * we use it for current documents, which is also OK. If using it on unseen data, you must do an + * iteration of Gibbs sampling before calling this. Small perplexity means good result. + */ + def perplexity(data: RDD[Document], phi: Array[Vector], theta: Array[Vector]): Double = { + val (termProb, totalNum) = data.flatMap { case Document(docId, content) => + val currentTheta = BDV.zeros[Double](phi.head.size) + var col = 0 + var row = 0 + while (col < phi.head.size) { + row = 0 + while (row < phi.length) { + currentTheta(col) += phi(row)(col) * theta(docId)(row) + row += 1 + } + col += 1 + } + content.map(x => (math.log(currentTheta(x)), 1)) + }.reduce { (lhs, rhs) => + (lhs._1 + rhs._1, lhs._2 + rhs._2) + } + math.exp(-1 * termProb / totalNum) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 3d6e7e0d5c95..5498736a0cc9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -20,15 +20,17 @@ package org.apache.spark.mllib.util import scala.reflect.ClassTag import breeze.linalg.{Vector => BV, SparseVector => BSV, squaredDistance => breezeSquaredDistance} +import breeze.util.Index +import chalk.text.tokenize.JavaWordTokenizer import org.apache.spark.annotation.Experimental -import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.rdd.PartitionwiseSampledRDD -import org.apache.spark.SparkContext._ +import org.apache.spark.SparkContext import org.apache.spark.util.random.BernoulliSampler import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.clustering.Document /** * Helper methods to load, save and pre-process data used in ML Lib. @@ -233,4 +235,60 @@ object MLUtils { } sqDist } + +/** + * Load corpus from a given path. Terms and documents will be translated into integers, with a + * term-integer map and a document-integer map. + * + * @param dir The path of corpus. + * @param dirStopWords The path of stop words. + * @return (RDD[Document], Term-integer map, doc-integer map) + */ + def loadCorpus( + sc: SparkContext, + dir: String, + minSplits: Int, + dirStopWords: String = ""): + (RDD[Document], Index[String], Index[String]) = { + + // Containers and indexers for terms and documents + val termMap = Index[String]() + val docMap = Index[String]() + + val stopWords = + if (dirStopWords == "") { + Set.empty[String] + } + else { + sc.textFile(dirStopWords, minSplits). + map(x => x.replaceAll( """(?m)\s+$""", "")).distinct().collect().toSet + } + val broadcastStopWord = sc.broadcast(stopWords) + + // Tokenize and filter terms + val almostData = sc.wholeTextFiles(dir, minSplits).map { case (fileName, content) => + val tokens = JavaWordTokenizer(content) + .filter(_(0).isLetter) + .filter(!broadcastStopWord.value.contains(_)) + (fileName, tokens) + } + + almostData.map(_._1).collect().map(x => docMap.index(x)) + + almostData.flatMap(_._2).collect().map(x => termMap.index(x)) + + println(termMap.size) + println(docMap.size) + + val broadcastWordMap = sc.broadcast(termMap) + val broadcastDocMap = sc.broadcast(docMap) + + val data = almostData.map { case (fileName, tokens) => + val fileId = broadcastDocMap.value.index(fileName) + val translatedContent = tokens.map(broadcastWordMap.value.index) + Document(fileId, translatedContent) + }.cache() + + (data, termMap, docMap) + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala new file mode 100644 index 000000000000..e7dd4b819781 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import java.util.Random + +import breeze.linalg.{DenseVector => BDV} +import breeze.stats.distributions.Poisson +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import org.apache.spark.mllib.expectation.GibbsSampling +import org.apache.spark.mllib.expectation.GibbsSampling._ +import org.apache.spark.SparkContext + +class LDASuite extends FunSuite with BeforeAndAfterAll { + import LDASuite._ + @transient private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "LDA org.apache.spark.mllib.test") + } + + override def afterAll() { + sc.stop() + System.clearProperty("spark.driver.port") + } + + test("LDA || Gibbs sampling") { + val checkPointDir = System.getProperty("spark.gibbsSampling.checkPointDir", "/tmp/lda") + sc.setCheckpointDir(checkPointDir) + val model = generateRandomLDAModel(numTopics, numTerms) + val corpus = sampleCorpus(model, numDocs, numTerms, numTopics) + val data = sc.parallelize(corpus, 2) + var computedModel = LDAParams(numDocs, numTopics, numTerms) + val trainer = new GibbsSampling( + data, + numOuterIterations, + numInnerIterations, + docTopicSmoothing, + topicTermSmoothing) + + val pps = new Array[Double](incrementalLearning) + + var i = 0 + while (i < incrementalLearning) { + computedModel = trainer.runGibbsSampling(computedModel) + val (phi, theta) = trainer.solvePhiAndTheta(computedModel) + pps(i) = perplexity(data, phi, theta) + i += 1 + } + + pps.foreach(println) + val ppsDiff = pps.init.zip(pps.tail).map { case (lhs, rhs) => lhs - rhs } + assert(ppsDiff.count(_ > 0).toDouble / ppsDiff.size > 0.6) + assert(pps.head - pps.last > 0) + } +} + +object LDASuite { + val numTopics = 5 + val numTerms = 1000 + val numDocs = 100 + val expectedDocLength = 300 + val docTopicSmoothing = 0.01 + val topicTermSmoothing = 0.01 + val numOuterIterations = 5 + val numInnerIterations = 1 + val incrementalLearning = 10 + + /** + * Generate a random LDA model, i.e. the topic-term matrix. + */ + def generateRandomLDAModel(numTopics: Int, numTerms: Int): Array[BDV[Double]] = { + val model = new Array[BDV[Double]](numTopics) + val width = numTerms * 1.0 / numTopics + var topic = 0 + var i = 0 + while (topic < numTopics) { + val topicCentroid = width * (topic + 1) + model(topic) = BDV.zeros[Double](numTerms) + i = 0 + while (i < numTerms) { + // treat the term list as a circle, so the distance between the first one and the last one + // is 1, not n-1. + val distance = Math.abs(topicCentroid - i) % (numTerms / 2) + // Possibility is decay along with distance + model(topic)(i) = 1.0 / (1 + Math.abs(distance)) + i += 1 + } + topic += 1 + } + model + } + + /** + * Sample one document given the topic-term matrix. + */ + def ldaSampler( + model: Array[BDV[Double]], + topicDist: BDV[Double], + numTermsPerDoc: Int): Array[Int] = { + val samples = new Array[Int](numTermsPerDoc) + val rand = new Random() + (0 until numTermsPerDoc).foreach { i => + samples(i) = multinomialDistSampler( + rand, + model(multinomialDistSampler(rand, topicDist)) + ) + } + samples + } + + /** + * Sample corpus (many documents) from a given topic-term matrix. + */ + def sampleCorpus( + model: Array[BDV[Double]], + numDocs: Int, + numTerms: Int, + numTopics: Int): Array[Document] = { + (0 until numDocs).map { i => + val rand = new Random() + val numTermsPerDoc = Poisson.distribution(expectedDocLength).sample() + val numTopicsPerDoc = rand.nextInt(numTopics / 2) + 1 + val topicDist = BDV.zeros[Double](numTopics) + (0 until numTopicsPerDoc).foreach { _ => + topicDist(rand.nextInt(numTopics)) += 1 + } + Document(i, ldaSampler(model, topicDist, numTermsPerDoc)) + }.toArray + } +} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 33f9d644ca66..fe217c9dea5e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -399,6 +399,7 @@ object SparkBuild extends Build { name := "spark-mllib", previousArtifact := sparkPreviousArtifact("spark-mllib"), libraryDependencies ++= Seq( + "org.scalanlp" % "chalk" % "1.3.0", "org.jblas" % "jblas" % jblasVersion, "org.scalanlp" %% "breeze" % "0.7" )