@@ -87,7 +87,7 @@ class Word2Vec(
8787 private var vocabHash = mutable.HashMap .empty[String , Int ]
8888 private var alpha = startingAlpha
8989
90- private def learnVocab (words: RDD [String ]) {
90+ private def learnVocab (words: RDD [String ]){
9191 vocab = words.map(w => (w, 1 ))
9292 .reduceByKey(_ + _)
9393 .map(x => VocabWord (
@@ -110,6 +110,10 @@ class Word2Vec(
110110 logInfo(" trainWordsCount = " + trainWordsCount)
111111 }
112112
113+ private def learnVocabPerPartition (words: RDD [String ]) {
114+
115+ }
116+
113117 private def createExpTable (): Array [Double ] = {
114118 val expTable = new Array [Double ](EXP_TABLE_SIZE )
115119 var i = 0
@@ -303,8 +307,12 @@ class Word2Vec(
303307 combOp = (c1, c2) => (c1, c2) match {
304308 case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
305309 val n = syn0_1.length
306- blas.daxpy(n, 1.0 , syn0_2, 1 , syn0_1, 1 )
307- blas.daxpy(n, 1.0 , syn1_2, 1 , syn1_1, 1 )
310+ val weight1 = 1.0 * wc_1 / (wc_1 + wc_2)
311+ val weight2 = 1.0 * wc_2 / (wc_1 + wc_2)
312+ blas.dscal(n, weight1, syn0_1, 1 )
313+ blas.dscal(n, weight1, syn1_1, 1 )
314+ blas.daxpy(n, weight2, syn0_2, 1 , syn0_1, 1 )
315+ blas.daxpy(n, weight2, syn1_2, 1 , syn1_1, 1 )
308316 (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
309317 })
310318 syn0Global = aggSyn0
0 commit comments