Skip to content

Commit 7efbb6f

Browse files
author
Liquan Pei
committed
use broadcast version of vocab in aggregate
1 parent 6bcc8be commit 7efbb6f

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ import org.apache.spark.rdd._
2929
import org.apache.spark.SparkContext._
3030
import org.apache.spark.mllib.linalg.{Vector, Vectors}
3131
import org.apache.spark.HashPartitioner
32-
32+
import org.apache.spark.storage.StorageLevel
3333
/**
3434
* Entry in vocabulary
3535
*/
@@ -215,18 +215,18 @@ class Word2Vec(
215215
val sc = dataset.context
216216

217217
val expTable = sc.broadcast(createExpTable())
218-
val V = sc.broadcast(vocab)
219-
val VHash = sc.broadcast(vocabHash)
218+
val bcVocab = sc.broadcast(vocab)
219+
val bcVocabHash = sc.broadcast(vocabHash)
220220

221-
val sentences = words.mapPartitions {
221+
val sentences: RDD[Array[Int]] = words.mapPartitions {
222222
iter => { new Iterator[Array[Int]] {
223223
def hasNext = iter.hasNext
224224

225225
def next = {
226226
var sentence = new ArrayBuffer[Int]
227227
var sentenceLength = 0
228228
while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) {
229-
val word = VHash.value.get(iter.next)
229+
val word = bcVocabHash.value.get(iter.next)
230230
word match {
231231
case Some(w) => {
232232
sentence += w
@@ -278,14 +278,14 @@ class Word2Vec(
278278
val neu1e = new Array[Double](layer1Size)
279279
// Hierarchical softmax
280280
var d = 0
281-
while (d < vocab(word).codeLen) {
282-
val l2 = vocab(word).point(d) * layer1Size
281+
while (d < bcVocab.value(word).codeLen) {
282+
val l2 = bcVocab.value(word).point(d) * layer1Size
283283
// Propagate hidden -> output
284284
var f = blas.ddot(layer1Size, syn0, l1, 1, syn1, l2, 1)
285285
if (f > -MAX_EXP && f < MAX_EXP) {
286286
val ind = ((f + MAX_EXP) * (EXP_TABLE_SIZE / MAX_EXP / 2.0)).toInt
287287
f = expTable.value(ind)
288-
val g = (1 - vocab(word).code(d) - f) * alpha
288+
val g = (1 - bcVocab.value(word).code(d) - f) * alpha
289289
blas.daxpy(layer1Size, g, syn1, l2, 1, neu1e, 0, 1)
290290
blas.daxpy(layer1Size, g, syn0, l1, 1, syn1, l2, 1)
291291
}
@@ -310,17 +310,21 @@ class Word2Vec(
310310
syn0Global = aggSyn0
311311
syn1Global = aggSyn1
312312
}
313+
newSentences.unpersist()
314+
313315
val wordMap = new Array[(String, Array[Double])](vocabSize)
314316
var i = 0
315317
while (i < vocabSize) {
316-
val word = vocab(i).word
318+
val word = bcVocab.value(i).word
317319
val vector = new Array[Double](layer1Size)
318320
Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size)
319321
wordMap(i) = (word, vector)
320322
i += 1
321323
}
322324
val modelRDD = sc.parallelize(wordMap, modelPartitionNum)
323-
.partitionBy(new HashPartitioner(modelPartitionNum)).cache()
325+
.partitionBy(new HashPartitioner(modelPartitionNum))
326+
.persist(StorageLevel.MEMORY_AND_DISK)
327+
324328
new Word2VecModel(modelRDD)
325329
}
326330
}

0 commit comments

Comments
 (0)