Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.clustering.topicmodeling

import java.util.Random

import org.apache.spark.SparkContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.mllib.clustering.topicmodeling.regulaizers.{TopicsRegularizer, MatrixInPlaceModification}
import org.apache.spark.mllib.feature.Document
import org.apache.spark.rdd.RDD

import breeze.linalg._

private[topicmodeling] trait AbstractPLSA[DocumentParameterType <: DocumentParameters,
GlobalParameterType <: GlobalParameters,
GlobalCounterType <: GlobalCounters]
extends TopicModel[DocumentParameterType, GlobalParameterType] with MatrixInPlaceModification {
protected val numberOfTopics: Int
protected val random: Random
protected val topicRegularizer: TopicsRegularizer
protected val sc: SparkContext

protected def generalizedPerplexity(topicsBC: Broadcast[Array[Array[Float]]],
parameters: RDD[DocumentParameterType],
collectionLength: Int,
wordGivenModel: DocumentParameterType => (Int, Int) => Float) = {
math.exp(-(parameters.aggregate(0f)((thatOne, otherOne) =>
thatOne + singleDocumentLikelihood(otherOne, topicsBC, wordGivenModel(otherOne)),
(thatOne, otherOne) => thatOne + otherOne) + topicRegularizer(topicsBC.value)) /
collectionLength)
}

protected def getAlphabetSize(documents: RDD[Document]) = documents.first().alphabetSize

protected def getCollectionLength(documents: RDD[Document]) =
documents.map(doc => sum(doc.tokens)).reduce(_ + _)

protected def singleDocumentLikelihood(parameter: DocumentParameters,
topicsBC: Broadcast[Array[Array[Float]]],
wordGivenModel: ((Int, Int) => Float)) = {
sum(parameter.document.tokens.mapActivePairs(wordGivenModel)) +
parameter.priorThetaLogProbability
}

protected def probabilityOfWordGivenTopic(word: Int,
parameter: DocumentParameters,
topicsBC: Broadcast[Array[Array[Float]]]) = {
var underLog = 0f
for (topic <- 0 until numberOfTopics) {
underLog += parameter.theta(topic) * topicsBC.value(topic)(word)
}
underLog
}

protected def getInitialTopics(alphabetSize: Int) = {
val topics = Array.fill[Float](numberOfTopics, alphabetSize)(random.nextFloat)
normalize(topics)
sc.broadcast(topics)
}

protected def getTopics(parameters: RDD[DocumentParameterType],
alphabetSize: Int,
oldTopics: Broadcast[Array[Array[Float]]],
globalCounters: GlobalCounterType,
foldingIn : Boolean) = {
if (foldingIn) oldTopics
else {
val newTopicCnt: Array[Array[Float]] = globalCounters.wordsFromTopics

topicRegularizer.regularize(newTopicCnt, oldTopics.value)
normalize(newTopicCnt)

sc.broadcast(newTopicCnt)
}
}

private def normalize(matrix: Array[Array[Float]]) = {
matrix.foreach(array => {
val sum = array.sum
shift(array, (arr, i) => arr(i) /= sum)
})
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.clustering.topicmodeling

import breeze.linalg.SparseVector
import org.apache.spark.mllib.clustering.topicmodeling.regulaizers.DocumentOverTopicDistributionRegularizer
import org.apache.spark.mllib.feature.Document

/**
* the class contains document parameter in PLSA model
* @param document
* @param theta the distribution over topics
* @param regularizer
*/
class DocumentParameters(val document: Document,
val theta: Array[Float],
private val regularizer: DocumentOverTopicDistributionRegularizer)
extends Serializable {
private def getZ(topics: Array[Array[Float]]) = {
val numberOfTopics = topics.size

document.tokens.mapActivePairs { case (word, n) =>
(0 until numberOfTopics).foldLeft(0f)((sum, topic) =>
sum + topics(topic)(word) * theta(topic))
}
}

private[topicmodeling] def wordsFromTopics(topics: Array[Array[Float]]):
Array[SparseVector[Float]] = {
val Z = getZ(topics)

wordsToTopicCnt(topics, Z)
}

private[topicmodeling] def wordsToTopicCnt(topics: Array[Array[Float]],
Z: SparseVector[Float]): Array[SparseVector[Float]] = {
val array = Array.ofDim[SparseVector[Float]](theta.size)
forWithIndex(theta)((topicWeight, topicNum) =>
array(topicNum) = document.tokens.mapActivePairs { case (word,
num) => num * topics(topicNum)(word) * topicWeight / Z(word)
}
)
array
}

private def forWithIndex(array: Array[Float])(operation: (Float, Int) => Unit) {
var i = 0
val size = array.size
while (i < size) {
operation(array(i), i)
i += 1
}
}

private[topicmodeling] def assignNewTheta(topics: Array[Array[Float]],
Z: SparseVector[Float]) {
val newTheta: Array[Float] = {
val array = Array.ofDim[Float](theta.size)
forWithIndex(theta)((weight, topicNum) => array(topicNum) = weight * document.tokens
.activeIterator.foldLeft(0f) { case (sum, (word, wordNum)) =>
sum + wordNum * topics(topicNum)(word) / Z(word)
})
array
}
regularizer.regularize(newTheta, theta)

val newThetaSum = newTheta.sum

forWithIndex(newTheta)((wordsNum, topicNum) => theta(topicNum) = wordsNum / newThetaSum)

}

private[topicmodeling] def getNewTheta(topics: Array[Array[Float]]) = {
val Z = getZ(topics)
assignNewTheta(topics, Z)

this
}

private[topicmodeling] def priorThetaLogProbability = regularizer(theta)

}


private[topicmodeling] object DocumentParameters extends SparseVectorFasterSum {

def apply(document: Document,
numberOfTopics: Int,
regularizer: DocumentOverTopicDistributionRegularizer) = {
val theta = getTheta(numberOfTopics)
new DocumentParameters(document, theta, regularizer)
}

private def getTheta(numberOfTopics: Int) = {
Array.fill[Float](numberOfTopics)(1f / numberOfTopics)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.mllib.clustering.topicmodeling

/**
* contains global counters in PLSA model -- holds n_{tw} (Vorontov's notation) counters and
* alphabet size
*
* @param wordsFromTopics
* @param alphabetSize
*/
private[topicmodeling] class GlobalCounters(val wordsFromTopics: Array[Array[Float]],
val alphabetSize: Int) extends Serializable {

/**
* merges two GlobalParameters into a single one
* @param that other GlobalParameters
* @return GlobalParameters
*/
private[topicmodeling] def + (that: GlobalCounters) = {
wordsFromTopics.zip(that.wordsFromTopics).foreach { case (thisOne, otherOne) =>
(0 until alphabetSize).foreach(i => thisOne(i) += otherOne(i))
}

new GlobalCounters(wordsFromTopics, alphabetSize)
}

/**
* calculates and add local parameters to global parameters
* @param that DocumentParameters.
* @param topics words by topics distribution
* @param alphabetSize number of unique words
* @return GlobalParameters
*/
private[topicmodeling] def add(that: DocumentParameters,
topics: Array[Array[Float]],
alphabetSize: Int) = {

val wordsFromTopic = that.wordsFromTopics(topics)

wordsFromTopic.zip(wordsFromTopics).foreach { case (topic, words) =>
topic.activeIterator.foreach{ case (word, num) =>
words(word) += num
}
}
this
}
}

private[topicmodeling] object GlobalCounters {
def apply(topicNum: Int, alphabetSize: Int) =
new GlobalCounters(Array.ofDim[Float](topicNum, alphabetSize), alphabetSize)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/


package org.apache.spark.mllib.clustering.topicmodeling

/**
* Holds global parameters of PLSA model -- \Phi matrix (topics over words distribution) and
* alphabet size
*
* @param phi -- distribution of topics over words
* @param alphabetSize
*/
class GlobalParameters(val phi : Array[Array[Float]], val alphabetSize : Int)
Loading