Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 7fa18d1

Browse files
committed
add user guide for StringIndexer
1 parent 136cb93 commit 7fa18d1

File tree

1 file changed

+116
-0
lines changed

1 file changed

+116
-0
lines changed

docs/ml-features.md

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,122 @@ for expanded in polyDF.select("polyFeatures").take(3):
456456
</div>
457457
</div>
458458

459+
## StringIndexer
460+
461+
`StringIndexer` encodes a string column of labels to a column of label indices.
462+
The indices are in `[0, numLabels)`, ordered by label frequencies.
463+
So the most frequent label gets index `0`.
464+
If the input column is numeric, we cast it to string and index the string values.
465+
466+
**Examples**
467+
468+
Assume that we have the following DataFrame with columns `id` and `category`:
469+
470+
~~~~
471+
id | category
472+
----|----------
473+
0 | a
474+
1 | b
475+
2 | c
476+
3 | a
477+
4 | a
478+
5 | c
479+
~~~~
480+
481+
`category` is a string column with three labels: "a", "b", and "c".
482+
Applying `StringIndexer` with `category` as the input column and `categoryIndex` as the output
483+
column, we should get the following:
484+
485+
~~~~
486+
id | category | categoryIndex
487+
----|----------|---------------
488+
0 | a | 0.0
489+
1 | b | 2.0
490+
2 | c | 1.0
491+
3 | a | 0.0
492+
4 | a | 0.0
493+
5 | c | 1.0
494+
~~~~
495+
496+
"a" gets index `0` because it is the most frequent, followed by "c" with index `1` and "b" with
497+
index `2`.
498+
499+
<div class="codetabs">
500+
501+
<div data-lang="scala" markdown="1">
502+
503+
[`StringIndexer`](api/scala/index.html#org.apache.spark.ml.feature.StringIndexer) takes an input
504+
column name and an output column name.
505+
506+
{% highlight scala %}
507+
import org.apache.spark.ml.feature.StringIndexer
508+
509+
val df = sqlContext.createDataFrame(
510+
Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c"))
511+
).toDF("id", "category")
512+
val indexer = new StringIndexer()
513+
.setInputCol("category")
514+
.setOutputCol("categoryIndex")
515+
val indexed = indexer.fit(df).transform(df)
516+
indexed.show()
517+
{% endhighlight %}
518+
</div>
519+
520+
<div data-lang="java" markdown="1">
521+
[`StringIndexer`](api/java/org/apache/spark/ml/feature/StringIndexer.html) takes an input column
522+
name and an output column name.
523+
524+
{% highlight java %}
525+
import java.util.Arrays;
526+
527+
import org.apache.spark.api.java.JavaRDD;
528+
import org.apache.spark.ml.feature.StringIndexer;
529+
import org.apache.spark.sql.DataFrame;
530+
import org.apache.spark.sql.Row;
531+
import org.apache.spark.sql.RowFactory;
532+
import org.apache.spark.sql.types.StructField;
533+
import org.apache.spark.sql.types.StructType;
534+
import static org.apache.spark.sql.types.DataTypes.*;
535+
536+
JavaRDD<Row> jrdd = jsc.parallelize(Arrays.asList(
537+
RowFactory.create(0, "a"),
538+
RowFactory.create(1, "b"),
539+
RowFactory.create(2, "c"),
540+
RowFactory.create(3, "a"),
541+
RowFactory.create(4, "a"),
542+
RowFactory.create(5, "c")
543+
));
544+
StructType schema = new StructType(new StructField[] {
545+
createStructField("id", DoubleType, false),
546+
createStructField("category", StringType, false)
547+
});
548+
DataFrame df = sqlContext.createDataFrame(jrdd, schema);
549+
StringIndexer indexer = new StringIndexer()
550+
.setInputCol("category")
551+
.setOutputCol("categoryIndex");
552+
DataFrame indexed = indexer.fit(df).transform(df);
553+
indexed.show();
554+
{% endhighlight %}
555+
</div>
556+
557+
<div data-lang="python" markdown="1">
558+
559+
[`StringIndexer`](api/python/pyspark.ml.html#pyspark.ml.feature.StringIndexer) takes an input
560+
column name and an output column name.
561+
562+
{% highlight python %}
563+
from pyspark.ml.feature import StringIndexer
564+
565+
df = sqlContext.createDataFrame(
566+
[(0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")]
567+
["id", "category"])
568+
indexer = StringIndexer(inputCol="category", outputCol="categoryIndex")
569+
indexed = indexer.fit(df).transform(df)
570+
indexed.show()
571+
{% endhighlight %}
572+
</div>
573+
</div>
574+
459575
## OneHotEncoder
460576

461577
[One-hot encoding](http://en.wikipedia.org/wiki/One-hot) maps a column of label indices to a column of binary vectors, with at most a single one-value. This encoding allows algorithms which expect continuous features, such as Logistic Regression, to use categorical features

0 commit comments

Comments
 (0)