Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
package org.apache.spark.mllib.clustering

import breeze.linalg.{DenseVector => BDV}

import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaPairRDD
import org.apache.spark.graphx._
import org.apache.spark.mllib.linalg.Vector
Expand Down Expand Up @@ -197,20 +198,28 @@ class LDA private (
}


/** LDAOptimizer used to perform the actual calculation */
/**
* :: DeveloperApi ::
*
* LDAOptimizer used to perform the actual calculation
*/
@DeveloperApi
def getOptimizer: LDAOptimizer = ldaOptimizer

/**
* :: DeveloperApi ::
*
* LDAOptimizer used to perform the actual calculation (default = EMLDAOptimizer)
*/
@DeveloperApi
def setOptimizer(optimizer: LDAOptimizer): this.type = {
this.ldaOptimizer = optimizer
this
}

/**
* Set the LDAOptimizer used to perform the actual calculation by algorithm name.
* Currently "em", "online" is supported.
* Currently "em", "online" are supported.
*/
def setOptimizer(optimizerName: String): this.type = {
this.ldaOptimizer =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,21 @@ import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, sum, normalize, kr
import breeze.numerics.{digamma, exp, abs}
import breeze.stats.distributions.{Gamma, RandBasis}

import org.apache.spark.annotation.Experimental
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.graphx._
import org.apache.spark.graphx.impl.GraphImpl
import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
import org.apache.spark.mllib.linalg.{Matrices, SparseVector, DenseVector, Vector}
import org.apache.spark.rdd.RDD

/**
* :: Experimental ::
* :: DeveloperApi ::
*
* An LDAOptimizer specifies which optimization/learning/inference algorithm to use, and it can
* hold optimizer-specific parameters for users to set.
*/
@Experimental
trait LDAOptimizer {
@DeveloperApi
sealed trait LDAOptimizer {

/*
DEVELOPERS NOTE:
Expand All @@ -59,7 +59,7 @@ trait LDAOptimizer {
}

/**
* :: Experimental ::
* :: DeveloperApi ::
*
* Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters.
*
Expand All @@ -75,8 +75,8 @@ trait LDAOptimizer {
* "On Smoothing and Inference for Topic Models." UAI, 2009.
*
*/
@Experimental
class EMLDAOptimizer extends LDAOptimizer {
@DeveloperApi
final class EMLDAOptimizer extends LDAOptimizer {

import LDA._

Expand Down Expand Up @@ -211,7 +211,7 @@ class EMLDAOptimizer extends LDAOptimizer {


/**
* :: Experimental ::
* :: DeveloperApi ::
*
* An online optimizer for LDA. The Optimizer implements the Online variational Bayes LDA
* algorithm, which processes a subset of the corpus on each iteration, and updates the term-topic
Expand All @@ -220,8 +220,8 @@ class EMLDAOptimizer extends LDAOptimizer {
* Original Online LDA paper:
* Hoffman, Blei and Bach, "Online Learning for Latent Dirichlet Allocation." NIPS, 2010.
*/
@Experimental
class OnlineLDAOptimizer extends LDAOptimizer {
@DeveloperApi
final class OnlineLDAOptimizer extends LDAOptimizer {

// LDA common parameters
private var k: Int = 0
Expand All @@ -243,8 +243,8 @@ class OnlineLDAOptimizer extends LDAOptimizer {
private var randomGenerator: java.util.Random = null

// Online LDA specific parameters
// Learning rate is: (tau_0 + t)^{-kappa}
private var tau_0: Double = 1024
// Learning rate is: (tau0 + t)^{-kappa}
private var tau0: Double = 1024
private var kappa: Double = 0.51
private var miniBatchFraction: Double = 0.05

Expand All @@ -265,16 +265,16 @@ class OnlineLDAOptimizer extends LDAOptimizer {
* A (positive) learning parameter that downweights early iterations. Larger values make early
* iterations count less.
*/
def getTau_0: Double = this.tau_0
def getTau0: Double = this.tau0

/**
* A (positive) learning parameter that downweights early iterations. Larger values make early
* iterations count less.
* Default: 1024, following the original Online LDA paper.
*/
def setTau_0(tau_0: Double): this.type = {
require(tau_0 > 0, s"LDA tau_0 must be positive, but was set to $tau_0")
this.tau_0 = tau_0
def setTau0(tau0: Double): this.type = {
require(tau0 > 0, s"LDA tau0 must be positive, but was set to $tau0")
this.tau0 = tau0
this
}

Expand Down Expand Up @@ -434,11 +434,8 @@ class OnlineLDAOptimizer extends LDAOptimizer {
* Update lambda based on the batch submitted. batchSize can be different for each iteration.
*/
private[clustering] def update(stat: BDM[Double], iter: Int, batchSize: Int): Unit = {
val tau_0 = this.getTau_0
val kappa = this.getKappa

// weight of the mini-batch.
val weight = math.pow(tau_0 + iter, -kappa)
val weight = math.pow(getTau0 + iter, -getKappa)

// Update lambda based on documents.
lambda = lambda * (1 - weight) +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ public void OnlineOptimizerCompatibility() {

// Train a model
OnlineLDAOptimizer op = new OnlineLDAOptimizer()
.setTau_0(1024)
.setTau0(1024)
.setKappa(0.51)
.setGammaShape(1e40)
.setMiniBatchFraction(0.5);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,12 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
val lda = new LDA().setK(2)
val corpus = sc.parallelize(tinyCorpus, 2)
val op = new OnlineLDAOptimizer().initialize(corpus, lda)
op.setKappa(0.9876).setMiniBatchFraction(0.123).setTau_0(567)
op.setKappa(0.9876).setMiniBatchFraction(0.123).setTau0(567)
assert(op.getAlpha == 0.5) // default 1.0 / k
assert(op.getEta == 0.5) // default 1.0 / k
assert(op.getKappa == 0.9876)
assert(op.getMiniBatchFraction == 0.123)
assert(op.getTau_0 == 567)
assert(op.getTau0 == 567)
}

test("OnlineLDAOptimizer one iteration") {
Expand All @@ -159,7 +159,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
val corpus = sc.parallelize(docs, 2)

// Set GammaShape large to avoid the stochastic impact.
val op = new OnlineLDAOptimizer().setTau_0(1024).setKappa(0.51).setGammaShape(1e40)
val op = new OnlineLDAOptimizer().setTau0(1024).setKappa(0.51).setGammaShape(1e40)
.setMiniBatchFraction(1)
val lda = new LDA().setK(k).setMaxIterations(1).setOptimizer(op).setSeed(12345)

Expand Down Expand Up @@ -192,7 +192,7 @@ class LDASuite extends FunSuite with MLlibTestSparkContext {
).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }

val docs = sc.parallelize(toydata)
val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau_0(1024).setKappa(0.51)
val op = new OnlineLDAOptimizer().setMiniBatchFraction(1).setTau0(1024).setKappa(0.51)
.setGammaShape(1e10)
val lda = new LDA().setK(2)
.setDocConcentration(0.01)
Expand Down