Skip to content

Commit 04c48e9

Browse files
committed
ensure the functionality
1 parent a190f2c commit 04c48e9

File tree

2 files changed

+34
-48
lines changed

2 files changed

+34
-48
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ private[feature] trait Word2VecBase extends Params
6262
*/
6363
final val seed = new LongParam(this, "seed", "a random seed to random an initial vector")
6464

65-
setDefault(seed -> Utils.random.nextLong())
65+
setDefault(seed -> 42L)
6666

6767
/** @group getParam */
6868
def getSeed: Long = getOrDefault(seed)
@@ -77,12 +77,15 @@ private[feature] trait Word2VecBase extends Params
7777
/** @group getParam */
7878
def getMinCount: Int = getOrDefault(minCount)
7979

80+
setDefault(stepSize -> 0.025)
81+
setDefault(maxIter -> 1)
82+
8083
/**
8184
* Validate and transform the input schema.
8285
*/
8386
protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
8487
val map = extractParamMap(paramMap)
85-
SchemaUtils.checkColumnType(schema, map(inputCol), new ArrayType(new StringType, false))
88+
SchemaUtils.checkColumnType(schema, map(inputCol), new ArrayType(StringType, true))
8689
SchemaUtils.appendColumn(schema, map(outputCol), new VectorUDT)
8790
}
8891
}
@@ -166,8 +169,14 @@ class Word2VecModel private[ml] (
166169
val map = extractParamMap(paramMap)
167170
val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors)
168171
val word2Vec = udf { v: Seq[String] =>
169-
v.map(bWordVectors.value.transform).foldLeft(Vectors.zeros(map(vectorSize))) { (cum, vec) =>
170-
Vectors.dense(cum.toArray.zip(vec.toArray).map(x => x._1 + x._2))
172+
if (v.size == 0) {
173+
Vectors.zeros(map(vectorSize))
174+
} else {
175+
Vectors.dense(
176+
v.map(bWordVectors.value.getVectors).foldLeft(Array.fill[Double](map(vectorSize))(0)) {
177+
(cum, vec) => cum.zip(vec).map(x => x._1 + x._2)
178+
}.map(_ / v.size)
179+
)
171180
}
172181
}
173182
dataset.withColumn(map(outputCol), word2Vec(col(map(inputCol))))

mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala

Lines changed: 21 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -31,57 +31,34 @@ class Word2VecSuite extends FunSuite with MLlibTestSparkContext {
3131
import sqlContext.implicits._
3232

3333
val sentence = "a b " * 100 + "a c " * 10
34-
val localDoc = Seq(sentence, sentence)
35-
val doc = sc.parallelize(localDoc)
36-
.map(line => line.split(" "))
37-
val docDF = doc.map(text => Tuple1(text)).toDF("text")
34+
val numOfWords = sentence.split(" ").size
35+
val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
3836

39-
val model = new Word2Vec()
40-
.setVectorSize(3)
41-
.setSeed(42L)
42-
.setInputCol("text")
43-
.setMaxIter(1)
44-
.fit(docDF)
45-
46-
val words = sc.parallelize(Seq("a", "b", "c"))
4737
val codes = Map(
48-
"a" -> Vectors.dense(-0.2811822295188904,-0.6356269121170044,-0.3020961284637451),
49-
"b" -> Vectors.dense(1.0309048891067505,-1.29472815990448,0.22276712954044342),
50-
"c" -> Vectors.dense(-0.08456747233867645,0.5137411952018738,0.11731560528278351)
38+
"a" -> Array(-0.2811822295188904,-0.6356269121170044,-0.3020961284637451),
39+
"b" -> Array(1.0309048891067505,-1.29472815990448,0.22276712954044342),
40+
"c" -> Array(-0.08456747233867645,0.5137411952018738,0.11731560528278351)
5141
)
5242

53-
val synonyms = Map(
54-
"a" -> Map("b" -> 0.3680490553379059),
55-
"b" -> Map("a" -> 0.3680490553379059),
56-
"c" -> Map("b" -> -0.8148014545440674)
57-
)
58-
val wordsDF = words.map(word => Tuple3(word, codes(word), synonyms(word)))
59-
.toDF("word", "realCode", "realSynonyms")
43+
val expected = doc.map { sentence =>
44+
Vectors.dense(sentence.map(codes.apply).reduce((word1, word2) =>
45+
word1.zip(word2).map { case (v1, v2) => v1 + v2 }
46+
).map(_ / numOfWords))
47+
}
6048

61-
val res = model
62-
.setInputCol("word")
63-
.setCodeCol("code")
64-
.setSynonymsCol("syn")
65-
.setNumSynonyms(1)
66-
.transform(wordsDF)
49+
val docDF = doc.zip(expected).toDF("text", "expected")
6750

68-
assert(
69-
res.select("code", "realCode")
70-
.map { case Row(c: Vector, rc: Vector) => (c, rc) }
71-
.collect()
72-
.forall { case (vector1, vector2) =>
73-
vector1 ~== vector2 absTol 1E-5
74-
}, "The code is not correct after transforming."
75-
)
51+
val model = new Word2Vec()
52+
.setVectorSize(3)
53+
.setInputCol("text")
54+
.setOutputCol("result")
55+
.fit(docDF)
56+
57+
model.transform(docDF).select("result", "expected").collect().foreach {
58+
case Row(vector1: Vector, vector2: Vector) =>
59+
assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.")
60+
}
7661

77-
assert(
78-
res.select("syn", "realSynonyms")
79-
.map { case Row(s: Map[String, Double], rs: Map[String, Double]) => (s, rs) }
80-
.collect()
81-
.forall { case (map1, map2) =>
82-
map1 == map2
83-
}, "The synonyms are not correct after transforming."
84-
)
8562
}
8663
}
8764

0 commit comments

Comments
 (0)