@@ -31,6 +31,9 @@ import org.apache.spark.SparkContext._
3131import  org .apache .spark .mllib .linalg .Vector 
3232import  org .apache .spark .HashPartitioner 
3333
34+ /** 
35+  *  Entry in vocabulary  
36+  */  
3437private  case  class  VocabWord (
3538  var  word :  String ,
3639  var  cn :  Int ,
@@ -39,6 +42,9 @@ private case class VocabWord(
3942  var  codeLen : Int 
4043)
4144
45+ /** 
46+  *  Vector representation of word 
47+  */  
4248class  Word2Vec (
4349    val  size :  Int ,
4450    val  startingAlpha :  Double ,
@@ -51,7 +57,8 @@ class Word2Vec(
5157  private  val  MAX_CODE_LENGTH  =  40 
5258  private  val  MAX_SENTENCE_LENGTH  =  1000 
5359  private  val  layer1Size  =  size 
54- 
60+   private  val  modelPartitionNum  =  100 
61+   
5562  private  var  trainWordsCount  =  0 
5663  private  var  vocabSize  =  0 
5764  private  var  vocab :  Array [VocabWord ] =  null 
@@ -169,6 +176,7 @@ class Word2Vec(
169176   * Computes the vector representation of each word in  
170177   * vocabulary 
171178   * @param  dataset  an RDD of strings 
179+    * @return  a Word2VecModel 
172180   */  
173181
174182  def  fit (dataset: RDD [String ]):  Word2VecModel  =  {
@@ -274,11 +282,14 @@ class Word2Vec(
274282      wordMap(i) =  (word, vector)
275283      i +=  1 
276284    }
277-     val  modelRDD  =  sc.parallelize(wordMap,100 ).partitionBy(new  HashPartitioner (100 ))
285+     val  modelRDD  =  sc.parallelize(wordMap, modelPartitionNum ).partitionBy(new  HashPartitioner (modelPartitionNum ))
278286    new  Word2VecModel (modelRDD)
279287  }
280288}
281289
290+ /** 
291+ * Word2Vec model 
292+ */ 
282293class  Word2VecModel  (val  _model : RDD [(String , Array [Double ])]) extends  Serializable  {
283294
284295  val  model  =  _model
@@ -292,22 +303,46 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab
292303    blas.ddot(n, v1, 1 , v2,1 ) /  norm1 /  norm2
293304  }
294305
306+   /**  
307+    * Transforms a word to its vector representation 
308+    * @param  word  a word  
309+    * @return  vector representation of word 
310+    */  
311+ 
295312  def  transform (word : String ):  Array [Double ] =  {
296313    val  result  =  model.lookup(word) 
297314    if  (result.isEmpty) Array [Double ]()
298315    else  result(0 )
299316  }
300317
318+   /**  
319+    * Transforms an RDD to its vector representation 
320+    * @param  dataset  a an RDD of words  
321+    * @return  RDD of vector representation  
322+    */  
323+   
301324  def  transform (dataset : RDD [String ]):  RDD [Array [Double ]] =  {
302325    dataset.map(word =>  transform(word))
303326  }
304327
328+   /**  
329+    * Find synonyms of a word 
330+    * @param  word  a word 
331+    * @param  num  number of synonyms to find   
332+    * @return  array of (word, similarity) 
333+    */  
305334  def  findSynonyms (word : String , num : Int ):  Array [(String , Double )] =  {
306335    val  vector  =  transform(word)
307336    if  (vector.isEmpty) Array [(String , Double )]()
308337    else  findSynonyms(vector,num)
309338  }
310339
340+   /**  
341+    * Find synonyms of the vector representation of a word 
342+    * @param  vector  vector representation of a word 
343+    * @param  num  number of synonyms to find   
344+    * @return  array of (word, similarity) 
345+    */  
311346  def  findSynonyms (vector : Array [Double ], num : Int ):  Array [(String , Double )] =  {
312347    require(num >  0 , " Number of similar words should > 0" 
313348    val  topK  =  model.map(
@@ -321,6 +356,15 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab
321356}
322357
323358object  Word2Vec  extends  Serializable  with  Logging  {
359+   /**  
360+    * Train Word2Vec model 
361+    * @param  input  RDD of words 
362+    * @param  size  vectoer dimension 
363+    * @param  startingAlpha  initial learning rate 
364+    * @param  window  context words from [-window, window] 
365+    * @param  minCount  minimum frequncy to consider a vocabulary word 
366+    * @return  Word2Vec model  
367+   */  
324368  def  train (
325369    input : RDD [String ],
326370    size : Int ,
@@ -329,25 +373,4 @@ object Word2Vec extends Serializable with Logging {
329373    minCount : Int ):  Word2VecModel  =  {
330374    new  Word2Vec (size,startingAlpha, window, minCount).fit(input)
331375  }
332- 
333-   def  main (args : Array [String ]) {
334-     if  (args.length <  6 ) {
335-       println(" Usage: word2vec input size startingAlpha window minCount num" 
336-       sys.exit(1 )
337-     }
338-     val  conf  =  new  SparkConf ()
339-       .setAppName(" word2vec" 
340-       
341-     val  sc  =  new  SparkContext (conf)
342-     val  input  =  sc.textFile(args(0 ))
343-     val  size  =  args(1 ).toInt
344-     val  startingAlpha  =  args(2 ).toDouble
345-     val  window  =  args(3 ).toInt
346-     val  minCount  =  args(4 ).toInt
347-     val  num  =  args(5 ).toInt
348-     val  model  =  train(input, size, startingAlpha, window, minCount)
349-     val  vec  =  model.findSynonyms(" china" 
350-     for ((w, dist) <-  vec) logInfo(w.toString +  "  " +  dist.toString)
351-     sc.stop()
352-   }
353376}
0 commit comments