Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 1b5369e

Browse files
committed
add docs of word2vec
1 parent 48fc38f commit 1b5369e

File tree

2 files changed

+144
-0
lines changed

2 files changed

+144
-0
lines changed

docs/ml-features.md

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,84 @@ for features_label in featurized.select("features", "label").take(3):
106106
</div>
107107
</div>
108108

109+
## Word2Vec
110+
111+
`Word2Vec` is an `Estimator` which takes sequences of words that represents documents and trains a `Word2VecModel`. The model is a `Map(String, Vector)` essentially, which maps each word to an unique fix-sized vector. The `Word2VecModel` transforms each documents into a vector using the average of all words in the document, which aims to other computations of documents such as similarity calculation consequencely. Please refer to the [MLlib user guide on Word2Vec](mllib-feature-extraction.html#Word2Vec) for more details on Word2Vec.
112+
113+
Word2Vec is implemented in [Word2Vec](api/scala/index.html#org.apache.spark.ml.feature.Word2Vec). In the following code segment, we start with a set of documents, each of them is represented as a sequence of words. For each document, we transform it into a feature vector. This feature vector could then be passed to a learning algorithm.
114+
115+
<div class="codetabs">
116+
<div data-lang="scala" markdown="1">
117+
{% highlight scala %}
118+
import org.apache.spark.ml.feature.Word2Vec
119+
120+
val documentDF = sqlContext.createDataFrame(Seq(
121+
"Hi I heard about Spark".split(" "),
122+
"I wish Java could use case classes".split(" "),
123+
"Logistic regression models are neat".split(" ")
124+
)).map(Tuple1.apply).toDF("text")
125+
126+
val word2Vec = new Word2Vec.setInputCol("text").setOutputCol("result").setVectorSize(3)
127+
val model = word2Vec.fit(documentDF)
128+
val result = model.transform(documentDF).select("result").take(3).foreach(println)
129+
{% endhighlight %}
130+
</div>
131+
132+
<div data-lang="java" markdown="1">
133+
{% highlight java %}
134+
import com.google.common.collect.Lists;
135+
136+
import org.apache.spark.api.java.JavaRDD;
137+
import org.apache.spark.api.java.JavaSparkContext;
138+
import org.apache.spark.sql.DataFrame;
139+
import org.apache.spark.sql.Row;
140+
import org.apache.spark.sql.RowFactory;
141+
import org.apache.spark.sql.SQLContext;
142+
import org.apache.spark.sql.types.*;
143+
144+
JavaSparkContext jsc = ...
145+
SQLContext sqlContext = ...
146+
JavaRDD<Row> jrdd = jsc.parallelize(Lists.newArrayList(
147+
RowFactory.create(Lists.newArrayList("Hi I heard about Spark".split(" "))),
148+
RowFactory.create(Lists.newArrayList("I wish Java could use case classes".split(" "))),
149+
RowFactory.create(Lists.newArrayList("Logistic regression models are neat".split(" ")))
150+
));
151+
StructType schema = new StructType(new StructField[]{
152+
new StructField("text", new ArrayType(StringType$.MODULE$, true), false, Metadata.empty())
153+
});
154+
DataFrame documentDF = sqlContext.createDataFrame(jrdd, schema);
155+
156+
Word2Vec word2Vec = new Word2Vec()
157+
.setInputCol("text")
158+
.setOutputCol("result")
159+
.setVectorSize(3)
160+
.setMinCount(0);
161+
Word2VecModel model = word2Vec.fit(documentDF);
162+
DataFrame result = model.transform(documentDF);
163+
164+
for (Row r: result.select("result").take(3)) {
165+
System.out.println(r);
166+
}
167+
{% endhighlight %}
168+
</div>
169+
170+
<div data-lang="python" markdown="1">
171+
{% highlight python %}
172+
from pyspark.ml.feature import Word2Vec
173+
174+
documentDF = sqlContext.createDataFrame([
175+
("Hi I heard about Spark".split(" "), ),
176+
("I wish Java could use case classes".split(" "), ),
177+
("Logistic regression models are neat".split(" "))
178+
], ["text"])
179+
word2Vec = Word2Vec(vectorSize = 3, minCount = 0, inputCol = "text", outputCol = "result")
180+
model = word2Vec.fit(documentDF)
181+
result = model.transform(documentDF)
182+
for feature in result.select("result").take(3):
183+
print feature
184+
{% endhighlight %}
185+
</div>
186+
</div>
109187

110188
# Feature Transformers
111189

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package org.apache.spark.ml.feature;
2+
3+
import com.google.common.collect.Lists;
4+
import org.junit.After;
5+
import org.junit.Assert;
6+
import org.junit.Before;
7+
import org.junit.Test;
8+
9+
import org.apache.spark.api.java.JavaRDD;
10+
import org.apache.spark.api.java.JavaSparkContext;
11+
import org.apache.spark.mllib.linalg.Vector;
12+
import org.apache.spark.mllib.linalg.VectorUDT;
13+
import org.apache.spark.mllib.linalg.Vectors;
14+
import org.apache.spark.sql.DataFrame;
15+
import org.apache.spark.sql.Row;
16+
import org.apache.spark.sql.RowFactory;
17+
import org.apache.spark.sql.SQLContext;
18+
import org.apache.spark.sql.types.*;
19+
20+
public class JavaWord2VecSuite {
21+
private transient JavaSparkContext jsc;
22+
private transient SQLContext sqlContext;
23+
24+
@Before
25+
public void setUp() {
26+
jsc = new JavaSparkContext("local", "JavaWord2VecSuite");
27+
sqlContext = new SQLContext(jsc);
28+
}
29+
30+
@After
31+
public void tearDown() {
32+
jsc.stop();
33+
jsc = null;
34+
}
35+
36+
@Test
37+
public void testJavaWord2Vec() {
38+
JavaRDD<Row> jrdd = jsc.parallelize(Lists.newArrayList(
39+
RowFactory.create(Lists.newArrayList("Hi I heard about Spark".split(" ")),
40+
Vectors.dense(0.017877750098705292, -0.018388677015900613, -0.01183266043663025)),
41+
RowFactory.create(Lists.newArrayList("I wish Java could use case classes".split(" ")),
42+
Vectors.dense(0.0038498884865215844, -0.07299017374004636, 0.010990704176947474)),
43+
RowFactory.create(Lists.newArrayList("Logistic regression models are neat".split(" ")),
44+
Vectors.dense(0.017819208838045598, -0.006920230574905872, 0.022744188457727434))
45+
));
46+
StructType schema = new StructType(new StructField[]{
47+
new StructField("text", new ArrayType(StringType$.MODULE$, true), false, Metadata.empty()),
48+
new StructField("expected", new VectorUDT(), false, Metadata.empty())
49+
});
50+
DataFrame documentDF = sqlContext.createDataFrame(jrdd, schema);
51+
52+
Word2Vec word2Vec = new Word2Vec()
53+
.setInputCol("text")
54+
.setOutputCol("result")
55+
.setVectorSize(3)
56+
.setMinCount(0);
57+
Word2VecModel model = word2Vec.fit(documentDF);
58+
DataFrame result = model.transform(documentDF);
59+
60+
for (Row r: result.select("result", "expected").collect()) {
61+
double[] polyFeatures = ((Vector)r.get(0)).toArray();
62+
double[] expected = ((Vector)r.get(1)).toArray();
63+
Assert.assertArrayEquals(polyFeatures, expected, 1e-1);
64+
}
65+
}
66+
}

0 commit comments

Comments
 (0)