@@ -29,7 +29,7 @@ import org.apache.spark.rdd._
2929import org .apache .spark .SparkContext ._
3030import org .apache .spark .mllib .linalg .{Vector , Vectors }
3131import org .apache .spark .HashPartitioner
32-
32+ import org . apache . spark . storage . StorageLevel
3333/**
3434 * Entry in vocabulary
3535 */
@@ -215,18 +215,18 @@ class Word2Vec(
215215 val sc = dataset.context
216216
217217 val expTable = sc.broadcast(createExpTable())
218- val V = sc.broadcast(vocab)
219- val VHash = sc.broadcast(vocabHash)
218+ val bcVocab = sc.broadcast(vocab)
219+ val bcVocabHash = sc.broadcast(vocabHash)
220220
221- val sentences = words.mapPartitions {
221+ val sentences : RDD [ Array [ Int ]] = words.mapPartitions {
222222 iter => { new Iterator [Array [Int ]] {
223223 def hasNext = iter.hasNext
224224
225225 def next = {
226226 var sentence = new ArrayBuffer [Int ]
227227 var sentenceLength = 0
228228 while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH ) {
229- val word = VHash .value.get(iter.next)
229+ val word = bcVocabHash .value.get(iter.next)
230230 word match {
231231 case Some (w) => {
232232 sentence += w
@@ -278,14 +278,14 @@ class Word2Vec(
278278 val neu1e = new Array [Double ](layer1Size)
279279 // Hierarchical softmax
280280 var d = 0
281- while (d < vocab (word).codeLen) {
282- val l2 = vocab (word).point(d) * layer1Size
281+ while (d < bcVocab.value (word).codeLen) {
282+ val l2 = bcVocab.value (word).point(d) * layer1Size
283283 // Propagate hidden -> output
284284 var f = blas.ddot(layer1Size, syn0, l1, 1 , syn1, l2, 1 )
285285 if (f > - MAX_EXP && f < MAX_EXP ) {
286286 val ind = ((f + MAX_EXP ) * (EXP_TABLE_SIZE / MAX_EXP / 2.0 )).toInt
287287 f = expTable.value(ind)
288- val g = (1 - vocab (word).code(d) - f) * alpha
288+ val g = (1 - bcVocab.value (word).code(d) - f) * alpha
289289 blas.daxpy(layer1Size, g, syn1, l2, 1 , neu1e, 0 , 1 )
290290 blas.daxpy(layer1Size, g, syn0, l1, 1 , syn1, l2, 1 )
291291 }
@@ -310,17 +310,21 @@ class Word2Vec(
310310 syn0Global = aggSyn0
311311 syn1Global = aggSyn1
312312 }
313+ newSentences.unpersist()
314+
313315 val wordMap = new Array [(String , Array [Double ])](vocabSize)
314316 var i = 0
315317 while (i < vocabSize) {
316- val word = vocab (i).word
318+ val word = bcVocab.value (i).word
317319 val vector = new Array [Double ](layer1Size)
318320 Array .copy(syn0Global, i * layer1Size, vector, 0 , layer1Size)
319321 wordMap(i) = (word, vector)
320322 i += 1
321323 }
322324 val modelRDD = sc.parallelize(wordMap, modelPartitionNum)
323- .partitionBy(new HashPartitioner (modelPartitionNum)).cache()
325+ .partitionBy(new HashPartitioner (modelPartitionNum))
326+ .persist(StorageLevel .MEMORY_AND_DISK )
327+
324328 new Word2VecModel (modelRDD)
325329 }
326330}
0 commit comments