diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
new file mode 100644
index 000000000000..1f0d8e76ca28
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import java.util.Random
+
+import breeze.linalg.{DenseVector => BDV}
+
+import org.apache.spark.{AccumulableParam, Logging, SparkContext}
+import org.apache.spark.mllib.expectation.GibbsSampling
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.rdd.RDD
+
+case class Document(docId: Int, content: Iterable[Int])
+
+case class LDAParams (
+ docCounts: Vector,
+ topicCounts: Vector,
+ docTopicCounts: Array[Vector],
+ topicTermCounts: Array[Vector])
+ extends Serializable {
+
+ def update(docId: Int, term: Int, topic: Int, inc: Int) = {
+ docCounts.toBreeze(docId) += inc
+ topicCounts.toBreeze(topic) += inc
+ docTopicCounts(docId).toBreeze(topic) += inc
+ topicTermCounts(topic).toBreeze(term) += inc
+ this
+ }
+
+ def merge(other: LDAParams) = {
+ docCounts.toBreeze += other.docCounts.toBreeze
+ topicCounts.toBreeze += other.topicCounts.toBreeze
+
+ var i = 0
+ while (i < docTopicCounts.length) {
+ docTopicCounts(i).toBreeze += other.docTopicCounts(i).toBreeze
+ i += 1
+ }
+
+ i = 0
+ while (i < topicTermCounts.length) {
+ topicTermCounts(i).toBreeze += other.topicTermCounts(i).toBreeze
+ i += 1
+ }
+ this
+ }
+
+ /**
+ * This function used for computing the new distribution after drop one from current document,
+ * which is a really essential part of Gibbs sampling for LDA, you can refer to the paper:
+ * Parameter estimation for text analysis
+ */
+ def dropOneDistSampler(
+ docTopicSmoothing: Double,
+ topicTermSmoothing: Double,
+ termId: Int,
+ docId: Int,
+ rand: Random): Int = {
+ val (numTopics, numTerms) = (topicCounts.size, topicTermCounts.head.size)
+ val topicThisTerm = BDV.zeros[Double](numTopics)
+ var i = 0
+ while (i < numTopics) {
+ topicThisTerm(i) =
+ ((topicTermCounts(i)(termId) + topicTermSmoothing)
+ / (topicCounts(i) + (numTerms * topicTermSmoothing))
+ ) + (docTopicCounts(docId)(i) + docTopicSmoothing)
+ i += 1
+ }
+ GibbsSampling.multinomialDistSampler(rand, topicThisTerm)
+ }
+}
+
+object LDAParams {
+ implicit val ldaParamsAP = new LDAParamsAccumulableParam
+
+ def apply(numDocs: Int, numTopics: Int, numTerms: Int) = new LDAParams(
+ Vectors.fromBreeze(BDV.zeros[Double](numDocs)),
+ Vectors.fromBreeze(BDV.zeros[Double](numTopics)),
+ Array(0 until numDocs: _*).map(_ => Vectors.fromBreeze(BDV.zeros[Double](numTopics))),
+ Array(0 until numTopics: _*).map(_ => Vectors.fromBreeze(BDV.zeros[Double](numTerms))))
+}
+
+class LDAParamsAccumulableParam extends AccumulableParam[LDAParams, (Int, Int, Int, Int)] {
+ def addAccumulator(r: LDAParams, t: (Int, Int, Int, Int)) = {
+ val (docId, term, topic, inc) = t
+ r.update(docId, term, topic, inc)
+ }
+
+ def addInPlace(r1: LDAParams, r2: LDAParams): LDAParams = r1.merge(r2)
+
+ def zero(initialValue: LDAParams): LDAParams = initialValue
+}
+
+class LDA private (
+ var numTopics: Int,
+ var docTopicSmoothing: Double,
+ var topicTermSmoothing: Double,
+ var numIteration: Int,
+ var numDocs: Int,
+ var numTerms: Int)
+ extends Serializable with Logging {
+ def run(input: RDD[Document]): (GibbsSampling, LDAParams) = {
+ val trainer = new GibbsSampling(
+ input,
+ numIteration,
+ 1,
+ docTopicSmoothing,
+ topicTermSmoothing)
+ (trainer, trainer.runGibbsSampling(LDAParams(numDocs, numTopics, numTerms)))
+ }
+}
+
+object LDA extends Logging {
+
+ def train(
+ data: RDD[Document],
+ numTopics: Int,
+ docTopicSmoothing: Double,
+ topicTermSmoothing: Double,
+ numIterations: Int,
+ numDocs: Int,
+ numTerms: Int): (Array[Vector], Array[Vector]) = {
+ val lda = new LDA(numTopics,
+ docTopicSmoothing,
+ topicTermSmoothing,
+ numIterations,
+ numDocs,
+ numTerms)
+ val (trainer, model) = lda.run(data)
+ trainer.solvePhiAndTheta(model)
+ }
+
+ def main(args: Array[String]) {
+ if (args.length != 5) {
+ println("Usage: LDA ")
+ System.exit(1)
+ }
+
+ val (master, inputDir, k, iters, minSplit) =
+ (args(0), args(1), args(2).toInt, args(3).toInt, args(4).toInt)
+ val checkPointDir = System.getProperty("spark.gibbsSampling.checkPointDir", "/tmp/lda-cp")
+ val sc = new SparkContext(master, "LDA")
+ sc.setCheckpointDir(checkPointDir)
+ val (data, wordMap, docMap) = MLUtils.loadCorpus(sc, inputDir, minSplit)
+ val numDocs = docMap.size
+ val numTerms = wordMap.size
+
+ val (phi, theta) = LDA.train(data, k, 0.01, 0.01, iters, numDocs, numTerms)
+ val pp = GibbsSampling.perplexity(data, phi, theta)
+ println(s"final mode perplexity is $pp")
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/expectation/GibbsSampling.scala b/mllib/src/main/scala/org/apache/spark/mllib/expectation/GibbsSampling.scala
new file mode 100644
index 000000000000..5e4f776c9166
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/expectation/GibbsSampling.scala
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.expectation
+
+import java.util.Random
+
+import breeze.linalg.{DenseVector => BDV, sum}
+
+import org.apache.spark.Logging
+import org.apache.spark.rdd.RDD
+import org.apache.spark.mllib.clustering.{Document, LDAParams}
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+
+/**
+ * Gibbs sampling from a given dataset and org.apache.spark.mllib.model.
+ * @param data Dataset, such as corpus.
+ * @param numOuterIterations Number of outer iteration.
+ * @param numInnerIterations Number of inner iteration, used in each partition.
+ * @param docTopicSmoothing Document-topic smoothing.
+ * @param topicTermSmoothing Topic-term smoothing.
+ */
+class GibbsSampling(
+ data: RDD[Document],
+ numOuterIterations: Int,
+ numInnerIterations: Int,
+ docTopicSmoothing: Double,
+ topicTermSmoothing: Double)
+ extends Logging with Serializable {
+
+ import GibbsSampling._
+
+ /**
+ * Main function of running a Gibbs sampling method. It contains two phases of total Gibbs
+ * sampling: first is initialization, second is real sampling.
+ */
+ def runGibbsSampling(
+ initParams: LDAParams,
+ data: RDD[Document] = data,
+ numOuterIterations: Int = numOuterIterations,
+ numInnerIterations: Int = numInnerIterations,
+ docTopicSmoothing: Double = docTopicSmoothing,
+ topicTermSmoothing: Double = topicTermSmoothing): LDAParams = {
+
+ val numTerms = initParams.topicTermCounts.head.size
+ val numDocs = initParams.docCounts.size
+ val numTopics = initParams.topicCounts.size
+
+ // Construct topic assignment RDD
+ logInfo("Start initialization")
+
+ val cpInterval = System.getProperty("spark.gibbsSampling.checkPointInterval", "10").toInt
+ val sc = data.context
+ val (initialParams, initialChosenTopics) = sampleTermAssignment(initParams, data)
+
+ // Gibbs sampling
+ val (params, _, _) = Iterator.iterate((sc.accumulable(initialParams), initialChosenTopics, 0)) {
+ case (lastParams, lastChosenTopics, i) =>
+ logInfo("Start Gibbs sampling")
+
+ val rand = new Random(42 + i * i)
+ val params = sc.accumulable(LDAParams(numDocs, numTopics, numTerms))
+ val chosenTopics = data.zip(lastChosenTopics).map {
+ case (Document(docId, content), topics) =>
+ content.zip(topics).map { case (term, topic) =>
+ lastParams += (docId, term, topic, -1)
+
+ val chosenTopic = lastParams.localValue.dropOneDistSampler(
+ docTopicSmoothing, topicTermSmoothing, term, docId, rand)
+
+ lastParams += (docId, term, chosenTopic, 1)
+ params += (docId, term, chosenTopic, 1)
+
+ chosenTopic
+ }
+ }.cache()
+
+ if (i + 1 % cpInterval == 0) {
+ chosenTopics.checkpoint()
+ }
+
+ // Trigger a job to collect accumulable LDA parameters.
+ chosenTopics.count()
+ lastChosenTopics.unpersist()
+
+ (params, chosenTopics, i + 1)
+ }.drop(1 + numOuterIterations).next()
+
+ params.value
+ }
+
+ /**
+ * Model matrix Phi and Theta are inferred via LDAParams.
+ */
+ def solvePhiAndTheta(
+ params: LDAParams,
+ docTopicSmoothing: Double = docTopicSmoothing,
+ topicTermSmoothing: Double = topicTermSmoothing): (Array[Vector], Array[Vector]) = {
+ val numTopics = params.topicCounts.size
+ val numTerms = params.topicTermCounts.head.size
+
+ val docCount = params.docCounts.toBreeze :+ (docTopicSmoothing * numTopics)
+ val topicCount = params.topicCounts.toBreeze :+ (topicTermSmoothing * numTerms)
+ val docTopicCount = params.docTopicCounts.map(vec => vec.toBreeze :+ docTopicSmoothing)
+ val topicTermCount = params.topicTermCounts.map(vec => vec.toBreeze :+ topicTermSmoothing)
+
+ var i = 0
+ while (i < numTopics) {
+ topicTermCount(i) :/= topicCount(i)
+ i += 1
+ }
+
+ i = 0
+ while (i < docCount.length) {
+ docTopicCount(i) :/= docCount(i)
+ i += 1
+ }
+
+ (topicTermCount.map(vec => Vectors.fromBreeze(vec)),
+ docTopicCount.map(vec => Vectors.fromBreeze(vec)))
+ }
+}
+
+object GibbsSampling extends Logging {
+
+ /**
+ * Initial step of Gibbs sampling, which supports incremental LDA.
+ */
+ private def sampleTermAssignment(
+ params: LDAParams,
+ data: RDD[Document]): (LDAParams, RDD[Iterable[Int]]) = {
+
+ val sc = data.context
+ val initialParams = sc.accumulable(params)
+ val rand = new Random(42)
+ val initialChosenTopics = data.map { case Document(docId, content) =>
+ val docTopics = params.docTopicCounts(docId)
+ if (docTopics.toBreeze.norm(2) == 0) {
+ content.map { term =>
+ val topic = uniformDistSampler(rand, params.topicCounts.size)
+ initialParams += (docId, term, topic, 1)
+ topic
+ }
+ } else {
+ content.map { term =>
+ val topicTerms = Vectors.dense(params.topicTermCounts.map(_(term))).toBreeze
+ val dist = docTopics.toBreeze :* topicTerms
+ multinomialDistSampler(rand, dist.asInstanceOf[BDV[Double]])
+ }
+ }
+ }.cache()
+
+ // Trigger a job to collect accumulable LDA parameters.
+ initialChosenTopics.count()
+
+ (initialParams.value, initialChosenTopics)
+ }
+
+ /**
+ * A uniform distribution sampler, which is only used for initialization.
+ */
+ private def uniformDistSampler(rand: Random, dimension: Int): Int = rand.nextInt(dimension)
+
+ /**
+ * A multinomial distribution sampler, using roulette method to sample an Int back.
+ */
+ def multinomialDistSampler(rand: Random, dist: BDV[Double]): Int = {
+ val roulette = rand.nextDouble()
+
+ dist :/= sum[BDV[Double], Double](dist)
+
+ def loop(index: Int, accum: Double): Int = {
+ if(index == dist.length) return dist.length - 1
+ val sum = accum + dist(index)
+ if (sum >= roulette) index else loop(index + 1, sum)
+ }
+
+ loop(0, 0.0)
+ }
+
+ /**
+ * Perplexity is a kind of evaluation method of LDA. Usually it is used on unseen data. But here
+ * we use it for current documents, which is also OK. If using it on unseen data, you must do an
+ * iteration of Gibbs sampling before calling this. Small perplexity means good result.
+ */
+ def perplexity(data: RDD[Document], phi: Array[Vector], theta: Array[Vector]): Double = {
+ val (termProb, totalNum) = data.flatMap { case Document(docId, content) =>
+ val currentTheta = BDV.zeros[Double](phi.head.size)
+ var col = 0
+ var row = 0
+ while (col < phi.head.size) {
+ row = 0
+ while (row < phi.length) {
+ currentTheta(col) += phi(row)(col) * theta(docId)(row)
+ row += 1
+ }
+ col += 1
+ }
+ content.map(x => (math.log(currentTheta(x)), 1))
+ }.reduce { (lhs, rhs) =>
+ (lhs._1 + rhs._1, lhs._2 + rhs._2)
+ }
+ math.exp(-1 * termProb / totalNum)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index 3d6e7e0d5c95..5498736a0cc9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -20,15 +20,17 @@ package org.apache.spark.mllib.util
import scala.reflect.ClassTag
import breeze.linalg.{Vector => BV, SparseVector => BSV, squaredDistance => breezeSquaredDistance}
+import breeze.util.Index
+import chalk.text.tokenize.JavaWordTokenizer
import org.apache.spark.annotation.Experimental
-import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.PartitionwiseSampledRDD
-import org.apache.spark.SparkContext._
+import org.apache.spark.SparkContext
import org.apache.spark.util.random.BernoulliSampler
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.clustering.Document
/**
* Helper methods to load, save and pre-process data used in ML Lib.
@@ -233,4 +235,60 @@ object MLUtils {
}
sqDist
}
+
+/**
+ * Load corpus from a given path. Terms and documents will be translated into integers, with a
+ * term-integer map and a document-integer map.
+ *
+ * @param dir The path of corpus.
+ * @param dirStopWords The path of stop words.
+ * @return (RDD[Document], Term-integer map, doc-integer map)
+ */
+ def loadCorpus(
+ sc: SparkContext,
+ dir: String,
+ minSplits: Int,
+ dirStopWords: String = ""):
+ (RDD[Document], Index[String], Index[String]) = {
+
+ // Containers and indexers for terms and documents
+ val termMap = Index[String]()
+ val docMap = Index[String]()
+
+ val stopWords =
+ if (dirStopWords == "") {
+ Set.empty[String]
+ }
+ else {
+ sc.textFile(dirStopWords, minSplits).
+ map(x => x.replaceAll( """(?m)\s+$""", "")).distinct().collect().toSet
+ }
+ val broadcastStopWord = sc.broadcast(stopWords)
+
+ // Tokenize and filter terms
+ val almostData = sc.wholeTextFiles(dir, minSplits).map { case (fileName, content) =>
+ val tokens = JavaWordTokenizer(content)
+ .filter(_(0).isLetter)
+ .filter(!broadcastStopWord.value.contains(_))
+ (fileName, tokens)
+ }
+
+ almostData.map(_._1).collect().map(x => docMap.index(x))
+
+ almostData.flatMap(_._2).collect().map(x => termMap.index(x))
+
+ println(termMap.size)
+ println(docMap.size)
+
+ val broadcastWordMap = sc.broadcast(termMap)
+ val broadcastDocMap = sc.broadcast(docMap)
+
+ val data = almostData.map { case (fileName, tokens) =>
+ val fileId = broadcastDocMap.value.index(fileName)
+ val translatedContent = tokens.map(broadcastWordMap.value.index)
+ Document(fileId, translatedContent)
+ }.cache()
+
+ (data, termMap, docMap)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
new file mode 100644
index 000000000000..e7dd4b819781
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import java.util.Random
+
+import breeze.linalg.{DenseVector => BDV}
+import breeze.stats.distributions.Poisson
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.expectation.GibbsSampling
+import org.apache.spark.mllib.expectation.GibbsSampling._
+import org.apache.spark.SparkContext
+
+class LDASuite extends FunSuite with BeforeAndAfterAll {
+ import LDASuite._
+ @transient private var sc: SparkContext = _
+
+ override def beforeAll() {
+ sc = new SparkContext("local", "LDA org.apache.spark.mllib.test")
+ }
+
+ override def afterAll() {
+ sc.stop()
+ System.clearProperty("spark.driver.port")
+ }
+
+ test("LDA || Gibbs sampling") {
+ val checkPointDir = System.getProperty("spark.gibbsSampling.checkPointDir", "/tmp/lda")
+ sc.setCheckpointDir(checkPointDir)
+ val model = generateRandomLDAModel(numTopics, numTerms)
+ val corpus = sampleCorpus(model, numDocs, numTerms, numTopics)
+ val data = sc.parallelize(corpus, 2)
+ var computedModel = LDAParams(numDocs, numTopics, numTerms)
+ val trainer = new GibbsSampling(
+ data,
+ numOuterIterations,
+ numInnerIterations,
+ docTopicSmoothing,
+ topicTermSmoothing)
+
+ val pps = new Array[Double](incrementalLearning)
+
+ var i = 0
+ while (i < incrementalLearning) {
+ computedModel = trainer.runGibbsSampling(computedModel)
+ val (phi, theta) = trainer.solvePhiAndTheta(computedModel)
+ pps(i) = perplexity(data, phi, theta)
+ i += 1
+ }
+
+ pps.foreach(println)
+ val ppsDiff = pps.init.zip(pps.tail).map { case (lhs, rhs) => lhs - rhs }
+ assert(ppsDiff.count(_ > 0).toDouble / ppsDiff.size > 0.6)
+ assert(pps.head - pps.last > 0)
+ }
+}
+
+object LDASuite {
+ val numTopics = 5
+ val numTerms = 1000
+ val numDocs = 100
+ val expectedDocLength = 300
+ val docTopicSmoothing = 0.01
+ val topicTermSmoothing = 0.01
+ val numOuterIterations = 5
+ val numInnerIterations = 1
+ val incrementalLearning = 10
+
+ /**
+ * Generate a random LDA model, i.e. the topic-term matrix.
+ */
+ def generateRandomLDAModel(numTopics: Int, numTerms: Int): Array[BDV[Double]] = {
+ val model = new Array[BDV[Double]](numTopics)
+ val width = numTerms * 1.0 / numTopics
+ var topic = 0
+ var i = 0
+ while (topic < numTopics) {
+ val topicCentroid = width * (topic + 1)
+ model(topic) = BDV.zeros[Double](numTerms)
+ i = 0
+ while (i < numTerms) {
+ // treat the term list as a circle, so the distance between the first one and the last one
+ // is 1, not n-1.
+ val distance = Math.abs(topicCentroid - i) % (numTerms / 2)
+ // Possibility is decay along with distance
+ model(topic)(i) = 1.0 / (1 + Math.abs(distance))
+ i += 1
+ }
+ topic += 1
+ }
+ model
+ }
+
+ /**
+ * Sample one document given the topic-term matrix.
+ */
+ def ldaSampler(
+ model: Array[BDV[Double]],
+ topicDist: BDV[Double],
+ numTermsPerDoc: Int): Array[Int] = {
+ val samples = new Array[Int](numTermsPerDoc)
+ val rand = new Random()
+ (0 until numTermsPerDoc).foreach { i =>
+ samples(i) = multinomialDistSampler(
+ rand,
+ model(multinomialDistSampler(rand, topicDist))
+ )
+ }
+ samples
+ }
+
+ /**
+ * Sample corpus (many documents) from a given topic-term matrix.
+ */
+ def sampleCorpus(
+ model: Array[BDV[Double]],
+ numDocs: Int,
+ numTerms: Int,
+ numTopics: Int): Array[Document] = {
+ (0 until numDocs).map { i =>
+ val rand = new Random()
+ val numTermsPerDoc = Poisson.distribution(expectedDocLength).sample()
+ val numTopicsPerDoc = rand.nextInt(numTopics / 2) + 1
+ val topicDist = BDV.zeros[Double](numTopics)
+ (0 until numTopicsPerDoc).foreach { _ =>
+ topicDist(rand.nextInt(numTopics)) += 1
+ }
+ Document(i, ldaSampler(model, topicDist, numTermsPerDoc))
+ }.toArray
+ }
+}
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 33f9d644ca66..fe217c9dea5e 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -399,6 +399,7 @@ object SparkBuild extends Build {
name := "spark-mllib",
previousArtifact := sparkPreviousArtifact("spark-mllib"),
libraryDependencies ++= Seq(
+ "org.scalanlp" % "chalk" % "1.3.0",
"org.jblas" % "jblas" % jblasVersion,
"org.scalanlp" %% "breeze" % "0.7"
)