@@ -31,6 +31,9 @@ import org.apache.spark.SparkContext._
3131import org .apache .spark .mllib .linalg .Vector
3232import org .apache .spark .HashPartitioner
3333
34+ /**
35+ * Entry in vocabulary
36+ */
3437private case class VocabWord (
3538 var word : String ,
3639 var cn : Int ,
@@ -39,6 +42,9 @@ private case class VocabWord(
3942 var codeLen : Int
4043)
4144
45+ /**
46+ * Vector representation of word
47+ */
4248class Word2Vec (
4349 val size : Int ,
4450 val startingAlpha : Double ,
@@ -51,7 +57,8 @@ class Word2Vec(
5157 private val MAX_CODE_LENGTH = 40
5258 private val MAX_SENTENCE_LENGTH = 1000
5359 private val layer1Size = size
54-
60+ private val modelPartitionNum = 100
61+
5562 private var trainWordsCount = 0
5663 private var vocabSize = 0
5764 private var vocab : Array [VocabWord ] = null
@@ -169,6 +176,7 @@ class Word2Vec(
169176 * Computes the vector representation of each word in
170177 * vocabulary
171178 * @param dataset an RDD of strings
179+ * @return a Word2VecModel
172180 */
173181
174182 def fit (dataset: RDD [String ]): Word2VecModel = {
@@ -274,11 +282,14 @@ class Word2Vec(
274282 wordMap(i) = (word, vector)
275283 i += 1
276284 }
277- val modelRDD = sc.parallelize(wordMap,100 ).partitionBy(new HashPartitioner (100 ))
285+ val modelRDD = sc.parallelize(wordMap, modelPartitionNum ).partitionBy(new HashPartitioner (modelPartitionNum ))
278286 new Word2VecModel (modelRDD)
279287 }
280288}
281289
290+ /**
291+ * Word2Vec model
292+ */
282293class Word2VecModel (val _model : RDD [(String , Array [Double ])]) extends Serializable {
283294
284295 val model = _model
@@ -292,22 +303,46 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab
292303 blas.ddot(n, v1, 1 , v2,1 ) / norm1 / norm2
293304 }
294305
306+ /**
307+ * Transforms a word to its vector representation
308+ * @param word a word
309+ * @return vector representation of word
310+ */
311+
295312 def transform (word : String ): Array [Double ] = {
296313 val result = model.lookup(word)
297314 if (result.isEmpty) Array [Double ]()
298315 else result(0 )
299316 }
300317
318+ /**
319+ * Transforms an RDD to its vector representation
320+ * @param dataset a an RDD of words
321+ * @return RDD of vector representation
322+ */
323+
301324 def transform (dataset : RDD [String ]): RDD [Array [Double ]] = {
302325 dataset.map(word => transform(word))
303326 }
304327
328+ /**
329+ * Find synonyms of a word
330+ * @param word a word
331+ * @param num number of synonyms to find
332+ * @return array of (word, similarity)
333+ */
305334 def findSynonyms (word : String , num : Int ): Array [(String , Double )] = {
306335 val vector = transform(word)
307336 if (vector.isEmpty) Array [(String , Double )]()
308337 else findSynonyms(vector,num)
309338 }
310339
340+ /**
341+ * Find synonyms of the vector representation of a word
342+ * @param vector vector representation of a word
343+ * @param num number of synonyms to find
344+ * @return array of (word, similarity)
345+ */
311346 def findSynonyms (vector : Array [Double ], num : Int ): Array [(String , Double )] = {
312347 require(num > 0 , " Number of similar words should > 0" )
313348 val topK = model.map(
@@ -321,6 +356,15 @@ class Word2VecModel (val _model:RDD[(String, Array[Double])]) extends Serializab
321356}
322357
323358object Word2Vec extends Serializable with Logging {
359+ /**
360+ * Train Word2Vec model
361+ * @param input RDD of words
362+ * @param size vectoer dimension
363+ * @param startingAlpha initial learning rate
364+ * @param window context words from [-window, window]
365+ * @param minCount minimum frequncy to consider a vocabulary word
366+ * @return Word2Vec model
367+ */
324368 def train (
325369 input : RDD [String ],
326370 size : Int ,
@@ -329,25 +373,4 @@ object Word2Vec extends Serializable with Logging {
329373 minCount : Int ): Word2VecModel = {
330374 new Word2Vec (size,startingAlpha, window, minCount).fit(input)
331375 }
332-
333- def main (args : Array [String ]) {
334- if (args.length < 6 ) {
335- println(" Usage: word2vec input size startingAlpha window minCount num" )
336- sys.exit(1 )
337- }
338- val conf = new SparkConf ()
339- .setAppName(" word2vec" )
340-
341- val sc = new SparkContext (conf)
342- val input = sc.textFile(args(0 ))
343- val size = args(1 ).toInt
344- val startingAlpha = args(2 ).toDouble
345- val window = args(3 ).toInt
346- val minCount = args(4 ).toInt
347- val num = args(5 ).toInt
348- val model = train(input, size, startingAlpha, window, minCount)
349- val vec = model.findSynonyms(" china" , num)
350- for ((w, dist) <- vec) logInfo(w.toString + " " + dist.toString)
351- sc.stop()
352- }
353376}
0 commit comments