1818package org .apache .spark .ml .feature
1919
2020import org .apache .spark .annotation .AlphaComponent
21- import org .apache .spark .ml .Estimator
22- import org .apache .spark .ml .Model
23- import org .apache .spark .ml .param .HasInputCol
24- import org .apache .spark .ml .param .ParamMap
25- import org .apache .spark .ml .param .Params
26- import org .apache .spark .ml .param ._
21+ import org .apache .spark .ml .{Estimator , Model }
22+ import org .apache .spark .ml .param .{HasInputCol , ParamMap , Params , _ }
2723import org .apache .spark .mllib .feature
2824import org .apache .spark .mllib .linalg .{Vector , VectorUDT }
29- import org .apache .spark .sql .DataFrame
30- import org .apache .spark .sql .Row
25+ import org .apache .spark .sql .{DataFrame , Row }
3126import org .apache .spark .sql .functions ._
3227import org .apache .spark .sql .types ._
3328import org .apache .spark .util .Utils
3429
3530/**
36- * Params for [[StandardScaler ]] and [[StandardScalerModel ]].
31+ * Params for [[Word2Vec ]] and [[Word2VecModel ]].
3732 */
38- private [feature] trait Word2VecParams extends Params with HasInputCol with HasMaxIter with HasLearningRate {
33+ private [feature] trait Word2VecParams extends Params
34+ with HasInputCol with HasMaxIter with HasLearningRate {
3935
4036 /**
4137 * The dimension of the code that you want to transform from words.
4238 */
43- val vectorSize = new IntParam (this , " vectorSize" , " " , Some (100 ))
39+ val vectorSize = new IntParam (
40+ this , " vectorSize" , " the dimension of codes after transforming from words" , Some (100 ))
4441
4542 /** @group getParam */
4643 def getVectorSize : Int = get(vectorSize)
4744
4845 /**
49- * Number of partitions
46+ * Number of partitions for sentences of words.
5047 */
51- val numPartitions = new IntParam (this , " numPartitions" , " " , Some (1 ))
48+ val numPartitions = new IntParam (
49+ this , " numPartitions" , " number of partitions for sentences of words" , Some (1 ))
5250
5351 /** @group getParam */
5452 def getNumPartitions : Int = get(numPartitions)
5553
5654 /**
57- * The random seed
55+ * A random seed to random an initial vector.
5856 */
59- val seed = new LongParam (this , " seed" , " " , Some (Utils .random.nextLong()))
57+ val seed = new LongParam (
58+ this , " seed" , " a random seed to random an initial vector" , Some (Utils .random.nextLong()))
6059
6160 /** @group getParam */
6261 def getSeed : Long = get(seed)
6362
6463 /**
6564 * The minimum count of words that can be kept in training set.
6665 */
67- val minCount = new IntParam (this , " minCount" , " " , Some (5 ))
66+ val minCount = new IntParam (
67+ this , " minCount" , " the minimum count of words to filter words" , Some (5 ))
6868
6969 /** @group getParam */
7070 def getMinCount : Int = get(minCount)
@@ -96,8 +96,8 @@ private[feature] trait Word2VecParams extends Params with HasInputCol with HasMa
9696
9797/**
9898 * :: AlphaComponent ::
99- * Standardizes features by removing the mean and scaling to unit variance using column summary
100- * statistics on the samples in the training set .
99+ * Word2Vec trains a model of `Map(String, Vector)`, i.e. transforms a word into a code for further
100+ * natural language processing or machine learning process .
101101 */
102102@ AlphaComponent
103103class Word2Vec extends Estimator [Word2VecModel ] with Word2VecParams {
@@ -123,8 +123,6 @@ class Word2Vec extends Estimator[Word2VecModel] with Word2VecParams {
123123 /** @group setParam */
124124 def setMinCount (value : Int ) = set(minCount, value)
125125
126- type S <: Iterable [String ]
127-
128126 override def fit (dataset : DataFrame , paramMap : ParamMap ): Word2VecModel = {
129127 transformSchema(dataset.schema, paramMap, logging = true )
130128 val map = this .paramMap ++ paramMap
@@ -153,7 +151,7 @@ class Word2Vec extends Estimator[Word2VecModel] with Word2VecParams {
153151
154152/**
155153 * :: AlphaComponent ::
156- * Model fitted by [[StandardScaler ]].
154+ * Model fitted by [[Word2Vec ]].
157155 */
158156@ AlphaComponent
159157class Word2VecModel private [ml] (
@@ -174,6 +172,12 @@ class Word2VecModel private[ml] (
174172 /** @group setParam */
175173 def setCodeCol (value : String ): this .type = set(codeCol, value)
176174
175+ /**
176+ * The transforming process of `Word2Vec` model has two approaches - 1. Transform a word of
177+ * `String` into a code of `Vector`; 2. Find n (given by you) synonyms of a given word.
178+ *
179+ * Note. Currently we only support finding synonyms for word of `String`, not `Vector`.
180+ */
177181 override def transform (dataset : DataFrame , paramMap : ParamMap ): DataFrame = {
178182 transformSchema(dataset.schema, paramMap, logging = true )
179183 val map = this .paramMap ++ paramMap
@@ -189,6 +193,7 @@ class Word2VecModel private[ml] (
189193 }
190194
191195 if (map(synonymsCol) != " " & map(numSynonyms) > 0 ) {
196+ // TODO We will add finding synonyms for code of `Vector`.
192197 val findSynonyms = udf { (word : String ) =>
193198 wordVectors.findSynonyms(word, map(numSynonyms)).toMap : Map [String , Double ]
194199 }
@@ -216,7 +221,7 @@ class Word2VecModel private[ml] (
216221 if (map(codeCol) != " " ) {
217222 require(! schema.fieldNames.contains(map(codeCol)),
218223 s " Output column ${map(codeCol)} already exists. " )
219- outputFields = outputFields :+ StructField (map(codeCol), new VectorUDT , false )
224+ outputFields = outputFields :+ StructField (map(codeCol), new VectorUDT , nullable = false )
220225 }
221226
222227 if (map(synonymsCol) != " " ) {
@@ -225,7 +230,7 @@ class Word2VecModel private[ml] (
225230 require(map(numSynonyms) > 0 ,
226231 s " Number of synonyms should larger than 0 " )
227232 outputFields = outputFields :+
228- StructField (map(synonymsCol), MapType (StringType , DoubleType ), false )
233+ StructField (map(synonymsCol), MapType (StringType , DoubleType ), nullable = false )
229234 }
230235
231236 StructType (outputFields)
0 commit comments