@@ -31,57 +31,34 @@ class Word2VecSuite extends FunSuite with MLlibTestSparkContext {
3131 import sqlContext .implicits ._
3232
3333 val sentence = " a b " * 100 + " a c " * 10
34- val localDoc = Seq (sentence, sentence)
35- val doc = sc.parallelize(localDoc)
36- .map(line => line.split(" " ))
37- val docDF = doc.map(text => Tuple1 (text)).toDF(" text" )
34+ val numOfWords = sentence.split(" " ).size
35+ val doc = sc.parallelize(Seq (sentence, sentence)).map(line => line.split(" " ))
3836
39- val model = new Word2Vec ()
40- .setVectorSize(3 )
41- .setSeed(42L )
42- .setInputCol(" text" )
43- .setMaxIter(1 )
44- .fit(docDF)
45-
46- val words = sc.parallelize(Seq (" a" , " b" , " c" ))
4737 val codes = Map (
48- " a" -> Vectors .dense (- 0.2811822295188904 ,- 0.6356269121170044 ,- 0.3020961284637451 ),
49- " b" -> Vectors .dense (1.0309048891067505 ,- 1.29472815990448 ,0.22276712954044342 ),
50- " c" -> Vectors .dense (- 0.08456747233867645 ,0.5137411952018738 ,0.11731560528278351 )
38+ " a" -> Array (- 0.2811822295188904 ,- 0.6356269121170044 ,- 0.3020961284637451 ),
39+ " b" -> Array (1.0309048891067505 ,- 1.29472815990448 ,0.22276712954044342 ),
40+ " c" -> Array (- 0.08456747233867645 ,0.5137411952018738 ,0.11731560528278351 )
5141 )
5242
53- val synonyms = Map (
54- " a" -> Map (" b" -> 0.3680490553379059 ),
55- " b" -> Map (" a" -> 0.3680490553379059 ),
56- " c" -> Map (" b" -> - 0.8148014545440674 )
57- )
58- val wordsDF = words.map(word => Tuple3 (word, codes(word), synonyms(word)))
59- .toDF(" word" , " realCode" , " realSynonyms" )
43+ val expected = doc.map { sentence =>
44+ Vectors .dense(sentence.map(codes.apply).reduce((word1, word2) =>
45+ word1.zip(word2).map { case (v1, v2) => v1 + v2 }
46+ ).map(_ / numOfWords))
47+ }
6048
61- val res = model
62- .setInputCol(" word" )
63- .setCodeCol(" code" )
64- .setSynonymsCol(" syn" )
65- .setNumSynonyms(1 )
66- .transform(wordsDF)
49+ val docDF = doc.zip(expected).toDF(" text" , " expected" )
6750
68- assert(
69- res.select(" code" , " realCode" )
70- .map { case Row (c : Vector , rc : Vector ) => (c, rc) }
71- .collect()
72- .forall { case (vector1, vector2) =>
73- vector1 ~== vector2 absTol 1E-5
74- }, " The code is not correct after transforming."
75- )
51+ val model = new Word2Vec ()
52+ .setVectorSize(3 )
53+ .setInputCol(" text" )
54+ .setOutputCol(" result" )
55+ .fit(docDF)
56+
57+ model.transform(docDF).select(" result" , " expected" ).collect().foreach {
58+ case Row (vector1 : Vector , vector2 : Vector ) =>
59+ assert(vector1 ~== vector2 absTol 1E-5 , " Transformed vector is different with expected." )
60+ }
7661
77- assert(
78- res.select(" syn" , " realSynonyms" )
79- .map { case Row (s : Map [String , Double ], rs : Map [String , Double ]) => (s, rs) }
80- .collect()
81- .forall { case (map1, map2) =>
82- map1 == map2
83- }, " The synonyms are not correct after transforming."
84- )
8562 }
8663}
8764
0 commit comments