1717
1818package org .apache .spark .mllib .feature
1919
20- import scala .util .Random
21- import scala .collection .mutable .ArrayBuffer
2220import scala .collection .mutable
21+ import scala .collection .mutable .ArrayBuffer
22+ import scala .util .Random
2323
2424import com .github .fommil .netlib .BLAS .{getInstance => blas }
25-
26- import org .apache .spark .annotation .Experimental
27- import org .apache .spark .Logging
28- import org .apache .spark .rdd ._
25+ import org .apache .spark .{HashPartitioner , Logging }
2926import org .apache .spark .SparkContext ._
27+ import org .apache .spark .annotation .Experimental
3028import org .apache .spark .mllib .linalg .{Vector , Vectors }
31- import org .apache .spark .HashPartitioner
32- import org .apache .spark .storage .StorageLevel
3329import org .apache .spark .mllib .rdd .RDDFunctions ._
30+ import org .apache .spark .rdd ._
31+ import org .apache .spark .storage .StorageLevel
32+
3433/**
3534 * Entry in vocabulary
3635 */
@@ -52,7 +51,7 @@ private case class VocabWord(
5251 *
5352 * We used skip-gram model in our implementation and hierarchical softmax
5453 * method to train the model. The variable names in the implementation
55- * mathes the original C implementation.
54+ * matches the original C implementation.
5655 *
5756 * For original C implementation, see https://code.google.com/p/word2vec/
5857 * For research papers, see
@@ -61,34 +60,41 @@ private case class VocabWord(
6160 * Distributed Representations of Words and Phrases and their Compositionality.
6261 * @param size vector dimension
6362 * @param startingAlpha initial learning rate
64- * @param window context words from [-window, window]
65- * @param minCount minimum frequncy to consider a vocabulary word
66- * @param parallelisum number of partitions to run Word2Vec
63+ * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
64+ * @param numIterations number of iterations to run, should be smaller than or equal to parallelism
6765 */
6866@ Experimental
6967class Word2Vec (
7068 val size : Int ,
7169 val startingAlpha : Double ,
72- val window : Int ,
73- val minCount : Int ,
74- val parallelism : Int = 1 ,
75- val numIterations : Int = 1 )
76- extends Serializable with Logging {
77-
70+ val parallelism : Int ,
71+ val numIterations : Int ) extends Serializable with Logging {
72+
73+ /**
74+ * Word2Vec with a single thread.
75+ */
76+ def this (size : Int , startingAlpha : Int ) = this (size, startingAlpha, 1 , 1 )
77+
7878 private val EXP_TABLE_SIZE = 1000
7979 private val MAX_EXP = 6
8080 private val MAX_CODE_LENGTH = 40
8181 private val MAX_SENTENCE_LENGTH = 1000
8282 private val layer1Size = size
8383 private val modelPartitionNum = 100
84-
84+
85+ /** context words from [-window, window] */
86+ private val window = 5
87+
88+ /** minimum frequency to consider a vocabulary word */
89+ private val minCount = 5
90+
8591 private var trainWordsCount = 0
8692 private var vocabSize = 0
8793 private var vocab : Array [VocabWord ] = null
8894 private var vocabHash = mutable.HashMap .empty[String , Int ]
8995 private var alpha = startingAlpha
9096
91- private def learnVocab (words: RDD [String ]){
97+ private def learnVocab (words: RDD [String ]): Unit = {
9298 vocab = words.map(w => (w, 1 ))
9399 .reduceByKey(_ + _)
94100 .map(x => VocabWord (
@@ -99,7 +105,7 @@ class Word2Vec(
99105 0 ))
100106 .filter(_.cn >= minCount)
101107 .collect()
102- .sortWith((a, b)=> a.cn > b.cn)
108+ .sortWith((a, b) => a.cn > b.cn)
103109
104110 vocabSize = vocab.length
105111 var a = 0
@@ -111,22 +117,18 @@ class Word2Vec(
111117 logInfo(" trainWordsCount = " + trainWordsCount)
112118 }
113119
114- private def learnVocabPerPartition (words: RDD [String ]) {
115-
116- }
117-
118- private def createExpTable (): Array [Double ] = {
119- val expTable = new Array [Double ](EXP_TABLE_SIZE )
120+ private def createExpTable (): Array [Float ] = {
121+ val expTable = new Array [Float ](EXP_TABLE_SIZE )
120122 var i = 0
121123 while (i < EXP_TABLE_SIZE ) {
122124 val tmp = math.exp((2.0 * i / EXP_TABLE_SIZE - 1.0 ) * MAX_EXP )
123- expTable(i) = tmp / (tmp + 1 )
125+ expTable(i) = ( tmp / (tmp + 1.0 )).toFloat
124126 i += 1
125127 }
126128 expTable
127129 }
128130
129- private def createBinaryTree () {
131+ private def createBinaryTree (): Unit = {
130132 val count = new Array [Long ](vocabSize * 2 + 1 )
131133 val binary = new Array [Int ](vocabSize * 2 + 1 )
132134 val parentNode = new Array [Int ](vocabSize * 2 + 1 )
@@ -208,8 +210,7 @@ class Word2Vec(
208210 * @param dataset an RDD of words
209211 * @return a Word2VecModel
210212 */
211-
212- def fit [S <: Iterable [String ]](dataset: RDD [S ]): Word2VecModel = {
213+ def fit [S <: Iterable [String ]](dataset : RDD [S ]): Word2VecModel = {
213214
214215 val words = dataset.flatMap(x => x)
215216
@@ -223,39 +224,37 @@ class Word2Vec(
223224 val bcVocab = sc.broadcast(vocab)
224225 val bcVocabHash = sc.broadcast(vocabHash)
225226
226- val sentences : RDD [Array [Int ]] = words.mapPartitions {
227- iter => { new Iterator [Array [Int ]] {
228- def hasNext = iter.hasNext
229-
230- def next = {
231- var sentence = new ArrayBuffer [Int ]
232- var sentenceLength = 0
233- while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH ) {
234- val word = bcVocabHash.value.get(iter.next)
235- word match {
236- case Some (w) => {
237- sentence += w
238- sentenceLength += 1
239- }
240- case None =>
241- }
227+ val sentences : RDD [Array [Int ]] = words.mapPartitions { iter =>
228+ new Iterator [Array [Int ]] {
229+ def hasNext : Boolean = iter.hasNext
230+
231+ def next (): Array [Int ] = {
232+ var sentence = new ArrayBuffer [Int ]
233+ var sentenceLength = 0
234+ while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH ) {
235+ val word = bcVocabHash.value.get(iter.next())
236+ word match {
237+ case Some (w) =>
238+ sentence += w
239+ sentenceLength += 1
240+ case None =>
242241 }
243- sentence.toArray
244242 }
243+ sentence.toArray
245244 }
246245 }
247246 }
248247
249248 val newSentences = sentences.repartition(parallelism).cache()
250- var syn0Global
251- = Array .fill[Double ](vocabSize * layer1Size)((Random .nextDouble - 0.5 ) / layer1Size)
252- var syn1Global = new Array [Double ](vocabSize * layer1Size)
249+ var syn0Global =
250+ Array .fill[Float ](vocabSize * layer1Size)((Random .nextFloat() - 0.5f ) / layer1Size)
251+ var syn1Global = new Array [Float ](vocabSize * layer1Size)
253252
254253 for (iter <- 1 to numIterations) {
255254 val (aggSyn0, aggSyn1, _, _) =
256- // TODO: broadcast temp instead of serializing it directly
255+ // TODO: broadcast temp instead of serializing it directly
257256 // or initialize the model in each executor
258- newSentences.treeAggregate((syn0Global.clone() , syn1Global.clone() , 0 , 0 ))(
257+ newSentences.treeAggregate((syn0Global, syn1Global, 0 , 0 ))(
259258 seqOp = (c, v) => (c, v) match {
260259 case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
261260 var lwc = lastWordCount
@@ -280,23 +279,23 @@ class Word2Vec(
280279 if (c >= 0 && c < sentence.size) {
281280 val lastWord = sentence(c)
282281 val l1 = lastWord * layer1Size
283- val neu1e = new Array [Double ](layer1Size)
282+ val neu1e = new Array [Float ](layer1Size)
284283 // Hierarchical softmax
285284 var d = 0
286285 while (d < bcVocab.value(word).codeLen) {
287286 val l2 = bcVocab.value(word).point(d) * layer1Size
288287 // Propagate hidden -> output
289- var f = blas.ddot (layer1Size, syn0, l1, 1 , syn1, l2, 1 )
288+ var f = blas.sdot (layer1Size, syn0, l1, 1 , syn1, l2, 1 )
290289 if (f > - MAX_EXP && f < MAX_EXP ) {
291290 val ind = ((f + MAX_EXP ) * (EXP_TABLE_SIZE / MAX_EXP / 2.0 )).toInt
292291 f = expTable.value(ind)
293- val g = (1 - bcVocab.value(word).code(d) - f) * alpha
294- blas.daxpy (layer1Size, g, syn1, l2, 1 , neu1e, 0 , 1 )
295- blas.daxpy (layer1Size, g, syn0, l1, 1 , syn1, l2, 1 )
292+ val g = (( 1 - bcVocab.value(word).code(d) - f) * alpha).toFloat
293+ blas.saxpy (layer1Size, g, syn1, l2, 1 , neu1e, 0 , 1 )
294+ blas.saxpy (layer1Size, g, syn0, l1, 1 , syn1, l2, 1 )
296295 }
297296 d += 1
298297 }
299- blas.daxpy (layer1Size, 1.0 , neu1e, 0 , 1 , syn0, l1, 1 )
298+ blas.saxpy (layer1Size, 1.0f , neu1e, 0 , 1 , syn0, l1, 1 )
300299 }
301300 }
302301 a += 1
@@ -308,24 +307,24 @@ class Word2Vec(
308307 combOp = (c1, c2) => (c1, c2) match {
309308 case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
310309 val n = syn0_1.length
311- val weight1 = 1.0 * wc_1 / (wc_1 + wc_2)
312- val weight2 = 1.0 * wc_2 / (wc_1 + wc_2)
313- blas.dscal (n, weight1, syn0_1, 1 )
314- blas.dscal (n, weight1, syn1_1, 1 )
315- blas.daxpy (n, weight2, syn0_2, 1 , syn0_1, 1 )
316- blas.daxpy (n, weight2, syn1_2, 1 , syn1_1, 1 )
310+ val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
311+ val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
312+ blas.sscal (n, weight1, syn0_1, 1 )
313+ blas.sscal (n, weight1, syn1_1, 1 )
314+ blas.saxpy (n, weight2, syn0_2, 1 , syn0_1, 1 )
315+ blas.saxpy (n, weight2, syn1_2, 1 , syn1_1, 1 )
317316 (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
318317 })
319318 syn0Global = aggSyn0
320319 syn1Global = aggSyn1
321320 }
322321 newSentences.unpersist()
323322
324- val wordMap = new Array [(String , Array [Double ])](vocabSize)
323+ val wordMap = new Array [(String , Array [Float ])](vocabSize)
325324 var i = 0
326325 while (i < vocabSize) {
327326 val word = bcVocab.value(i).word
328- val vector = new Array [Double ](layer1Size)
327+ val vector = new Array [Float ](layer1Size)
329328 Array .copy(syn0Global, i * layer1Size, vector, 0 , layer1Size)
330329 wordMap(i) = (word, vector)
331330 i += 1
@@ -341,15 +340,15 @@ class Word2Vec(
341340/**
342341* Word2Vec model
343342*/
344- class Word2VecModel (private val model : RDD [(String , Array [Double ])]) extends Serializable {
343+ class Word2VecModel (private val model : RDD [(String , Array [Float ])]) extends Serializable {
345344
346- private def cosineSimilarity (v1 : Array [Double ], v2 : Array [Double ]): Double = {
345+ private def cosineSimilarity (v1 : Array [Float ], v2 : Array [Float ]): Double = {
347346 require(v1.length == v2.length, " Vectors should have the same length" )
348347 val n = v1.length
349- val norm1 = blas.dnrm2 (n, v1, 1 )
350- val norm2 = blas.dnrm2 (n, v2, 1 )
348+ val norm1 = blas.snrm2 (n, v1, 1 )
349+ val norm2 = blas.snrm2 (n, v2, 1 )
351350 if (norm1 == 0 || norm2 == 0 ) return 0.0
352- blas.ddot (n, v1, 1 , v2,1 ) / norm1 / norm2
351+ blas.sdot (n, v1, 1 , v2,1 ) / norm1 / norm2
353352 }
354353
355354 /**
@@ -360,9 +359,9 @@ class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Ser
360359 def transform (word : String ): Vector = {
361360 val result = model.lookup(word)
362361 if (result.isEmpty) {
363- throw new IllegalStateException (s " ${ word} not in vocabulary " )
362+ throw new IllegalStateException (s " $word not in vocabulary " )
364363 }
365- else Vectors .dense(result(0 ))
364+ else Vectors .dense(result(0 ).map(_.toDouble) )
366365 }
367366
368367 /**
@@ -394,7 +393,7 @@ class Word2VecModel (private val model:RDD[(String, Array[Double])]) extends Ser
394393 def findSynonyms (vector : Vector , num : Int ): Array [(String , Double )] = {
395394 require(num > 0 , " Number of similar words should > 0" )
396395 val topK = model.map { case (w, vec) =>
397- (cosineSimilarity(vector.toArray, vec), w) }
396+ (cosineSimilarity(vector.toArray.map(_.toFloat) , vec), w) }
398397 .sortByKey(ascending = false )
399398 .take(num + 1 )
400399 .map(_.swap)
@@ -410,18 +409,16 @@ object Word2Vec{
410409 * @param input RDD of words
411410 * @param size vector dimension
412411 * @param startingAlpha initial learning rate
413- * @param window context words from [-window, window]
414- * @param minCount minimum frequncy to consider a vocabulary word
415- * @return Word2Vec model
416- */
412+ * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
413+ * @param numIterations number of iterations, should be smaller than or equal to parallelism
414+ * @return Word2Vec model
415+ */
417416 def train [S <: Iterable [String ]](
418417 input : RDD [S ],
419418 size : Int ,
420419 startingAlpha : Double ,
421- window : Int ,
422- minCount : Int ,
423420 parallelism : Int = 1 ,
424421 numIterations: Int = 1 ): Word2VecModel = {
425- new Word2Vec (size,startingAlpha, window, minCount, parallelism, numIterations).fit[S ](input)
422+ new Word2Vec (size,startingAlpha, parallelism, numIterations).fit[S ](input)
426423 }
427424}
0 commit comments