11/* 
2- * Licensed to the Apache Software Foundation (ASF) under one or more 
3- * contributor license agreements.  See the NOTICE file distributed with 
4- * this work for additional information regarding copyright ownership. 
5- * The ASF licenses this file to You under the Apache License, Version 2.0 
6- * Add a comment to this line 
7- * (the "License"); you may not use this file except in compliance with 
8- * the License.  You may obtain a copy of the License at 
9- * 
10- *    http://www.apache.org/licenses/LICENSE-2.0 
11- * 
12- * Unless required by applicable law or agreed to in writing, software 
13- * distributed under the License is distributed on an "AS IS" BASIS, 
14- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
15- * See the License for the specific language governing permissions and 
16- * limitations under the License. 
17- */ 
2+  * Licensed to the Apache Software Foundation (ASF) under one or more 
3+  * contributor license agreements.  See the NOTICE file distributed with 
4+  * this work for additional information regarding copyright ownership. 
5+  * The ASF licenses this file to You under the Apache License, Version 2.0 
6+  * (the "License"); you may not use this file except in compliance with 
7+  * the License.  You may obtain a copy of the License at 
8+  * 
9+  *    http://www.apache.org/licenses/LICENSE-2.0 
10+  * 
11+  * Unless required by applicable law or agreed to in writing, software 
12+  * distributed under the License is distributed on an "AS IS" BASIS, 
13+  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
14+  * See the License for the specific language governing permissions and 
15+  * limitations under the License. 
16+  */  
1817
1918package  org .apache .spark .mllib .feature 
2019
21- import  scala .util .{ Random   =>   Random } 
20+ import  scala .util .Random 
2221import  scala .collection .mutable .ArrayBuffer 
2322import  scala .collection .mutable 
2423
2524import  com .github .fommil .netlib .BLAS .{getInstance  =>  blas }
2625
27- import  org .apache .spark ._ 
26+ import  org .apache .spark .annotation .Experimental 
27+ import  org .apache .spark .Logging 
2828import  org .apache .spark .rdd ._ 
2929import  org .apache .spark .SparkContext ._ 
30- import  org .apache .spark .mllib .linalg .Vector 
30+ import  org .apache .spark .mllib .linalg .{ Vector ,  Vectors } 
3131import  org .apache .spark .HashPartitioner 
3232
3333/** 
@@ -42,8 +42,27 @@ private case class VocabWord(
4242)
4343
4444/** 
45-  *  Vector representation of word 
45+  * :: Experimental :: 
46+  * Word2Vec creates vector representation of words in a text corpus. 
47+  * The algorithm first constructs a vocabulary from the corpus 
48+  * and then learns vector representation of words in the vocabulary.  
49+  * The vector representation can be used as features in  
50+  * natural language processing and machine learning algorithms. 
51+  *  
52+  * We used skip-gram model in our implementation and hierarchical softmax  
53+  * method to train the model.  
54+  * 
55+  * For original C implementation, see https://code.google.com/p/word2vec/  
56+  * For research papers, see  
57+  * Efficient Estimation of Word Representations in Vector Space 
58+  * and  
59+  * Distributed Representations of Words and Phrases and their Compositionality 
60+  * @param  size  vector dimension 
61+  * @param  startingAlpha  initial learning rate 
62+  * @param  window  context words from [-window, window] 
63+  * @param  minCount  minimum frequncy to consider a vocabulary word 
4664 */  
65+ @ Experimental 
4766class  Word2Vec (
4867    val  size :  Int ,
4968    val  startingAlpha :  Double ,
@@ -64,11 +83,15 @@ class Word2Vec(
6483  private  var  vocabHash  =  mutable.HashMap .empty[String , Int ]
6584  private  var  alpha  =  startingAlpha
6685
67-   private  def  learnVocab (dataset : RDD [String ]) {
68-     vocab =  dataset.flatMap(line =>  line.split("  " 
69-       .map(w =>  (w, 1 ))
86+   private  def  learnVocab (words: RDD [String ]) {
87+     vocab =  words.map(w =>  (w, 1 ))
7088      .reduceByKey(_ +  _)
71-       .map(x =>  VocabWord (x._1, x._2, new  Array [Int ](MAX_CODE_LENGTH ), new  Array [Int ](MAX_CODE_LENGTH ), 0 ))
89+       .map(x =>  VocabWord (
90+         x._1, 
91+         x._2, 
92+         new  Array [Int ](MAX_CODE_LENGTH ), 
93+         new  Array [Int ](MAX_CODE_LENGTH ), 
94+         0 ))
7295      .filter(_.cn >=  minCount)
7396      .collect()
7497      .sortWith((a, b)=>  a.cn >  b.cn)
@@ -172,15 +195,16 @@ class Word2Vec(
172195  }
173196
174197  /**  
175-    * Computes the vector representation of each word in  
176-    * vocabulary 
177-    * @param  dataset  an RDD of strings 
198+    * Computes the vector representation of each word in vocabulary. 
199+    * @param  dataset  an RDD of words 
178200   * @return  a Word2VecModel 
179201   */  
180202
181-   def  fit (dataset: RDD [String ]):  Word2VecModel  =  {
203+   def  fit [ S   <:   Iterable [ String ]] (dataset: RDD [S ]):  Word2VecModel  =  {
182204
183-     learnVocab(dataset)
205+     val  words  =  dataset.flatMap(x =>  x)
206+ 
207+     learnVocab(words)
184208
185209    createBinaryTree()
186210
@@ -190,9 +214,10 @@ class Word2Vec(
190214    val  V  =  sc.broadcast(vocab)
191215    val  VHash  =  sc.broadcast(vocabHash)
192216
193-     val  sentences  =  dataset.flatMap(line  =>  line.split( "   " )) .mapPartitions {
217+     val  sentences  =  words .mapPartitions {
194218      iter =>  { new  Iterator [Array [Int ]] {
195219          def  hasNext  =  iter.hasNext
220+           
196221          def  next  =  {
197222            var  sentence  =  new  ArrayBuffer [Int ]
198223            var  sentenceLength  =  0 
@@ -215,7 +240,8 @@ class Word2Vec(
215240    val  newSentences  =  sentences.repartition(1 ).cache()
216241    val  temp  =  Array .fill[Double ](vocabSize *  layer1Size)((Random .nextDouble -  0.5 ) /  layer1Size)
217242    val  (aggSyn0, _, _, _) = 
218-       //  TODO: broadcast temp instead of serializing it directly or initialize the model in each executor
243+       //  TODO: broadcast temp instead of serializing it directly 
244+       //  or initialize the model in each executor
219245      newSentences.aggregate((temp.clone(), new  Array [Double ](vocabSize *  layer1Size), 0 , 0 ))(
220246        seqOp =  (c, v) =>  (c, v) match  { case  ((syn0, syn1, lastWordCount, wordCount), sentence) => 
221247          var  lwc  =  lastWordCount
@@ -241,7 +267,7 @@ class Word2Vec(
241267                  val  lastWord  =  sentence(c)
242268                  val  l1  =  lastWord *  layer1Size
243269                  val  neu1e  =  new  Array [Double ](layer1Size)
244-                   // HS 
270+                   //  Hierarchical softmax  
245271                  var  d  =  0 
246272                  while  (d <  vocab(word).codeLen) {
247273                    val  l2  =  vocab(word).point(d) *  layer1Size
@@ -265,11 +291,12 @@ class Word2Vec(
265291          }
266292          (syn0, syn1, lwc, wc)
267293        },
268-         combOp =  (c1, c2) =>  (c1, c2) match  { case  ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => 
269-           val  n  =  syn0_1.length
270-           blas.daxpy(n, 1.0 , syn0_2, 1 , syn0_1, 1 )
271-           blas.daxpy(n, 1.0 , syn1_2, 1 , syn1_1, 1 )
272-           (syn0_1, syn0_2, lwc_1 +  lwc_2, wc_1 +  wc_2)
294+         combOp =  (c1, c2) =>  (c1, c2) match  { 
295+           case  ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => 
296+             val  n  =  syn0_1.length
297+             blas.daxpy(n, 1.0 , syn0_2, 1 , syn0_1, 1 )
298+             blas.daxpy(n, 1.0 , syn1_2, 1 , syn1_1, 1 )
299+             (syn0_1, syn0_2, lwc_1 +  lwc_2, wc_1 +  wc_2)
273300        })
274301
275302    val  wordMap  =  new  Array [(String , Array [Double ])](vocabSize)
@@ -281,19 +308,18 @@ class Word2Vec(
281308      wordMap(i) =  (word, vector)
282309      i +=  1 
283310    }
284-     val  modelRDD  =  sc.parallelize(wordMap, modelPartitionNum).partitionBy(new  HashPartitioner (modelPartitionNum))
311+     val  modelRDD  =  sc.parallelize(wordMap, modelPartitionNum)
312+       .partitionBy(new  HashPartitioner (modelPartitionNum))
285313    new  Word2VecModel (modelRDD)
286314  }
287315}
288316
289317/** 
290318* Word2Vec model 
291319*/ 
292- class  Word2VecModel  (val  _model : RDD [(String , Array [Double ])]) extends  Serializable  {
293- 
294-   val  model  =  _model
320+ class  Word2VecModel  (private  val  model : RDD [(String , Array [Double ])]) extends  Serializable  {
295321
296-   private  def  distance (v1 : Array [Double ], v2 : Array [Double ]):  Double  =  {
322+   private  def  cosineSimilarity (v1 : Array [Double ], v2 : Array [Double ]):  Double  =  {
297323    require(v1.length ==  v2.length, " Vectors should have the same length" 
298324    val  n  =  v1.length
299325    val  norm1  =  blas.dnrm2(n, v1, 1 )
@@ -307,20 +333,20 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab
307333   * @param  word  a word  
308334   * @return  vector representation of word 
309335   */  
310- 
311-   def  transform (word : String ):  Array [Double ] =  {
336+   def  transform (word : String ):  Vector  =  {
312337    val  result  =  model.lookup(word) 
313-     if  (result.isEmpty) Array [Double ]()
314-     else  result(0 )
338+     if  (result.isEmpty) {
339+       throw  new  IllegalStateException (s " ${word} not in vocabulary " )
340+     }
341+     else  Vectors .dense(result(0 ))
315342  }
316343
317344  /**  
318345   * Transforms an RDD to its vector representation 
319346   * @param  dataset  a an RDD of words  
320347   * @return  RDD of vector representation  
321348   */  
322-   
323-   def  transform (dataset : RDD [String ]):  RDD [Array [Double ]] =  {
349+   def  transform (dataset : RDD [String ]):  RDD [Vector ] =  {
324350    dataset.map(word =>  transform(word))
325351  }
326352
@@ -332,44 +358,44 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab
332358   */  
333359  def  findSynonyms (word : String , num : Int ):  Array [(String , Double )] =  {
334360    val  vector  =  transform(word)
335-     if  (vector.isEmpty) Array [(String , Double )]()
336-     else  findSynonyms(vector,num)
361+     findSynonyms(vector,num)
337362  }
338363
339364  /**  
340365   * Find synonyms of the vector representation of a word 
341366   * @param  vector  vector representation of a word 
342367   * @param  num  number of synonyms to find   
343-    * @return  array of (word, similarity ) 
368+    * @return  array of (word, cosineSimilarity ) 
344369   */  
345-   def  findSynonyms (vector : Array [ Double ] , num : Int ):  Array [(String , Double )] =  {
370+   def  findSynonyms (vector : Vector , num : Int ):  Array [(String , Double )] =  {
346371    require(num >  0 , " Number of similar words should > 0" 
347-     val  topK  =  model.map( 
348-       { case (w, vec)  =>  (distance( vector, vec), w)}) 
372+     val  topK  =  model.map {  case (w, vec)  =>   
373+       (cosineSimilarity( vector.toArray , vec), w) } 
349374    .sortByKey(ascending =  false )
350375    .take(num +  1 )
351-     .map({case  (dist, w) =>  (w, dist)}).drop(1 )
376+     .map(_.swap)
377+     .tail
352378
353379    topK
354380  }
355381}
356382
357- object  Word2Vec   extends   Serializable   with   Logging   {
383+ object  Word2Vec {
358384  /**  
359385   * Train Word2Vec model 
360386   * @param  input  RDD of words 
361-    * @param  size  vectoer  dimension 
387+    * @param  size  vector  dimension 
362388   * @param  startingAlpha  initial learning rate 
363389   * @param  window  context words from [-window, window] 
364390   * @param  minCount  minimum frequncy to consider a vocabulary word 
365391   * @return  Word2Vec model  
366392  */  
367-   def  train (
368-     input : RDD [String ],
393+   def  train [ S   <:   Iterable [ String ]] (
394+     input : RDD [S ],
369395    size : Int ,
370396    startingAlpha : Double ,
371397    window : Int ,
372398    minCount : Int ):  Word2VecModel  =  {
373-     new  Word2Vec (size,startingAlpha, window, minCount).fit(input)
399+     new  Word2Vec (size,startingAlpha, window, minCount).fit[ S ] (input)
374400  }
375401}
0 commit comments