Skip to content

Commit 5af803d

Browse files
committed
SPARK-7579 [MLLIB] User guide update for OneHotEncoder
1 parent 4de74d2 commit 5af803d

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed

docs/ml-features.md

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,5 +440,100 @@ for expanded in polyDF.select("polyFeatures").take(3):
440440
</div>
441441
</div>
442442

443+
## OneHotEncoder
444+
445+
[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features
446+
447+
<div class="codetabs">
448+
<div data-lang="scala" markdown="1">
449+
{% highlight scala %}
450+
import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer}
451+
452+
val df = sqlContext.createDataFrame(Seq(
453+
(0, "a"),
454+
(1, "b"),
455+
(2, "c"),
456+
(3, "a"),
457+
(4, "a"),
458+
(5, "c")
459+
)).toDF("id", "category")
460+
461+
val indexer = new StringIndexer()
462+
.setInputCol("category")
463+
.setOutputCol("categoryIndex")
464+
.fit(df)
465+
val indexed = indexer.transform(df)
466+
467+
val encoder = new OneHotEncoder().setInputCol("categoryIndex").
468+
setOutputCol("categoryVec")
469+
val encoded = encoder.transform(indexed)
470+
encoded.select("id", "categoryVec").foreach(println)
471+
{% endhighlight %}
472+
</div>
473+
474+
<div data-lang="java" markdown="1">
475+
{% highlight java %}
476+
import com.google.common.collect.Lists;
477+
478+
import org.apache.spark.api.java.JavaRDD;
479+
import org.apache.spark.ml.feature.OneHotEncoder;
480+
import org.apache.spark.ml.feature.StringIndexer;
481+
import org.apache.spark.ml.feature.StringIndexerModel;
482+
import org.apache.spark.sql.DataFrame;
483+
import org.apache.spark.sql.Row;
484+
import org.apache.spark.sql.RowFactory;
485+
import org.apache.spark.sql.types.DataTypes;
486+
import org.apache.spark.sql.types.Metadata;
487+
import org.apache.spark.sql.types.StructField;
488+
import org.apache.spark.sql.types.StructType;
489+
490+
JavaRDD<Row> jrdd = jsc.parallelize(Lists.newArrayList(
491+
RowFactory.create(0, "a"),
492+
RowFactory.create(1, "b"),
493+
RowFactory.create(2, "c"),
494+
RowFactory.create(3, "a"),
495+
RowFactory.create(4, "a"),
496+
RowFactory.create(5, "c")
497+
));
498+
StructType schema = new StructType(new StructField[]{
499+
new StructField("id", DataTypes.DoubleType, false, Metadata.empty()),
500+
new StructField("category", DataTypes.StringType, false, Metadata.empty())
501+
});
502+
DataFrame df = sqlContext.createDataFrame(jrdd, schema);
503+
StringIndexerModel indexer = new StringIndexer()
504+
.setInputCol("category")
505+
.setOutputCol("categoryIndex")
506+
.fit(df);
507+
DataFrame indexed = indexer.transform(df);
508+
509+
OneHotEncoder encoder = new OneHotEncoder()
510+
.setInputCol("categoryIndex")
511+
.setOutputCol("categoryVec");
512+
DataFrame encoded = encoder.transform(indexed);
513+
{% endhighlight %}
514+
</div>
515+
516+
<div data-lang="python" markdown="1">
517+
{% highlight python %}
518+
from pyspark.ml.feature import OneHotEncoder, StringIndexer
519+
520+
df = sqlContext.createDataFrame([
521+
(0, "a"),
522+
(1, "b"),
523+
(2, "c"),
524+
(3, "a"),
525+
(4, "a"),
526+
(5, "c")
527+
], ["id", "category"])
528+
529+
stringIndexer = StringIndexer(inputCol="category", outputCol="categoryIndex")
530+
model = stringIndexer.fit(df)
531+
indexed = model.transform(df)
532+
encoder = OneHotEncoder(includeFirst=False, inputCol="categoryIndex", outputCol="categoryVec")
533+
encoded = encoder.transform(indexed)
534+
{% endhighlight %}
535+
</div>
536+
</div>
537+
443538
# Feature Selectors
444539

0 commit comments

Comments
 (0)